Skip to content

Commit e9ad136

Browse files
maleadtsimeonschaub
authored andcommitted
Convert the kernel state back to a reference when needed.
We're currently passing the kernel state object by value, disregarding the typical Julia calling convention, because there's known issues with `byval` lowering on NVPTX. For compatibility with back-ends that do not support passing kernel arguments by actual values, provide a pass that's conceptually the inverse of `lower_byval`, instead rewriting the kernel state object to be passed by reference, and loading from it at the beginning of the kernel.
1 parent aa05fa3 commit e9ad136

File tree

4 files changed

+158
-29
lines changed

4 files changed

+158
-29
lines changed

src/interface.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ runtime_slug(@nospecialize(job::CompilerJob)) = error("Not implemented")
267267
kernel_state_type(@nospecialize(job::CompilerJob)) = Nothing
268268

269269
# Does the target need to pass kernel arguments by value?
270-
needs_byval(@nospecialize(job::CompilerJob)) = true
270+
pass_by_value(@nospecialize(job::CompilerJob)) = true
271271

272272
# whether pointer is a valid call target
273273
valid_function_pointer(@nospecialize(job::CompilerJob), ptr::Ptr{Cvoid}) = false

src/irgen.jl

Lines changed: 113 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ function irgen(@nospecialize(job::CompilerJob))
8282

8383
# minimal required optimization
8484
@tracepoint "rewrite" begin
85-
if job.config.kernel && needs_byval(job)
85+
if job.config.kernel && pass_by_value(job)
8686
# pass all bitstypes by value; by default Julia passes aggregates by reference
8787
# (this improves performance, and is mandated by certain back-ends like SPIR-V).
8888
args = classify_arguments(job, function_type(entry))
@@ -256,10 +256,11 @@ end
256256
## kernel promotion
257257

258258
@enum ArgumentCC begin
259-
BITS_VALUE # bitstype, passed as value
260-
BITS_REF # bitstype, passed as pointer
261-
MUT_REF # jl_value_t*, or the anonymous equivalent
262-
GHOST # not passed
259+
BITS_VALUE # bitstype, passed as value
260+
BITS_REF # bitstype, passed as pointer
261+
MUT_REF # jl_value_t*, or the anonymous equivalent
262+
GHOST # not passed
263+
KERNEL_STATE # the kernel state argument
263264
end
264265

265266
# Determine the calling convention of a the arguments of a Julia function, given the
@@ -270,7 +271,8 @@ end
270271
# - `name`: the name of the argument
271272
# - `idx`: the index of the argument in the LLVM function type, or `nothing` if the argument
272273
# is not passed at the LLVM level.
273-
function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType)
274+
function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType;
275+
post_optimization::Bool=false)
274276
source_sig = job.source.specTypes
275277
source_types = [source_sig.parameters...]
276278

@@ -282,9 +284,15 @@ function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.Fu
282284

283285
codegen_types = parameters(codegen_ft)
284286

