@@ -545,20 +545,22 @@ end
545
545
546
546
# kernel state arguments
547
547
#
548
- # add a state argument every function in the module, and lower calls to the
549
- # `julia.gpu.state_getter` intrinsics to use this newly-introduced state argument.
550
- #
551
- # the type of the state is determined by the `kernel_state_type` interface, and is passed
552
- # as a byval pointer so that (1) the intrinsic can use an opaque pointer for users to
553
- # cast to an appropriate type, while (2) ensuring the state resides in thread-local memory
554
- # so that it can be used without synchronizing global-memory accesses.
548
+ # to facilitate passing stateful information to kernels without having to recompile, e.g.,
549
+ # the storage location for exception flags, or the location of a I/O buffer, we enable the
550
+ # back-end to specify a Julia object that will be passed to the kernel by-value, and to
551
+ # every called function by-reference. Access to this object is done using the
552
+ # `julia.gpu.state_getter` intrinsic, which returns an opaque pointer to the state object.
553
+ # after optimization, these intrinsics will be lowered to refer to the state argument.
554
+
555
+ # add a state argument to every function in the module, starting from the kernel entry point
555
556
function add_kernel_state! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
556
557
entry:: LLVM.Function )
557
558
ctx = context (mod)
558
559
entry_fn = LLVM. name (entry)
559
560
560
561
# check if we even need a kernel state argument
561
562
state = kernel_state_type (job)
563
+ @assert job. source. kernel
562
564
if state === Nothing
563
565
return false
564
566
end
@@ -569,12 +571,7 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
569
571
# this is both for extern uses, and to make this transformation a two-step process.
570
572
T_int8 = LLVM. IntType (8 ; ctx)
571
573
T_pint8 = LLVM. PointerType (T_int8)
572
- state_intr = if haskey (functions (mod), " julia.gpu.state_getter" )
573
- functions (mod)[" julia.gpu.state_getter" ]
574
- else
575
- LLVM. Function (mod, " julia.gpu.state_getter" , LLVM. FunctionType (T_pint8))
576
- end
577
- push! (function_attributes (state_intr), EnumAttribute (" readnone" , 0 ; ctx))
574
+ state_intr = kernel_state_intr (mod)
578
575
579
576
# add a state argument to every function
580
577
worklist = filter (! isdeclaration, collect (functions (mod)))
@@ -659,10 +656,10 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
659
656
660
657
# forward the state argument
661
658
position! (builder, val)
662
- state = call! (builder, state_intr, Value[], " state" )
663
- state = bitcast! (builder, state , T_ptr_state)
659
+ untyped_state = call! (builder, state_intr, Value[], " state" )
660
+ typed_state = bitcast! (builder, untyped_state , T_ptr_state)
664
661
new_val = if val isa LLVM. CallInst
665
- call! (builder, new_f, [state , operands (val)[1 : end - 1 ]. .. ])
662
+ call! (builder, new_f, [typed_state , operands (val)[1 : end - 1 ]. .. ])
666
663
else
667
664
# TODO : invoke and callbr
668
665
error (" Rewrite of $(typeof (val)) -based calls is not implemented: $val " )
@@ -698,27 +695,7 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
698
695
unsafe_delete! (mod, f)
699
696
end
700
697
701
- # fixup all uses of the state getter to use the newly introduced function state argument
702
- Builder (ctx) do builder
703
- for use in uses (state_intr)
704
- inst = user (use)
705
- @assert inst isa LLVM. CallInst
706
-
707
- position! (builder, inst)
708
- bb = LLVM. parent (inst)
709
- f = LLVM. parent (bb)
710
-
711
- state = parameters (f)[1 ]
712
- state = bitcast! (builder, state, T_pint8)
713
- replace_uses! (inst, state)
714
-
715
- @assert isempty (uses (inst))
716
- unsafe_delete! (LLVM. parent (inst), inst)
717
- end
718
- end
719
-
720
- # HACK: add a dummy use of the kernel state pointer to ensure it is always available
721
- # also see `kernel_state_argument` below.
698
+ # HACK: add a dummy use of the kernel state pointer to ensure it survives optimization
722
699
dummy_user = if haskey (functions (mod), " julia.gpu.state_user" )
723
700
functions (mod)[" julia.gpu.state_user" ]
724
701
else
@@ -731,54 +708,112 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
731
708
call! (builder, dummy_user, [parameters (entry)[1 ]])
732
709
end
733
710
734
- # clean-up
735
- @assert isempty (uses (state_intr))
736
- unsafe_delete! (mod, state_intr)
737
-
738
- # don't pass the state when unnecessary
739
- # XXX : isn't this done during optimization as well?
740
- ModulePassManager () do pm
741
- dead_arg_elimination! (pm)
742
- run! (pm, mod)
743
- end
744
-
745
711
return true
746
712
end
747
713
748
- # return a value pointing to the state argument in a given function.
749
- function kernel_state_argument (f:: LLVM.Function , state:: Type )
750
- ctx = context (f)
751
- mod = LLVM. parent (f)
714
+ # lower calls to the state getter intrinsic. this is a two-step process, so that the state
715
+ # argument can be added before optimization, and that optimization can introduce new uses
716
+ # before the intrinsic getting lowered late during optimization.
717
+ #
718
+ # the reason we want to add the state argument before optimization, is that the initial
719
+ # argument is marked byval, but some backends need to eagerly lower that byval property
720
+ # (because the LLVM back-end doesn't support emitting code for it). That lowering typically
721
+ # generates a lot of expensive code, so _needs_ to be optimized.
722
+ function lower_kernel_state! (fun:: LLVM.Function )
723
+ job = current_job:: CompilerJob
724
+ mod = LLVM. parent (fun)
725
+ ctx = context (fun)
726
+ changed = false
727
+
728
+ # check if we even need a kernel state argument
729
+ if ! job. source. kernel
730
+ # only kernels have had a kernel state argument added
731
+ return false
732
+ end
733
+ state = kernel_state_type (job)
734
+ if state === Nothing
735
+ return false
736
+ end
752
737
738
+ # find the kernel state argument. normally, this is the first argument of the function.
739
+ state_arg = nothing
753
740
T_state = convert (LLVMType, state; ctx)
754
741
T_ptr_state = LLVM. PointerType (T_state)
755
-
756
- arg = parameters (f)[1 ]
757
- if llvmtype (arg) == T_ptr_state
758
- return arg
742
+ first_arg = parameters (fun)[1 ]
743
+ if llvmtype (first_arg) == T_ptr_state
744
+ state_arg = first_arg
759
745
end
760
746
761
- # if the first argument isn't a valid kernel state pointer, this probably means we're
762
- # in a kernel function whose byval-annotated kernel state argument got lowered eagerly.
763
- # to make sure we can still get a pointer to the kernel state, we've emitted a dummy
764
- # use, which we can use here to get a pointer to the kernel state.
747
+ # with kernels, the story is more complicated: the kernel state argument is marked byval,
748
+ # and it's possible we eagerly lowered that pointer to a value. to retrieve the state,
749
+ # look for the alloca slot the argument was stored in via the dummy use we introduced.
765
750
#
766
751
# this is obviously a hack, stemming from the fact that while lowering Julia intrinsics
767
752
# (which needs to happen _after_ optimization) we may have to emit calls to the GPU
768
753
# runtime while those functions may already have had their kernel state arguments added
769
754
# (which we do _before_ optimization to make sure that any lowered byval performs well).
770
- @assert llvmtype (arg) == T_state
771
- dummy_user = functions (mod)[" julia.gpu.state_user" ]
772
- for use in uses (dummy_user)
773
- call = user (use)
774
- bb = LLVM. parent (call)
775
- if LLVM. parent (bb) == f
776
- arg = operands (call)[1 ]
777
- return arg
755
+ if state_arg === nothing
756
+ @assert llvmtype (first_arg) == T_state
757
+ dummy_user = functions (mod)[" julia.gpu.state_user" ]
758
+ for use in uses (dummy_user)
759
+ call = user (use)
760
+ bb = LLVM. parent (call)
761
+ if LLVM. parent (bb) == fun
762
+ state_arg = operands (call)[1 ]
763
+ break
764
+ end
765
+ end
766
+ end
767
+
768
+ if state_arg === nothing
769
+ error (" Internal compiler error: could not reconstruct kernel state argument" )
770
+ end
771
+
772
+ # get the intrinsic returning an opaque pointer to the kernel state.
773
+ T_int8 = LLVM. IntType (8 ; ctx)
774
+ T_pint8 = LLVM. PointerType (T_int8)
775
+ state_intr = kernel_state_intr (mod)
776
+
777
+ # fixup all uses of the state getter to use the newly introduced function state argument
778
+ Builder (ctx) do builder
779
+ for use in uses (state_intr)
780
+ inst = user (use)
781
+ @assert inst isa LLVM. CallInst
782
+
783
+ position! (builder, inst)
784
+ bb = LLVM. parent (inst)
785
+ f = LLVM. parent (bb)
786
+
787
+ untyped_state = bitcast! (builder, state_arg, T_pint8)
788
+ replace_uses! (inst, untyped_state)
789
+
790
+ @assert isempty (uses (inst))
791
+ unsafe_delete! (LLVM. parent (inst), inst)
792
+
793
+ changed = true
778
794
end
779
795
end
780
796
781
- error (" Internal compiler error: could not reconstruct kernel state argument" )
797
+ # clean-up
798
+ @assert isempty (uses (state_intr))
799
+ unsafe_delete! (mod, state_intr)
800
+
801
+ return changed
802
+ end
803
+
804
+ function kernel_state_intr (mod:: LLVM.Module )
805
+ ctx = context (mod)
806
+ T_int8 = LLVM. IntType (8 ; ctx)
807
+ T_pint8 = LLVM. PointerType (T_int8)
808
+
809
+ state_intr = if haskey (functions (mod), " julia.gpu.state_getter" )
810
+ functions (mod)[" julia.gpu.state_getter" ]
811
+ else
812
+ LLVM. Function (mod, " julia.gpu.state_getter" , LLVM. FunctionType (T_pint8))
813
+ end
814
+ push! (function_attributes (state_intr), EnumAttribute (" readnone" , 0 ; ctx))
815
+
816
+ return state_intr
782
817
end
783
818
784
819
# run-time equivalent (untyped)
0 commit comments