@@ -195,7 +195,7 @@ function lower_throw!(mod::LLVM.Module)
195
195
end
196
196
197
197
# remove the call
198
- call_args = operands (call)[ 1 : end - 1 ] # last arg is function itself
198
+ call_args = arguments (call)
199
199
unsafe_delete! (LLVM. parent (call), call)
200
200
201
201
# HACK: kill the exceptions' unused arguments
@@ -406,9 +406,6 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
406
406
byval[arg. codegen. i+ has_kernel_state] = true
407
407
end
408
408
end
409
- if has_kernel_state
410
- byval[1 ] = true
411
- end
412
409
end
413
410
414
411
# fixup metadata
549
546
# the storage location for exception flags, or the location of a I/O buffer, we enable the
550
547
# back-end to specify a Julia object that will be passed to the kernel by-value, and to
551
548
# every called function by-reference. Access to this object is done using the
552
- # `julia.gpu.state_getter` intrinsic, which returns an opaque pointer to the state object.
553
- # after optimization, these intrinsics will be lowered to refer to the state argument.
549
+ # `julia.gpu.state_getter` intrinsic. after optimization, these intrinsics will be lowered
550
+ # to refer to the state argument.
551
+ #
552
+ # note that we deviate from the typical Julia calling convention, by always passing the
553
+ # state objects by value instead of by reference, this to ensure that the state object
554
+ # is not copied to the stack (because LLVM doesn't see that all uses are read-only).
555
+ # in principle, `readonly byval` should be equivalent, but LLVM doesn't realize that.
556
+ # also see https://github.com/JuliaGPU/CUDA.jl/pull/1167 and the comments in that PR.
557
+ # once LLVM supports this pattern, consider going back to passing the state by reference,
558
+ # so that the julia.gpu.state_getter` can be simplified to return an opaque pointer.
554
559
555
560
# add a state argument to every function in the module, starting from the kernel entry point
556
561
function add_kernel_state! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
@@ -565,13 +570,10 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
565
570
return false
566
571
end
567
572
T_state = convert (LLVMType, state; ctx)
568
- T_ptr_state = LLVM. PointerType (T_state)
569
573
570
574
# intrinsic returning an opaque pointer to the kernel state.
571
575
# this is both for extern uses, and to make this transformation a two-step process.
572
- T_int8 = LLVM. IntType (8 ; ctx)
573
- T_pint8 = LLVM. PointerType (T_int8)
574
- state_intr = kernel_state_intr (mod)
576
+ state_intr = kernel_state_intr (mod, T_state)
575
577
576
578
# add a state argument to every function
577
579
worklist = filter (! isdeclaration, collect (functions (mod)))
@@ -582,7 +584,7 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
582
584
LLVM. name! (f, fn * " .stateless" )
583
585
584
586
# create a new function
585
- new_param_types = [T_ptr_state , parameters (ft)... ]
587
+ new_param_types = [T_state , parameters (ft)... ]
586
588
new_ft = LLVM. FunctionType (return_type (ft), new_param_types)
587
589
new_f = LLVM. Function (mod, fn, new_ft)
588
590
LLVM. name! (parameters (new_f)[1 ], " state" )
@@ -618,16 +620,6 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
618
620
clone_into! (new_f, f; value_map, materializer,
619
621
changes= LLVM. API. LLVMCloneFunctionChangeTypeGlobalChanges)
620
622
621
- # pass the state by value to the kernel (after cloning, which overwrites attributes)
622
- if f == entry
623
- attr = if LLVM. version () >= v " 12"
624
- TypeAttribute (" byval" , T_state; ctx)
625
- else
626
- EnumAttribute (" byval" , 0 ; ctx)
627
- end
628
- push! (parameter_attributes (new_f, 1 ), attr)
629
- end
630
-
631
623
# we can't remove this function yet, as we might still need to rewrite any called,
632
624
# but remove the IR already
633
625
empty! (f)
@@ -656,10 +648,9 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
656
648
657
649
# forward the state argument
658
650
position! (builder, val)
659
- untyped_state = call! (builder, state_intr, Value[], " state" )
660
- typed_state = bitcast! (builder, untyped_state, T_ptr_state)
651
+ state = call! (builder, state_intr, Value[], " state" )
661
652
new_val = if val isa LLVM. CallInst
662
- call! (builder, new_f, [typed_state, operands (val)[ 1 : end - 1 ] . .. ])
653
+ call! (builder, new_f, [state, arguments (val)... ], operand_bundles (val) )
663
654
else
664
655
# TODO : invoke and callbr
665
656
error (" Rewrite of $(typeof (val)) -based calls is not implemented: $val " )
@@ -695,30 +686,12 @@ function add_kernel_state!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
695
686
unsafe_delete! (mod, f)
696
687
end
697
688
698
- # HACK: add a dummy use of the kernel state pointer to ensure it survives optimization
699
- dummy_user = if haskey (functions (mod), " julia.gpu.state_user" )
700
- functions (mod)[" julia.gpu.state_user" ]
701
- else
702
- LLVM. Function (mod, " julia.gpu.state_user" ,
703
- LLVM. FunctionType (LLVM. VoidType (ctx), [T_ptr_state]))
704
- end
705
- entry = functions (mod)[entry_fn]
706
- Builder (ctx) do builder
707
- position! (builder, first (instructions (first (blocks (entry)))))
708
- call! (builder, dummy_user, [parameters (entry)[1 ]])
709
- end
710
-
711
689
return true
712
690
end
713
691
714
692
# lower calls to the state getter intrinsic. this is a two-step process, so that the state
715
693
# argument can be added before optimization, and that optimization can introduce new uses
716
694
# before the intrinsic getting lowered late during optimization.
717
- #
718
- # the reason we want to add the state argument before optimization, is that the initial
719
- # argument is marked byval, but some backends need to eagerly lower that byval property
720
- # (because the LLVM back-end doesn't support emitting code for it). That lowering typically
721
- # generates a lot of expensive code, so _needs_ to be optimized.
722
695
function lower_kernel_state! (fun:: LLVM.Function )
723
696
job = current_job:: CompilerJob
724
697
mod = LLVM. parent (fun)
@@ -731,64 +704,33 @@ function lower_kernel_state!(fun::LLVM.Function)
731
704
return false
732
705
end
733
706
734
- # find the kernel state argument. normally, this is the first argument of the function.
735
- state_arg = nothing
707
+ # find the kernel state argument. this should be the first argument of the function.
708
+ state_arg = parameters (fun)[ 1 ]
736
709
T_state = convert (LLVMType, state; ctx)
737
- T_ptr_state = LLVM. PointerType (T_state)
738
- first_arg = parameters (fun)[1 ]
739
- if llvmtype (first_arg) == T_ptr_state
740
- state_arg = first_arg
741
- end
742
-
743
- # with kernels, the story is more complicated: the kernel state argument is marked byval,
744
- # and it's possible we eagerly lowered that pointer to a value. to retrieve the state,
745
- # look for the alloca slot the argument was stored in via the dummy use we introduced.
746
- #
747
- # this is obviously a hack, stemming from the fact that while lowering Julia intrinsics
748
- # (which needs to happen _after_ optimization) we may have to emit calls to the GPU
749
- # runtime while those functions may already have had their kernel state arguments added
750
- # (which we do _before_ optimization to make sure that any lowered byval performs well).
751
- if state_arg === nothing
752
- @assert llvmtype (first_arg) == T_state
753
- dummy_user = functions (mod)[" julia.gpu.state_user" ]
754
- for use in uses (dummy_user)
755
- call = user (use)
756
- bb = LLVM. parent (call)
757
- if LLVM. parent (bb) == fun
758
- state_arg = operands (call)[1 ]
759
- break
760
- end
761
- end
762
- end
763
-
764
- if state_arg === nothing
765
- error (" Internal compiler error: could not reconstruct kernel state argument" )
766
- end
767
-
768
- # get the intrinsic returning an opaque pointer to the kernel state.
769
- T_int8 = LLVM. IntType (8 ; ctx)
770
- T_pint8 = LLVM. PointerType (T_int8)
771
- state_intr = kernel_state_intr (mod)
710
+ @assert llvmtype (state_arg) == T_state
772
711
773
712
# fixup all uses of the state getter to use the newly introduced function state argument
774
- Builder (ctx) do builder
775
- for use in uses (state_intr)
776
- inst = user (use)
777
- @assert inst isa LLVM. CallInst
778
- bb = LLVM. parent (inst)
779
- LLVM. parent (bb) == fun || continue
713
+ if haskey (functions (mod), " julia.gpu.state_getter" )
714
+ state_intr = functions (mod)[" julia.gpu.state_getter" ]
715
+
716
+ Builder (ctx) do builder
717
+ for use in uses (state_intr)
718
+ inst = user (use)
719
+ @assert inst isa LLVM. CallInst
720
+ bb = LLVM. parent (inst)
721
+ LLVM. parent (bb) == fun || continue
780
722
781
- position! (builder, inst)
782
- bb = LLVM. parent (inst)
783
- f = LLVM. parent (bb)
723
+ position! (builder, inst)
724
+ bb = LLVM. parent (inst)
725
+ f = LLVM. parent (bb)
784
726
785
- untyped_state = bitcast! (builder, state_arg, T_pint8)
786
- replace_uses! (inst, untyped_state)
727
+ replace_uses! (inst, state_arg)
787
728
788
- @assert isempty (uses (inst))
789
- unsafe_delete! (LLVM. parent (inst), inst)
729
+ @assert isempty (uses (inst))
730
+ unsafe_delete! (LLVM. parent (inst), inst)
790
731
791
- changed = true
732
+ changed = true
733
+ end
792
734
end
793
735
end
794
736
@@ -810,45 +752,44 @@ function cleanup_kernel_state!(mod::LLVM.Module)
810
752
end
811
753
end
812
754
813
- # remove the kernel state dummy use
814
- if haskey (functions (mod), " julia.gpu.state_user" )
815
- intr = functions (mod)[" julia.gpu.state_user" ]
816
- for use in uses (intr)
817
- call = user (use)
818
- unsafe_delete! (LLVM. parent (call), call)
819
- end
820
- @assert isempty (uses (intr))
821
- unsafe_delete! (mod, intr)
822
- changed = true
823
- end
824
-
825
755
return changed
826
756
end
827
757
828
- function kernel_state_intr (mod:: LLVM.Module )
758
+ function kernel_state_intr (mod:: LLVM.Module , T_state )
829
759
ctx = context (mod)
830
- T_int8 = LLVM. IntType (8 ; ctx)
831
- T_pint8 = LLVM. PointerType (T_int8)
832
760
833
761
state_intr = if haskey (functions (mod), " julia.gpu.state_getter" )
834
762
functions (mod)[" julia.gpu.state_getter" ]
835
763
else
836
- LLVM. Function (mod, " julia.gpu.state_getter" , LLVM. FunctionType (T_pint8 ))
764
+ LLVM. Function (mod, " julia.gpu.state_getter" , LLVM. FunctionType (T_state ))
837
765
end
838
766
push! (function_attributes (state_intr), EnumAttribute (" readnone" , 0 ; ctx))
839
767
840
768
return state_intr
841
769
end
842
770
843
- # run-time equivalent (untyped)
844
- @inline kernel_state_pointer () = Base. llvmcall (("""
845
- declare i8* @julia.gpu.state_getter()
771
+ # run-time equivalent
772
+ function kernel_state_value (state)
773
+ Context () do ctx
774
+ T_state = convert (LLVMType, state; ctx)
775
+
776
+ # create function
777
+ llvm_f, _ = create_function (T_state)
778
+ mod = LLVM. parent (llvm_f)
779
+
780
+ # get intrinsic
781
+ state_intr = kernel_state_intr (mod, T_state)
782
+
783
+ # generate IR
784
+ Builder (ctx) do builder
785
+ entry = BasicBlock (llvm_f, " entry" ; ctx)
786
+ position! (builder, entry)
787
+
788
+ val = call! (builder, state_intr, Value[], " state" )
846
789
847
- define i64 @entry() #0 {
848
- %ptls = call i8* @julia.gpu.state_getter()
849
- %ptr = ptrtoint i8* %ptls to i64
850
- ret i64 %ptr
851
- }
790
+ ret! (builder, val)
791
+ end
852
792
853
- attributes #0 = { alwaysinline readnone }""" , " entry" ),
854
- Ptr{Cvoid}, Tuple{})
793
+ call_function (llvm_f, state)
794
+ end
795
+ end
0 commit comments