285-
args = []
286-
codegen_i = 1
287-
for (source_i, (source_typ, source_name)) in enumerate(zip(source_types, source_argnames))
287+
if post_optimization && kernel_state_type(job) !== Nothing
288+
args = []
289+
push!(args, (cc=KERNEL_STATE, typ=kernel_state_type(job), name=:kernel_state, idx=1))
290+
codegen_i = 2
291+
else
292+
args = []
293+
codegen_i = 1
294+
end
295+
for (source_typ, source_name) in zip(source_types, source_argnames)
288296
if isghosttype(source_typ) || Core.Compiler.isconstType(source_typ)
289297
push!(args, (cc=GHOST, typ=source_typ, name=source_name, idx=nothing))
290298
continue
@@ -817,3 +825,99 @@ function kernel_state_value(state)
817825
call_function(llvm_f, state)
818826
end
819827
end
828+
829+
# convert kernel state argument from pass-by-value to pass-by-reference
830+
#
831+
# the kernel state argument is always passed by value to avoid codegen issues with byval.
832+
# some back-ends however do not support passing kernel arguments by value, so this pass
833+
# serves to convert that argument (and is conceptually the inverse of `lower_byval`).
834+
function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
835+
f::LLVM.Function)
836+
ft = function_type(f)
837+
838+
# check if we even need a kernel state argument
839+
state = kernel_state_type(job)
840+
if state === Nothing
841+
return f
842+
end
843+
844+
T_state = convert(LLVMType, state)
845+
846+
# find the kernel state parameter (should be the first argument)
847+
if isempty(parameters(ft)) || value_type(parameters(f)[1]) != T_state
848+
return f
849+
end
850+
851+
@tracepoint "kernel state to reference" begin
852+
# generate the new function type & definition
853+
new_types = LLVM.LLVMType[]
854+
# convert the first parameter (kernel state) to a pointer
855+
push!(new_types, LLVM.PointerType(T_state))
856+
# keep all other parameters as-is
857+
for i in 2:length(parameters(ft))
858+
push!(new_types, parameters(ft)[i])
859+
end
860+
861+
new_ft = LLVM.FunctionType(return_type(ft), new_types)
862+
new_f = LLVM.Function(mod, "", new_ft)
863+
linkage!(new_f, linkage(f))
864+
865+
# name the parameters
866+
LLVM.name!(parameters(new_f)[1], "state_ptr")
867+
for (i, (arg, new_arg)) in enumerate(zip(parameters(f)[2:end], parameters(new_f)[2:end]))
868+
LLVM.name!(new_arg, LLVM.name(arg))
869+
end
870+
871+
# emit IR performing the "conversions"
872+
new_args = LLVM.Value[]
873+
@dispose builder=IRBuilder() begin
874+
entry = BasicBlock(new_f, "conversion")
875+
position!(builder, entry)
876+
877+
# load the kernel state value from the pointer
878+
state_val = load!(builder, T_state, parameters(new_f)[1], "state")
879+
push!(new_args, state_val)
880+
881+
# all other arguments are passed through directly
882+
for i in 2:length(parameters(new_f))
883+
push!(new_args, parameters(new_f)[i])
884+
end
885+
886+
# map the arguments
887+
value_map = Dict{LLVM.Value, LLVM.Value}(
888+
param => new_args[i] for (i,param) in enumerate(parameters(f))
889+
)
890+
value_map[f] = new_f
891+
892+
clone_into!(new_f, f; value_map,
893+
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
894+
895+
# fall through
896+
br!(builder, blocks(new_f)[2])
897+
end
898+
899+
# set the attributes for the state pointer parameter
900+
attrs = parameter_attributes(new_f, 1)
901+
# the pointer itself cannot be captured since we immediately load from it
902+
push!(attrs, EnumAttribute("nocapture", 0))
903+
# each kernel state is separate
904+
push!(attrs, EnumAttribute("noalias", 0))
905+
# the state is read-only
906+
push!(attrs, EnumAttribute("readonly", 0))
907+
908+
# remove the old function
909+
fn = LLVM.name(f)
910+
@assert isempty(uses(f))
911+
replace_metadata_uses!(f, new_f)
912+
erase!(f)
913+
LLVM.name!(new_f, fn)
914+
915+
# minimal optimization
916+
@dispose pb=NewPMPassBuilder() begin
917+
add!(pb, SimplifyCFGPass())
918+
run!(pb, new_f, llvm_machine(job.config.target))
919+
end
920+
921+
return new_f
922+
end
923+
end

src/metal.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ llvm_datalayout(target::MetalCompilerTarget) =
3535
"-v16:16:16-v24:32:32-v32:32:32-v48:64:64-v64:64:64-v96:128:128-v128:128:128-v192:256:256-v256:256:256-v512:512:512-v1024:1024:1024"*
3636
"-n8:16:32"
3737

38-
needs_byval(job::CompilerJob{MetalCompilerTarget}) = false
38+
pass_by_value(job::CompilerJob{MetalCompilerTarget}) = false
3939

4040

4141
## job
@@ -160,6 +160,11 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
160160
entry::LLVM.Function)
161161
entry_fn = LLVM.name(entry)
162162

163+
# convert the kernel state argument to a reference
164+
if job.config.kernel && kernel_state_type(job) !== Nothing
165+
entry = kernel_state_to_reference!(job, mod, entry)
166+
end
167+
163168
# add kernel metadata
164169
if job.config.kernel
165170
entry = add_parameter_address_spaces!(job, mod, entry)
@@ -235,7 +240,7 @@ function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLV
235240

