Skip to content

Commit c8d3c49

Browse files
committed
Set kernel metadata for all kernels, including deferred ones.
1 parent d762ef8 commit c8d3c49

File tree

2 files changed

+13
-9
lines changed

2 files changed

+13
-9
lines changed

src/driver.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -316,15 +316,16 @@ const __llvm_initialized = Ref(false)
316316
end
317317

318318
@timeit_debug to "IR post-processing" begin
319-
# mark the entry-point function (optimization may need it)
320-
if deferred_codegen
321-
# IDEA: save other parts of the CompileJob (so that we can reconstruct it
322-
# instead of setting it globally, which is incompatible with threading)?
319+
# mark the kernel entry-point functions (optimization may need it)
320+
if job.source.kernel
323321
# XXX: we want to save the actual function here, but due to our passes rewriting
324322
# functions, and the inability to RAUW values with a different type, that
325323
# metadata gets lost. So instead we save the function name. See also:
326324
# https://discourse.llvm.org/t/replacing-module-metadata-uses-of-function/62431
327-
push!(metadata(ir)["julia.entry"], MDNode([MDString(LLVM.name(entry); ctx)]; ctx))
325+
push!(metadata(ir)["julia.kernel"], MDNode([MDString(LLVM.name(entry); ctx)]; ctx))
326+
327+
# IDEA: save all jobs, not only top-level kernels, and save other attributes
328+
# so that we can reconstruct the CompileJob instead of setting it globally
328329
end
329330

330331
if optimize

src/irgen.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -551,16 +551,19 @@ function add_kernel_state!(mod::LLVM.Module)
551551
# this is both for extern uses, and to make this transformation a two-step process.
552552
state_intr = kernel_state_intr(mod, T_state)
553553

554-
entry_md = operands(metadata(mod)["julia.entry"])[1]
555-
entry_fn = string(operands(entry_md)[1])
556-
entry = functions(mod)[entry_fn]
554+
kernels = []
555+
kernels_md = metadata(mod)["julia.kernel"]
556+
for kernel_md in operands(kernels_md)
557+
kernel_fn = string(operands(kernel_md)[1])
558+
push!(kernels, functions(mod)[kernel_fn])
559+
end
557560

558561
# determine which functions need a kernel state argument
559562
#
560563
# previously, we add the argument to every function and relied on unused arg elim to
561564
# clean-up the IR. however, some libraries do Funny Stuff, e.g., libdevice bitcasting
562565
# function pointers. such IR is hard to rewrite, so instead be more conservative.
563-
worklist = Set{LLVM.Function}([entry, state_intr])
566+
worklist = Set{LLVM.Function}([state_intr, kernels...])
564567
worklist_length = 0
565568
while worklist_length != length(worklist)
566569
# iteratively discover functions that use the intrinsic or any function calling it

0 commit comments

Comments
 (0)