Skip to content

Commit 67adf26

Browse files
committed
Add support for calling runtime functions post-optimization.
1 parent 3e53a3f commit 67adf26

File tree

2 files changed

+20
-3
lines changed

2 files changed

+20
-3
lines changed

src/rtlib.jl

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ end
2828

2929
## higher-level functionality to work with runtime functions
3030

31-
function LLVM.call!(builder, rt::Runtime.RuntimeMethodInstance, args=LLVM.Value[])
31+
function LLVM.call!(builder, rt::Runtime.RuntimeMethodInstance, args=LLVM.Value[];
32+
state::Type=Nothing)
3233
bb = position(builder)
3334
f = LLVM.parent(bb)
3435
mod = LLVM.parent(f)
@@ -39,10 +40,18 @@ function LLVM.call!(builder, rt::Runtime.RuntimeMethodInstance, args=LLVM.Value[
3940
f = functions(mod)[rt.llvm_name]
4041
ft = eltype(llvmtype(f))
4142
else
42-
ft = convert(LLVM.FunctionType, rt; ctx)
43+
ft = convert(LLVM.FunctionType, rt; ctx, state)
4344
f = LLVM.Function(mod, rt.llvm_name, ft)
4445
end
4546

47+
# we may be calling this function after kernel state lowering,
48+
# in which case we need to manually get and pass the state.
49+
args = Value[args...]
50+
if state !== Nothing
51+
state_val = kernel_state_argument(f, state)
52+
pushfirst!(args, state_val)
53+
end
54+
4655
# runtime functions are written in Julia, while we're calling from LLVM,
4756
# this often results in argument type mismatches. try to fix some here.
4857
for (i,arg) in enumerate(args)

src/runtime.jl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,13 +33,21 @@ struct RuntimeMethodInstance
3333
llvm_name::String
3434
end
3535

36-
function Base.convert(::Type{LLVM.FunctionType}, rt::RuntimeMethodInstance; ctx::LLVM.Context)
36+
function Base.convert(::Type{LLVM.FunctionType}, rt::RuntimeMethodInstance;
37+
ctx::LLVM.Context, state::Type=Nothing)
3738
types = if rt.llvm_types === nothing
3839
LLVMType[convert(LLVMType, typ; ctx, allow_boxed=true) for typ in rt.types]
3940
else
4041
rt.llvm_types(ctx)
4142
end
4243

44+
# if we're running post-optimization, prepend the kernel state to the argument list
45+
if state !== Nothing
46+
T_state = convert(LLVMType, state; ctx)
47+
T_ptr_state = LLVM.PointerType(T_state)
48+
pushfirst!(types, T_ptr_state)
49+
end
50+
4351
return_type = if rt.llvm_return_type === nothing
4452
convert(LLVMType, rt.return_type; ctx, allow_boxed=true)
4553
else

0 commit comments

Comments
 (0)