236241
# find the byref parameters
237242
byref = BitVector(undef, length(parameters(ft)))
238-
args = classify_arguments(job, ft)
243+
args = classify_arguments(job, ft; post_optimization=true)
239244
filter!(args) do arg
240245
arg.cc != GHOST
241246
end
@@ -318,6 +323,7 @@ function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLV
318323

319324
# remove the old function
320325
fn = LLVM.name(f)
326+
prune_constexpr_uses!(f)
321327
@assert isempty(uses(f))
322328
replace_metadata_uses!(f, new_f)
323329
erase!(f)
@@ -418,7 +424,7 @@ end
418424

419425
# value-to-reference conversion
420426
#
421-
# Metal doesn't support passing valuse, so we need to convert those to references instead
427+
# Metal doesn't support passing values, so we need to convert those to references instead
422428
function pass_by_reference!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
423429
ft = function_type(f)
424430

@@ -717,7 +723,7 @@ function add_argument_metadata!(@nospecialize(job::CompilerJob), mod::LLVM.Modul
717723
arg_infos = Metadata[]
718724

719725
# Iterate through arguments and create metadata for them
720-
args = classify_arguments(job, entry_ft)
726+
args = classify_arguments(job, entry_ft; post_optimization=true)
721727
i = 1
722728
for arg in args
723729
arg.idx === nothing && continue

src/spirv.jl

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,8 @@ llvm_datalayout(::SPIRVCompilerTarget) = Int===Int64 ?
6262
runtime_slug(job::CompilerJob{SPIRVCompilerTarget}) =
6363
"spirv-" * String(job.config.target.backend)
6464

65-
function finish_module!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, entry::LLVM.Function)
65+
function finish_module!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
66+
entry::LLVM.Function)
6667
# update calling convention
6768
for f in functions(mod)
6869
# JuliaGPU/GPUCompiler.jl#97
@@ -72,6 +73,37 @@ function finish_module!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
7273
callconv!(entry, LLVM.API.LLVMSPIRKERNELCallConv)
7374
end
7475

76+
return entry
77+
end
78+
79+
function validate_ir(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module)
80+
errors = IRError[]
81+
82+
# support for half and double depends on the target
83+
if !job.config.target.supports_fp16
84+
append!(errors, check_ir_values(mod, LLVM.HalfType()))
85+
end
86+
if !job.config.target.supports_fp64
87+
append!(errors, check_ir_values(mod, LLVM.DoubleType()))
88+
end
89+
90+
return errors
91+
end
92+
93+
function finish_ir!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
94+
entry::LLVM.Function)
95+
# convert the kernel state argument to a byval reference
96+
if job.config.kernel
97+
state = kernel_state_type(job)
98+
if state !== Nothing
99+
entry = kernel_state_to_reference!(job, mod, entry)
100+
101+
T_state = convert(LLVMType, state)
102+
attr = TypeAttribute("byval", T_state)
103+
push!(parameter_attributes(entry, 1), attr)
104+
end
105+
end
106+
75107
# HACK: Intel's compute runtime doesn't properly support SPIR-V's byval attribute.
76108
# they do support struct byval, for OpenCL, so wrap byval parameters in a struct.
77109
if job.config.kernel
@@ -91,20 +123,6 @@ function finish_module!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
91123
return entry
92124
end
93125

94-
function validate_ir(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module)
95-
errors = IRError[]
96-
97-
# support for half and double depends on the target
98-
if !job.config.target.supports_fp16
99-
append!(errors, check_ir_values(mod, LLVM.HalfType()))
100-
end
101-
if !job.config.target.supports_fp64
102-
append!(errors, check_ir_values(mod, LLVM.DoubleType()))
103-
end
104-
105-
return errors
106-
end
107-
108126
@unlocked function mcgen(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
109127
format=LLVM.API.LLVMAssemblyFile)
110128
# The SPIRV Tools don't handle Julia's debug info, rejecting DW_LANG_Julia...
@@ -343,6 +361,7 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F
343361
# remove the old function
344362
# NOTE: if we ever have legitimate uses of the old function, create a shim instead
345363
fn = LLVM.name(f)
364+
prune_constexpr_uses!(f)
346365
@assert isempty(uses(f))
347366
replace_metadata_uses!(f, new_f)
348367
erase!(f)

0 commit comments

Comments
 (0)