Skip to content

Commit 1976f0d

Browse files
Add @device_function macro for AOT compilation (#749)
Introduce a @device_function macro that creates both a CPU-visible stub function and a method overlay in GLOBAL_METHOD_TABLE for GPU compilation. This prevents CPU runtime functions (ccalls to Julia internals) from leaking into GPU IR, which was breaking AOT compilation. --------- Co-authored-by: Michel Schanen <michel.schanen@gmail.com>
1 parent d762273 commit 1976f0d

File tree

6 files changed

+73
-8
lines changed

6 files changed

+73
-8
lines changed

src/driver.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ function compile_unhooked(output::Symbol, @nospecialize(job::CompilerJob); kwarg
9393
## LLVM IR
9494

9595
ir, ir_meta = emit_llvm(job)
96-
9796
if output == :llvm
9897
if job.config.strip
9998
@tracepoint "strip debug info" strip_debuginfo!(ir)

src/jlgen.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -293,10 +293,6 @@ end
293293
end # !HAS_INTEGRATED_CACHE
294294

295295

296-
## method overrides
297-
298-
Base.Experimental.@MethodTable(GLOBAL_METHOD_TABLE)
299-
300296
# Implements a priority lookup for method tables, where the first match in the stack get's returned.
301297
# An alternative to this would be to use a "Union" where we would query the parent method table and
302298
# do a most-specific match.
@@ -490,7 +486,10 @@ CC.lock_mi_inference(interp::GPUInterpreter, mi::MethodInstance) = nothing
490486
CC.unlock_mi_inference(interp::GPUInterpreter, mi::MethodInstance) = nothing
491487

492488
function CC.add_remark!(interp::GPUInterpreter, sv::CC.InferenceState, msg)
493-
@safe_debug "Inference remark during GPU compilation of $(sv.linfo): $msg"
489+
# NOTE: @safe_debug is disabled here because including logging/warning code causes
490+
# CPU runtime functions (ccalls to Julia internals) to leak into the GPU IR,
491+
# breaking AOT compilation. See PR #749 for details.
492+
return nothing
494493
end
495494

496495
CC.may_optimize(interp::GPUInterpreter) = true

src/rtlib.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ function emit_function!(mod, config::CompilerConfig, f, method)
7777
new_mod, meta = compile_unhooked(:llvm, CompilerJob(source, config))
7878
ft = function_type(meta.entry)
7979
expected_ft = convert(LLVM.FunctionType, method)
80+
8081
if return_type(ft) != return_type(expected_ft)
8182
error("Invalid return type for runtime function '$(method.name)': expected $(return_type(expected_ft)), got $(return_type(ft))")
8283
end

src/runtime.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ function compile(def, return_type, types, llvm_return_type=nothing, llvm_types=n
7171
meth = RuntimeMethodInstance(def,
7272
return_type, types, name,
7373
llvm_return_type, llvm_types, llvm_name)
74+
7475
if haskey(methods, name)
7576
error("Runtime function $name has already been registered!")
7677
end
@@ -82,8 +83,10 @@ function compile(def, return_type, types, llvm_return_type=nothing, llvm_types=n
8283
# using the new nonrecursive codegen to handle function lookup ourselves?
8384
if def isa Symbol
8485
args = [gensym() for typ in types]
85-
@eval @inline $def($(args...)) =
86-
ccall($("extern $llvm_name"), llvmcall, $return_type, ($(types...),), $(args...))
86+
@eval GPUCompiler.@device_function($return_type,
87+
@inline $def($(args...)) =
88+
ccall($("extern $llvm_name"), llvmcall, $return_type, ($(types...),), $(args...))
89+
)
8790
end
8891

8992
return

src/utils.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,3 +238,29 @@ end
238238
return inits
239239
end
240240
end
241+
## method overrides
242+
243+
Base.Experimental.@MethodTable(GLOBAL_METHOD_TABLE)
244+
using ExprTools: splitdef, combinedef
245+
macro device_function(rt, ex)
246+
ex = macroexpand(__module__, ex)
247+
def = splitdef(ex)
248+
249+
# generate a function that warns and returns the expected type
250+
# FIXME: The type may not have a default constructor, what do we do then?
251+
# Currently we are using the constructor with an Int64(1) as an argument.
252+
# NOTE: using Int64(1) is a bit odd. This is because Ptr(Int64(0)) == C_NULL, and julia code lowering
253+
# seems to get rid of this automatically.
254+
def[:body] = quote
255+
$rt(1)
256+
end
257+
258+
esc(quote
259+
$(combinedef(def))
260+
261+
# NOTE: no use of `@consistent_overlay` here because the regular function errors
262+
Base.Experimental.@overlay($(GPUCompiler).GLOBAL_METHOD_TABLE, $ex)
263+
end)
264+
end
265+
266+

test/utils.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -193,3 +193,40 @@ end
193193
# Check that we can call this function from the CPU, to support deferred codegen for Enzyme.
194194
@test ccall("extern deferred_codegen", llvmcall, UInt, (UInt,), 3) == 3
195195
end
196+
197+
@testset "@device_function macro" begin
198+
# Test that @device_function creates both CPU stub and overlay
199+
# The macro should:
200+
# 1. Define a CPU-visible function that returns the expected type
201+
# 2. Register an overlay in GLOBAL_METHOD_TABLE for GPU compilation
202+
203+
# Create a test module to contain the device functions
204+
test_mod = @eval module $(gensym("DeviceFunctionTest"))
205+
using GPUCompiler
206+
207+
# Test with Ptr return type (common for runtime functions)
208+
GPUCompiler.@device_function(Ptr{Nothing},
209+
@inline test_device_ptr() = ccall("extern gpu_test", llvmcall, Ptr{Nothing}, ())
210+
)
211+
212+
# Test with primitive return type
213+
GPUCompiler.@device_function(Nothing,
214+
@inline test_device_nothing() = ccall("extern gpu_test2", llvmcall, Nothing, ())
215+
)
216+
end
217+
218+
# Verify the functions are defined in the test module
219+
@test isdefined(test_mod, :test_device_ptr)
220+
@test isdefined(test_mod, :test_device_nothing)
221+
222+
# Verify the overlay exists in the global method table
223+
mt_view = GPUCompiler.get_method_table_view(Base.get_world_counter(), GPUCompiler.GLOBAL_METHOD_TABLE)
224+
sig_ptr = Tuple{typeof(test_mod.test_device_ptr)}
225+
sig_nothing = Tuple{typeof(test_mod.test_device_nothing)}
226+
227+
# The overlay should be findable in the method table
228+
result_ptr = findsup(sig_ptr, mt_view)
229+
result_nothing = findsup(sig_nothing, mt_view)
230+
@test result_ptr !== nothing
231+
@test result_nothing !== nothing
232+
end

0 commit comments

Comments
 (0)