@@ -588,8 +588,30 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
588
588
# this is both for extern uses, and to make this transformation a two-step process.
589
589
state_intr = kernel_state_intr (mod, T_state)
590
590
591
- # add a state argument to every function
592
- worklist = filter (! isdeclaration, collect (functions (mod)))
591
+ # determine which functions need a kernel state argument
592
+ #
593
+ # previously, we add the argument to every function and relied on unused arg elim to
594
+ # clean-up the IR. however, some libraries do Funny Stuff, e.g., libdevice bitcasting
595
+ # function pointers. such IR is hard to rewrite, so instead be more conservative.
596
+ worklist = Set {LLVM.Function} ([entry, state_intr])
597
+ worklist_length = 0
598
+ while worklist_length != length (worklist)
599
+ # iteratively discover functions that use the intrinsic or any function calling it
600
+ worklist_length = length (worklist)
601
+ additions = LLVM. Function[]
602
+ for f in worklist, use in uses (f)
603
+ inst = user (use):: Instruction
604
+ bb = LLVM. parent (inst)
605
+ new_f = LLVM. parent (bb)
606
+ in (new_f, worklist) || push! (additions, new_f)
607
+ end
608
+ for f in additions
609
+ push! (worklist, f)
610
+ end
611
+ end
612
+ delete! (worklist, state_intr)
613
+
614
+ # add a state argument
593
615
workmap = Dict {LLVM.Function, LLVM.Function} ()
594
616
for f in worklist
595
617
fn = LLVM. name (f)
@@ -726,14 +748,10 @@ function lower_kernel_state!(fun::LLVM.Function)
726
748
return false
727
749
end
728
750
729
- # find the kernel state argument. this should be the first argument of the function.
730
- state_arg = parameters (fun)[1 ]
731
- T_state = convert (LLVMType, state; ctx)
732
- @assert llvmtype (state_arg) == T_state
733
-
734
751
# fixup all uses of the state getter to use the newly introduced function state argument
735
752
if haskey (functions (mod), " julia.gpu.state_getter" )
736
753
state_intr = functions (mod)[" julia.gpu.state_getter" ]
754
+ state_arg = nothing # only look-up when needed
737
755
738
756
Builder (ctx) do builder
739
757
for use in uses (state_intr)
@@ -746,6 +764,14 @@ function lower_kernel_state!(fun::LLVM.Function)
746
764
bb = LLVM. parent (inst)
747
765
f = LLVM. parent (bb)
748
766
767
+ if state_arg === nothing
768
+ # find the kernel state argument. this should be the first argument of
769
+ # the function, but only when this function needs the state!
770
+ state_arg = parameters (fun)[1 ]
771
+ T_state = convert (LLVMType, state; ctx)
772
+ @assert llvmtype (state_arg) == T_state
773
+ end
774
+
749
775
replace_uses! (inst, state_arg)
750
776
751
777
@assert isempty (uses (inst))
0 commit comments