@@ -553,18 +553,28 @@ end
553
553
# cast to an appropriate type, while (2) ensuring the state resides in thread-local memory
554
554
# so that it can be used without synchronizing global-memory accesses.
555
555
function add_kernel_state! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
556
- entry:: LLVM.Function , T_state :: LLVMType )
556
+ entry:: LLVM.Function )
557
557
ctx = context (mod)
558
+ entry_fn = LLVM. name (entry)
559
+
560
+ # check if we even need a kernel state argument
561
+ state = kernel_state_type (job)
562
+ if state === Nothing
563
+ return false
564
+ end
565
+ T_state = convert (LLVMType, state; ctx)
566
+ T_ptr_state = LLVM. PointerType (T_state)
558
567
559
568
# intrinsic returning an opaque pointer to the kernel state.
560
569
# this is both for extern uses, and to make this transformation a two-step process.
561
- T_ptr_state = LLVM. PointerType (T_state)
562
- state_getter = if haskey (functions (mod), " julia.gpu.state_getter" )
570
+ T_int8 = LLVM. IntType (8 ; ctx)
571
+ T_pint8 = LLVM. PointerType (T_int8)
572
+ state_intr = if haskey (functions (mod), " julia.gpu.state_getter" )
563
573
functions (mod)[" julia.gpu.state_getter" ]
564
574
else
565
- LLVM. Function (mod, " julia.gpu.state_getter" , LLVM. FunctionType (T_ptr_state ))
575
+ LLVM. Function (mod, " julia.gpu.state_getter" , LLVM. FunctionType (T_int8 ))
566
576
end
567
- push! (function_attributes (state_getter ), EnumAttribute (" readnone" , 0 ; ctx))
577
+ push! (function_attributes (state_intr ), EnumAttribute (" readnone" , 0 ; ctx))
568
578
569
579
# add a state argument to every function
570
580
worklist = filter (! isdeclaration, collect (functions (mod)))
@@ -649,7 +659,8 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
649
659
650
660
# forward the state argument
651
661
position! (builder, val)
652
- state = call! (builder, state_getter, Value[], " state" )
662
+ state = call! (builder, state_intr, Value[], " state" )
663
+ state = bitcast! (builder, state, T_ptr_state)
653
664
new_val = if val isa LLVM. CallInst
654
665
call! (builder, new_f, [state, operands (val)[1 : end - 1 ]. .. ])
655
666
else
@@ -688,25 +699,89 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
688
699
end
689
700
690
701
# fixup all uses of the state getter to use the newly introduced function state argument
691
- for use in uses (state_getter)
692
- inst = user (use)
693
- @assert inst isa LLVM. CallInst
702
+ Builder (ctx) do builder
703
+ for use in uses (state_intr)
704
+ inst = user (use)
705
+ @assert inst isa LLVM. CallInst
694
706
695
- bb = LLVM. parent (inst)
696
- f = LLVM. parent (bb)
707
+ position! (builder, inst)
708
+ bb = LLVM. parent (inst)
709
+ f = LLVM. parent (bb)
697
710
698
- replace_uses! (inst, parameters (f)[1 ])
699
- @assert isempty (uses (inst))
700
- unsafe_delete! (LLVM. parent (inst), inst)
711
+ state = parameters (f)[1 ]
712
+ state = bitcast! (builder, state, T_int8)
713
+ replace_uses! (inst, state)
714
+
715
+ @assert isempty (uses (inst))
716
+ unsafe_delete! (LLVM. parent (inst), inst)
717
+ end
718
+ end
719
+
720
+ # HACK: add a dummy use of the kernel state pointer to ensure it is always available
721
+ # also see `kernel_state_argument` below.
722
+ dummy_user = if haskey (functions (mod), " julia.gpu.state_user" )
723
+ functions (mod)[" julia.gpu.state_user" ]
724
+ else
725
+ LLVM. Function (mod, " julia.gpu.state_user" ,
726
+ LLVM. FunctionType (LLVM. VoidType (ctx), [T_ptr_state]))
727
+ end
728
+ entry = functions (mod)[entry_fn]
729
+ Builder (ctx) do builder
730
+ position! (builder, first (instructions (first (blocks (entry)))))
731
+ call! (builder, dummy_user, [parameters (entry)[1 ]])
701
732
end
702
733
703
734
# clean-up
704
- @assert isempty (uses (state_getter))
705
- unsafe_delete! (mod, state_getter)
735
+ @assert isempty (uses (state_intr))
736
+ unsafe_delete! (mod, state_intr)
737
+
738
+ # don't pass the state when unnecessary
739
+ # XXX : isn't this done during optimization as well?
740
+ ModulePassManager () do pm
741
+ dead_arg_elimination! (pm)
742
+ run! (pm, mod)
743
+ end
744
+
745
+ return true
746
+ end
747
+
748
+ # return a value pointing to the state argument in a given function.
749
+ function kernel_state_argument (f:: LLVM.Function , state:: Type )
750
+ ctx = context (f)
751
+ mod = LLVM. parent (f)
752
+
753
+ T_state = convert (LLVMType, state; ctx)
754
+ T_ptr_state = LLVM. PointerType (T_state)
755
+
756
+ arg = parameters (f)[1 ]
757
+ if llvmtype (arg) == T_ptr_state
758
+ return arg
759
+ end
760
+
761
+ # if the first argument isn't a valid kernel state pointer, this probably means we're
762
+ # in a kernel function whose byval-annotated kernel state argument got lowered eagerly.
763
+ # to make sure we can still get a pointer to the kernel state, we've emitted a dummy
764
+ # use, which we can use here to get a pointer to the kernel state.
765
+ #
766
+ # this is obviously a hack, stemming from the fact that while lowering Julia intrinsics
767
+ # (which needs to happen _after_ optimization) we may have to emit calls to the GPU
768
+ # runtime while those functions may already have had their kernel state arguments added
769
+ # (which we do _before_ optimization to make sure that any lowered byval performs well).
770
+ @assert llvmtype (arg) == T_state
771
+ dummy_user = functions (mod)[" julia.gpu.state_user" ]
772
+ for use in uses (dummy_user)
773
+ call = user (use)
774
+ bb = LLVM. parent (call)
775
+ if LLVM. parent (bb) == f
776
+ arg = operands (call)[1 ]
777
+ return arg
778
+ end
779
+ end
706
780
707
- return
781
+ error ( " Internal compiler error: could not reconstruct kernel state argument " )
708
782
end
709
783
784
+ # run-time equivalent (untyped)
710
785
@inline kernel_state_pointer () = Base. llvmcall (("""
711
786
declare i8* @julia.gpu.state_getter()
712
787
716
791
ret i64 %ptr
717
792
}
718
793
719
- attributes #0 = { alwaysinline }""" , " entry" ),
794
+ attributes #0 = { alwaysinline readnone }""" , " entry" ),
720
795
Ptr{Cvoid}, Tuple{})
0 commit comments