Skip to content

Commit 435f1d4

Browse files
committed
run abstract interpretation over deferred code
1 parent e94f023 commit 435f1d4

File tree

2 files changed

+32
-3
lines changed

2 files changed

+32
-3
lines changed

src/jlgen.jl

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -460,12 +460,27 @@ function CC.concrete_eval_eligible(interp::GPUInterpreter,
460460
end
461461
end
462462

463+
struct DeferredCallInfo <: CC.CallInfo
464+
rt::DataType
465+
info::CC.CallInfo
466+
end
467+
463468
function CC.abstract_call_known(interp::GPUInterpreter, @nospecialize(f),
464469
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,
465470
max_methods::Int = CC.get_max_methods(interp, f, sv))
466-
if f === var"gpuc.deferred" ||
467-
f === var"gpuc.lookup"
468-
return CC.CallMeta(Ptr{Cvoid}, Union{}, CC.Effects(), CC.NoCallInfo())
471+
(; fargs, argtypes) = arginfo
472+
if f === var"gpuc.deferred"
473+
argvec = argtypes[2:end]
474+
call = CC.abstract_call(interp, CC.ArgInfo(nothing, argvec), si, sv, max_methods)
475+
callinfo = DeferredCallInfo(call.rt, call.info)
476+
@static if VERSION < v"1.11.0-"
477+
return CC.CallMeta(Ptr{Cvoid}, CC.Effects(), callinfo)
478+
else
479+
return CC.CallMeta(Ptr{Cvoid}, Union{}, CC.Effects(), callinfo)
480+
end
481+
end
482+
if f === var"gpuc.lookup"
483+
error("Unimplemented")
469484
end
470485
return @invoke CC.abstract_call_known(interp::CC.AbstractInterpreter, f,
471486
arginfo::CC.ArgInfo, si::CC.StmtInfo, sv::CC.AbsIntState,

test/native_tests.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,20 @@ end
162162
ir = fetch(t)
163163
@test contains(ir, r"add i64 %\d+, 3")
164164
end
165+
166+
@testset "deferred" begin
167+
@gensym child kernel unrelated
168+
@eval @noinline $child(i) = i
169+
@eval $kernel(i) = GPUCompiler.var"gpuc.deferred"($child, i)
170+
171+
# smoke test
172+
job, _ = Native.create_job(eval(kernel), (Int64,))
173+
174+
ci, rt = only(GPUCompiler.code_typed(job))
175+
@test rt === Ptr{Cvoid}
176+
177+
ir = sprint(io->GPUCompiler.code_llvm(io, job))
178+
end
165179
end
166180

167181
############################################################################################

0 commit comments

Comments
 (0)