@@ -501,30 +501,42 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
501
501
else
502
502
changes = LLVM. API. LLVMCloneFunctionChangeTypeLocalChangesOnly
503
503
end
504
- clone_into! (new_f, f; value_map, changes)
504
+
505
+ # use a value materializer for replacing uses of the function in constants
506
+ # NOTE: we assume kernel functions can't be called. on-device kernel launches,
507
+ # e.g. CUDA's dynamic parallelism, will pass the function to an API instead,
508
+ # and we update those constant expressions arguments here.
509
+ function materializer (val)
510
+ opcodes = (LLVM. API. LLVMPtrToInt, LLVM. API. LLVMAddrSpaceCast, LLVM. API. LLVMBitCast)
511
+ if val isa LLVM. ConstantExpr && opcode (val) in opcodes
512
+ target = operands (val)[1 ]
513
+ if target == f
514
+ return if opcode (val) == LLVM. API. LLVMPtrToInt
515
+ LLVM. const_ptrtoint (new_f, llvmtype (val))
516
+ elseif opcode (val) == LLVM. API. LLVMAddrSpaceCast
517
+ LLVM. const_addrspacecast (new_f, llvmtype (val))
518
+ elseif opcode (val) == LLVM. API. LLVMBitCast
519
+ LLVM. const_bitcast (new_f, llvmtype (val))
520
+ end
521
+ end
522
+ end
523
+ return val
524
+ end
525
+
526
+ # we don't want module-level changes, because otherwise LLVM will clone metadata,
527
+ # resulting in mismatching references between `!dbg` metadata and `dbg` instructions
528
+ clone_into! (new_f, f; value_map, changes, materializer)
505
529
506
530
# fall through
507
531
br! (builder, blocks (new_f)[2 ])
508
532
end
509
533
510
- # update uses of the kernel
511
- # NOTE: we assume kernel functions can't be called. on-device kernel launches,
512
- # e.g. CUDA's dynamic parallelism, will pass the function to an API instead,
513
- # and we update those constant expressions arguments here.
534
+ # drop unused constants that may be referring to the old functions
535
+ # XXX : can we do this differently?
514
536
for use in uses (f)
515
537
val = user (use)
516
- if val isa LLVM. ConstantExpr && opcode (val) == LLVM. API. LLVMPtrToInt
517
- target = operands (val)[1 ]
518
- if target == f
519
- new_val = LLVM. const_ptrtoint (new_f, llvmtype (val))
520
- replace_uses! (val, new_val)
521
-
522
- # drop the old constant if it is unused
523
- # XXX : can we do this differently?
524
- if isempty (uses (val))
525
- LLVM. unsafe_destroy! (val)
526
- end
527
- end
538
+ if val isa LLVM. ConstantExpr && isempty (uses (val))
539
+ LLVM. unsafe_destroy! (val)
528
540
end
529
541
end
530
542
@@ -576,8 +588,30 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
576
588
# this is both for extern uses, and to make this transformation a two-step process.
577
589
state_intr = kernel_state_intr (mod, T_state)
578
590
579
- # add a state argument to every function
580
- worklist = filter (! isdeclaration, collect (functions (mod)))
591
+ # determine which functions need a kernel state argument
592
+ #
593
+ # previously, we add the argument to every function and relied on unused arg elim to
594
+ # clean-up the IR. however, some libraries do Funny Stuff, e.g., libdevice bitcasting
595
+ # function pointers. such IR is hard to rewrite, so instead be more conservative.
596
+ worklist = Set {LLVM.Function} ([entry, state_intr])
597
+ worklist_length = 0
598
+ while worklist_length != length (worklist)
599
+ # iteratively discover functions that use the intrinsic or any function calling it
600
+ worklist_length = length (worklist)
601
+ additions = LLVM. Function[]
602
+ for f in worklist, use in uses (f)
603
+ inst = user (use):: Instruction
604
+ bb = LLVM. parent (inst)
605
+ new_f = LLVM. parent (bb)
606
+ in (new_f, worklist) || push! (additions, new_f)
607
+ end
608
+ for f in additions
609
+ push! (worklist, f)
610
+ end
611
+ end
612
+ delete! (worklist, state_intr)
613
+
614
+ # add a state argument
581
615
workmap = Dict {LLVM.Function, LLVM.Function} ()
582
616
for f in worklist
583
617
fn = LLVM. name (f)
@@ -608,10 +642,17 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
608
642
609
643
# use a value materializer for replacing uses of the function in constants
610
644
function materializer (val)
611
- if val isa LLVM. ConstantExpr && opcode (val) == LLVM. API. LLVMPtrToInt
645
+ opcodes = (LLVM. API. LLVMPtrToInt, LLVM. API. LLVMAddrSpaceCast, LLVM. API. LLVMBitCast)
646
+ if val isa LLVM. ConstantExpr && opcode (val) in opcodes
612
647
src = operands (val)[1 ]
613
648
if haskey (workmap, src)
614
- return LLVM. const_ptrtoint (workmap[src], llvmtype (val))
649
+ return if opcode (val) == LLVM. API. LLVMPtrToInt
650
+ LLVM. const_ptrtoint (workmap[src], llvmtype (val))
651
+ elseif opcode (val) == LLVM. API. LLVMAddrSpaceCast
652
+ LLVM. const_addrspacecast (workmap[src], llvmtype (val))
653
+ elseif opcode (val) == LLVM. API. LLVMBitCast
654
+ LLVM. const_bitcast (workmap[src], llvmtype (val))
655
+ end
615
656
end
616
657
end
617
658
return val
@@ -677,20 +718,6 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
677
718
replace_uses! (val, new_val)
678
719
@assert isempty (uses (val))
679
720
unsafe_delete! (LLVM. parent (val), val)
680
- elseif val isa LLVM. ConstantExpr && opcode (val) == LLVM. API. LLVMBitCast
681
- # XXX : why isn't this caught by the value materializer above?
682
- target = operands (val)[1 ]
683
- @assert target == f
684
- new_val = LLVM. const_bitcast (new_f, llvmtype (val))
685
- rewrite_uses! (val, new_val)
686
- # we can't simply replace this constant expression, as it may be used
687
- # as a call, taking arguments (so we need to rewrite it to pass the state)
688
-
689
- # drop the old constant if it is unused
690
- # XXX : can we do this differently?
691
- if isempty (uses (val))
692
- LLVM. unsafe_destroy! (val)
693
- end
694
721
else
695
722
error (" Cannot rewrite unknown use of function: $val " )
696
723
end
@@ -721,14 +748,10 @@ function lower_kernel_state!(fun::LLVM.Function)
721
748
return false
722
749
end
723
750
724
- # find the kernel state argument. this should be the first argument of the function.
725
- state_arg = parameters (fun)[1 ]
726
- T_state = convert (LLVMType, state; ctx)
727
- @assert llvmtype (state_arg) == T_state
728
-
729
751
# fixup all uses of the state getter to use the newly introduced function state argument
730
752
if haskey (functions (mod), " julia.gpu.state_getter" )
731
753
state_intr = functions (mod)[" julia.gpu.state_getter" ]
754
+ state_arg = nothing # only look-up when needed
732
755
733
756
Builder (ctx) do builder
734
757
for use in uses (state_intr)
@@ -741,6 +764,14 @@ function lower_kernel_state!(fun::LLVM.Function)
741
764
bb = LLVM. parent (inst)
742
765
f = LLVM. parent (bb)
743
766
767
+ if state_arg === nothing
768
+ # find the kernel state argument. this should be the first argument of
769
+ # the function, but only when this function needs the state!
770
+ state_arg = parameters (fun)[1 ]
771
+ T_state = convert (LLVMType, state; ctx)
772
+ @assert llvmtype (state_arg) == T_state
773
+ end
774
+
744
775
replace_uses! (inst, state_arg)
745
776
746
777
@assert isempty (uses (inst))
0 commit comments