Skip to content

Commit d9b1b08

Browse files
committed
Remove kernel state special casing now that the transformation happens late.
1 parent 27b82ba commit d9b1b08

File tree

5 files changed

+9
-54
lines changed

5 files changed

+9
-54
lines changed

src/irgen.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -404,21 +404,14 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
404404
# XXX: byval is not round-trippable on LLVM < 12 (see maleadt/LLVM.jl#186)
405405
# so we need to re-classify the Julia arguments.
406406
# remove this once we only support 1.7.
407-
has_kernel_state = kernel_state_type(job) !== Nothing
408-
orig_ft = if has_kernel_state
409-
# the kernel state has been added here already, so strip the first parameter
410-
LLVM.FunctionType(LLVM.return_type(ft), parameters(ft)[2:end]; vararg=isvararg(ft))
411-
else
412-
ft
413-
end
414-
args = classify_arguments(job, orig_ft)
407+
args = classify_arguments(job, ft)
415408
filter!(args) do arg
416409
arg.cc != GHOST
417410
end
418411
for arg in args
419412
if arg.cc == BITS_REF
420413
# NOTE: +1 since this pass runs after introducing the kernel state
421-
byval[arg.codegen.i+has_kernel_state] = true
414+
byval[arg.codegen.i] = true
422415
end
423416
end
424417
end

src/optim.jl

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -255,8 +255,7 @@ function optimize!(@nospecialize(job::CompilerJob), mod::LLVM.Module)
255255
ModulePassManager() do pm
256256
addTargetPasses!(pm, tm, triple)
257257

258-
# - remove unused kernel state arguments
259-
# - simplify function calls that don't use the returned value
258+
# simplify function calls that don't use the returned value
260259
dead_arg_elimination!(pm)
261260

262261
run!(pm, mod)
@@ -356,16 +355,8 @@ function lower_gc_frame!(fun::LLVM.Function)
356355

357356
# replace with PTX alloc_obj
358357
Builder(ctx) do builder
359-
# NOTE: this happens late during the pipeline, where we may have to
360-
# pass a kernel state arguments to the runtime function.
361-
state = if job.source.kernel
362-
kernel_state_type(job)
363-
else
364-
Nothing
365-
end
366-
367358
position!(builder, call)
368-
ptr = call!(builder, Runtime.get(:gc_pool_alloc), [sz]; state)
359+
ptr = call!(builder, Runtime.get(:gc_pool_alloc), [sz])
369360
replace_uses!(call, ptr)
370361
end
371362

src/rtlib.jl

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,7 @@ end
2828

2929
## higher-level functionality to work with runtime functions
3030

31-
function LLVM.call!(builder, rt::Runtime.RuntimeMethodInstance, args=LLVM.Value[];
32-
state::Type=Nothing)
31+
function LLVM.call!(builder, rt::Runtime.RuntimeMethodInstance, args=LLVM.Value[])
3332
bb = position(builder)
3433
f = LLVM.parent(bb)
3534
mod = LLVM.parent(f)
@@ -40,21 +39,10 @@ function LLVM.call!(builder, rt::Runtime.RuntimeMethodInstance, args=LLVM.Value[
4039
f = functions(mod)[rt.llvm_name]
4140
ft = eltype(llvmtype(f))
4241
else
43-
ft = convert(LLVM.FunctionType, rt; ctx, state)
42+
ft = convert(LLVM.FunctionType, rt; ctx)
4443
f = LLVM.Function(mod, rt.llvm_name, ft)
4544
end
4645

47-
# we may be calling this function after kernel state lowering,
48-
# in which case we need to manually get and pass the state.
49-
args = Value[args...]
50-
if state !== Nothing
51-
T_state = convert(LLVMType, state; ctx)
52-
53-
state_intr = kernel_state_intr(mod, T_state)
54-
state_val = call!(builder, state_intr, Value[], "state")
55-
pushfirst!(args, state_val)
56-
end
57-
5846
# runtime functions are written in Julia, while we're calling from LLVM,
5947
# this often results in argument type mismatches. try to fix some here.
6048
for (i,arg) in enumerate(args)

src/runtime.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,13 @@ struct RuntimeMethodInstance
3434
end
3535

3636
function Base.convert(::Type{LLVM.FunctionType}, rt::RuntimeMethodInstance;
37-
ctx::LLVM.Context, state::Type=Nothing)
37+
ctx::LLVM.Context)
3838
types = if rt.llvm_types === nothing
3939
LLVMType[convert(LLVMType, typ; ctx, allow_boxed=true) for typ in rt.types]
4040
else
4141
rt.llvm_types(ctx)
4242
end
4343

44-
# if we're running post-optimization, prepend the kernel state to the argument list
45-
if state !== Nothing
46-
T_state = convert(LLVMType, state; ctx)
47-
pushfirst!(types, T_state)
48-
end
49-
5044
return_type = if rt.llvm_return_type === nothing
5145
convert(LLVMType, rt.return_type; ctx, allow_boxed=true)
5246
else

src/spirv.jl

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -232,26 +232,15 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F
232232
end
233233
else
234234
# XXX: byval is not round-trippable on LLVM < 12 (see maleadt/LLVM.jl#186)
235-
has_kernel_state = kernel_state_type(job) !== Nothing
236-
orig_ft = if has_kernel_state
237-
# the kernel state has been added here already, so strip the first parameter
238-
LLVM.FunctionType(LLVM.return_type(ft), parameters(ft)[2:end]; vararg=isvararg(ft))
239-
else
240-
ft
241-
end
242-
args = classify_arguments(job, orig_ft)
235+
args = classify_arguments(job, ft)
243236
filter!(args) do arg
244237
arg.cc != GHOST
245238
end
246239
for arg in args
247240
if arg.cc == BITS_REF
248-
# NOTE: +1 since this pass runs after introducing the kernel state
249-
byval[arg.codegen.i+has_kernel_state] = true
241+
byval[arg.codegen.i] = true
250242
end
251243
end
252-
if has_kernel_state
253-
byval[1] = true
254-
end
255244
end
256245

257246
# generate the wrapper function type & definition

0 commit comments

Comments
 (0)