@@ -382,114 +382,78 @@ end
382
382
# some back-ends don't support byval, or support it badly
383
383
# https://reviews.llvm.org/D79744
384
384
385
- # generate a kernel wrapper to fix & improve argument passing
386
- function lower_byval (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module , entry_f :: LLVM.Function )
385
+ # modify the kernel function to fix & improve argument passing
386
+ function lower_byval (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module , f :: LLVM.Function )
387
387
ctx = context (mod)
388
- entry_ft = eltype (llvmtype (entry_f):: LLVM.PointerType ):: LLVM.FunctionType
389
- @compiler_assert return_type (entry_ft) == LLVM. VoidType (ctx) job
390
-
391
- args = classify_arguments (job, entry_f)
392
- filter! (args) do arg
393
- arg. cc != GHOST
388
+ ft = eltype (llvmtype (f):: LLVM.PointerType ):: LLVM.FunctionType
389
+ @compiler_assert return_type (ft) == LLVM. VoidType (ctx) job
390
+
391
+ # find the byval parameters
392
+ byval = BitVector (undef, length (parameters (ft)))
393
+ for i in 1 : length (byval)
394
+ attrs = collect (parameter_attributes (f, i))
395
+ byval[i] = any (attrs) do attr
396
+ kind (attr) == kind (EnumAttribute (" byval" , 0 ; ctx))
397
+ end
394
398
end
395
399
396
- # generate the wrapper function type & definition
397
- wrapper_types = LLVM. LLVMType[]
398
- for arg in args
399
- typ = if arg . cc == BITS_REF
400
- eltype (arg . codegen . typ )
400
+ # generate the new function type & definition
401
+ new_types = LLVM. LLVMType[]
402
+ for (i, param) in enumerate ( parameters (ft))
403
+ if byval[i]
404
+ push! (new_types, eltype (param :: LLVM.PointerType ) )
401
405
else
402
- convert (LLVMType, arg . typ; ctx )
406
+ push! (new_types, param )
403
407
end
404
- push! (wrapper_types, typ)
405
408
end
406
- wrapper_fn = LLVM. name (entry_f)
407
- LLVM. name! (entry_f, wrapper_fn * " .inner" )
408
- wrapper_ft = LLVM. FunctionType (LLVM. VoidType (ctx), wrapper_types)
409
- wrapper_f = LLVM. Function (mod, wrapper_fn, wrapper_ft)
409
+ new_ft = LLVM. FunctionType (return_type (ft), new_types)
410
+ new_f = LLVM. Function (mod, " " , new_ft)
411
+ linkage! (new_f, linkage (f))
410
412
411
413
# emit IR performing the "conversions"
412
- let builder = Builder (ctx)
413
- entry = BasicBlock (wrapper_f, " entry" ; ctx)
414
+ new_args = LLVM. Value[]
415
+ Builder (ctx) do builder
416
+ entry = BasicBlock (new_f, " entry" ; ctx)
414
417
position! (builder, entry)
415
418
416
- wrapper_args = Vector {LLVM.Value} ()
417
-
418
419
# perform argument conversions
419
- for arg in args
420
- if arg . cc == BITS_REF
420
+ for (i, param) in enumerate ( parameters (ft))
421
+ if byval[i]
421
422
# copy the argument value to a stack slot, and reference it.
422
- ptr = alloca! (builder, eltype (arg . codegen . typ ))
423
- if LLVM. addrspace (arg . codegen . typ ) != 0
424
- ptr = addrspacecast! (builder, ptr, arg . codegen . typ )
423
+ ptr = alloca! (builder, eltype (param ))
424
+ if LLVM. addrspace (param ) != 0
425
+ ptr = addrspacecast! (builder, ptr, param )
425
426
end
426
- store! (builder, parameters (wrapper_f)[arg . codegen . i], ptr)
427
- push! (wrapper_args , ptr)
427
+ store! (builder, parameters (new_f)[ i], ptr)
428
+ push! (new_args , ptr)
428
429
else
429
- push! (wrapper_args , parameters (wrapper_f)[arg . codegen . i])
430
- for attr in collect (parameter_attributes (entry_f, arg . codegen . i))
431
- push! (parameter_attributes (wrapper_f, arg . codegen . i), attr)
430
+ push! (new_args , parameters (new_f)[ i])
431
+ for attr in collect (parameter_attributes (f, i))
432
+ push! (parameter_attributes (new_f, i), attr)
432
433
end
433
434
end
434
435
end
435
436
436
- call! (builder, entry_f, wrapper_args)
437
-
438
- ret! (builder)
439
-
440
- dispose (builder)
441
- end
442
-
443
- # early-inline the original entry function into the wrapper
444
- push! (function_attributes (entry_f), EnumAttribute (" alwaysinline" , 0 ; ctx))
445
- linkage! (entry_f, LLVM. API. LLVMInternalLinkage)
437
+ # inline the old IR
438
+ value_map = Dict {LLVM.Value, LLVM.Value} (
439
+ param => new_args[i] for (i,param) in enumerate (parameters (f))
440
+ )
441
+ clone_into! (new_f, f; value_map,
442
+ changes= LLVM. API. LLVMCloneFunctionChangeTypeGlobalChanges)
443
+ # NOTE: we need global changes because LLVM 12 wants to clone debug metadata
446
444
447
- # copy debug info
448
- sp = LLVM. get_subprogram (entry_f)
449
- if sp != = nothing
450
- LLVM. set_subprogram! (wrapper_f, sp)
445
+ # fall through
446
+ br! (builder, collect (blocks (new_f))[2 ])
451
447
end
452
448
453
- fixup_metadata! (entry_f)
454
- ModulePassManager () do pm
455
- always_inliner! (pm)
456
- run! (pm, mod)
457
- end
449
+ # remove the old function
450
+ # NOTE: if we ever have legitimate uses of the old function, create a shim instead
451
+ fn = LLVM. name (f)
452
+ @assert isempty (uses (f))
453
+ # XXX : there may still be metadata using this function. RAUW updates those,
454
+ # but asserts on a debug build due to the updated function type.
455
+ unsafe_delete! (mod, f)
456
+ LLVM. name! (new_f, fn)
458
457
459
- return wrapper_f
460
- end
461
-
462
- # HACK: get rid of invariant.load and const TBAA metadata on loads from pointer args,
463
- # since storing to a stack slot violates the semantics of those attributes.
464
- # TODO : can we emit a wrapper that doesn't violate Julia's metadata?
465
- function fixup_metadata! (f:: LLVM.Function )
466
- for param in parameters (f)
467
- if isa (llvmtype (param), LLVM. PointerType)
468
- # collect all uses of the pointer
469
- worklist = Vector {LLVM.Instruction} (user .(collect (uses (param))))
470
- while ! isempty (worklist)
471
- value = popfirst! (worklist)
472
-
473
- # remove the invariant.load attribute
474
- md = metadata (value)
475
- if haskey (md, LLVM. MD_invariant_load)
476
- delete! (md, LLVM. MD_invariant_load)
477
- end
478
- if haskey (md, LLVM. MD_tbaa)
479
- delete! (md, LLVM. MD_tbaa)
480
- end
481
-
482
- # recurse on the output of some instructions
483
- if isa (value, LLVM. BitCastInst) ||
484
- isa (value, LLVM. GetElementPtrInst) ||
485
- isa (value, LLVM. AddrSpaceCastInst)
486
- append! (worklist, user .(collect (uses (value))))
487
- end
488
-
489
- # IMPORTANT NOTE: if we ever want to inline functions at the LLVM level,
490
- # we need to recurse into call instructions here, and strip metadata from
491
- # called functions (see CUDAnative.jl#238).
492
- end
493
- end
494
- end
458
+ return new_f
495
459
end
0 commit comments