From c5f14753597e05526375cd58069b45a7715a2ca9 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Tue, 3 Sep 2024 13:31:23 +0200 Subject: [PATCH 1/4] call add_input_arguments on finish_ir --- src/metal.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/metal.jl b/src/metal.jl index 87d2106e..f2867b67 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -136,6 +136,9 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L # add kernel metadata if job.config.kernel + add_input_arguments!(job, mod, entry) + entry = LLVM.functions(mod)[entry_fn] + entry = add_address_spaces!(job, mod, entry) add_argument_metadata!(job, mod, entry) From b213bba1856d5dbb1e13a1228c2e0dc537e8f25f Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Fri, 6 Sep 2024 00:16:45 +0200 Subject: [PATCH 2/4] run finish_module after linking of deferred jobs --- src/driver.jl | 11 +++++++-- src/metal.jl | 65 ++++++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 63 insertions(+), 13 deletions(-) diff --git a/src/driver.jl b/src/driver.jl index 54a6c053..1e474004 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -178,12 +178,15 @@ const __llvm_initialized = Ref(false) entry = functions(ir)[entry_fn] end + has_deferred_jobs = toplevel && !only_entry && haskey(functions(ir), "deferred_codegen") + # finalize the current module. this needs to happen before linking deferred modules, # since those modules have been finalized themselves, and we don't want to re-finalize. - entry = finish_module!(job, ir, entry) + if !has_deferred_jobs + entry = finish_module!(job, ir, entry) + end # deferred code generation - has_deferred_jobs = toplevel && !only_entry && haskey(functions(ir), "deferred_codegen") jobs = Dict{CompilerJob, String}(job => entry_fn) if has_deferred_jobs dyn_marker = functions(ir)["deferred_codegen"] @@ -258,6 +261,10 @@ const __llvm_initialized = Ref(false) unsafe_delete!(ir, dyn_marker) end + if has_deferred_jobs + entry = finish_module!(job, ir, entry) + end + if libraries # load the runtime outside of a timing block (because it recurses into the compiler) if !uses_julia_runtime(job) diff --git a/src/metal.jl b/src/metal.jl index f2867b67..98c57b02 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -461,11 +461,19 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, # iteratively discover functions that use an intrinsic or any function calling it worklist_length = length(worklist) additions = LLVM.Function[] - for f in worklist, use in uses(f) - inst = user(use)::Instruction - bb = LLVM.parent(inst) - new_f = LLVM.parent(bb) - in(new_f, worklist) || push!(additions, new_f) + for f in worklist + recursive_uses = collect(uses(f)) + for use in recursive_uses + val = user(use) + if val isa LLVM.ConstantExpr + append!(recursive_uses, uses(val)) + continue + end + inst = val::Instruction + bb = LLVM.parent(inst) + new_f = LLVM.parent(bb) + in(new_f, worklist) || push!(additions, new_f) + end end for f in additions push!(worklist, f) @@ -533,7 +541,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, end # update other uses of the old function, modifying call sites to pass the arguments - function rewrite_uses!(f, new_f) + function rewrite_uses!(f, new_f, fty) # update uses @dispose builder=IRBuilder() begin for use in uses(f) @@ -543,7 +551,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, # forward the arguments position!(builder, val) new_val = if val isa LLVM.CallInst - call!(builder, function_type(new_f), new_f, + call!(builder, fty, new_f, [arguments(val)..., parameters(callee_f)[end-nargs+1:end]...], operand_bundles(val)) else @@ -555,12 +563,22 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, replace_uses!(val, new_val) @assert isempty(uses(val)) unsafe_delete!(LLVM.parent(val), val) - elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast + elseif val isa LLVM.ConstantExpr && (opcode(val) == LLVM.API.LLVMBitCast || + opcode(val) == LLVM.API.LLVMPtrToInt || + opcode(val) == LLVM.API.LLVMIntToPtr) # XXX: why isn't this caught by the value materializer above? target = operands(val)[1] @assert target == f - new_val = LLVM.const_bitcast(new_f, value_type(val)) - rewrite_uses!(val, new_val) + + new_val = if opcode(val) == LLVM.API.LLVMBitCast + LLVM.const_bitcast(new_f, value_type(val)) + elseif opcode(val) == LLVM.API.LLVMPtrToInt + LLVM.const_ptrtoint(new_f, value_type(val)) + elseif opcode(val) == LLVM.API.LLVMIntToPtr + LLVM.const_inttoptr(new_f, value_type(val)) + end + + rewrite_uses!(val, new_val, fty) # we can't simply replace this constant expression, as it may be used # as a call, taking arguments (so we need to rewrite it to pass the input arguments) @@ -569,6 +587,31 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, if isempty(uses(val)) LLVM.unsafe_destroy!(val) end + elseif val isa LLVM.Instruction && (opcode(val) == LLVM.API.LLVMBitCast || + opcode(val) == LLVM.API.LLVMPtrToInt || + opcode(val) == LLVM.API.LLVMIntToPtr) && + all(isa.(operands(val::Instruction), LLVM.ConstantExpr)) + # XXX: why isn't this caught by the value materializer above? + target = operands(val)[1] + @assert target == f + + new_val = if opcode(val) == LLVM.API.LLVMBitCast + LLVM.const_bitcast(new_f, value_type(val)) + elseif opcode(val) == LLVM.API.LLVMPtrToInt + LLVM.const_ptrtoint(new_f, value_type(val)) + elseif opcode(val) == LLVM.API.LLVMIntToPtr + LLVM.const_inttoptr(new_f, value_type(val)) + end + + rewrite_uses!(val, new_val, fty) + # we can't simply replace this constant expression, as it may be used + # as a call, taking arguments (so we need to rewrite it to pass the input arguments) + + # drop the old constant if it is unused + # XXX: can we do this differently? + if isempty(uses(val)) + unsafe_delete!(LLVM.parent(val), val) + end else error("Cannot rewrite unknown use of function: $val") end @@ -576,7 +619,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, end end for (f, new_f) in workmap - rewrite_uses!(f, new_f) + rewrite_uses!(f, new_f, function_type(new_f)) @assert isempty(uses(f)) unsafe_delete!(mod, f) end From 1d4e52f68fb65dad713202c603f7e91faae29c07 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Fri, 6 Sep 2024 00:17:15 +0200 Subject: [PATCH 3/4] Revert "call add_input_arguments on finish_ir" This reverts commit c5f14753597e05526375cd58069b45a7715a2ca9. --- src/metal.jl | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/metal.jl b/src/metal.jl index 98c57b02..76b2668e 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -136,9 +136,6 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L # add kernel metadata if job.config.kernel - add_input_arguments!(job, mod, entry) - entry = LLVM.functions(mod)[entry_fn] - entry = add_address_spaces!(job, mod, entry) add_argument_metadata!(job, mod, entry) From 14c148bbbd4eaaafdb6e1bc64b50872777a5e4e9 Mon Sep 17 00:00:00 2001 From: Tim Gymnich Date: Fri, 6 Sep 2024 17:14:42 +0200 Subject: [PATCH 4/4] Revert "run finish_module after linking of deferred jobs" This reverts commit b213bba1856d5dbb1e13a1228c2e0dc537e8f25f. --- src/driver.jl | 11 ++------- src/metal.jl | 65 +++++++++------------------------------------------ 2 files changed, 13 insertions(+), 63 deletions(-) diff --git a/src/driver.jl b/src/driver.jl index 1e474004..54a6c053 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -178,15 +178,12 @@ const __llvm_initialized = Ref(false) entry = functions(ir)[entry_fn] end - has_deferred_jobs = toplevel && !only_entry && haskey(functions(ir), "deferred_codegen") - # finalize the current module. this needs to happen before linking deferred modules, # since those modules have been finalized themselves, and we don't want to re-finalize. - if !has_deferred_jobs - entry = finish_module!(job, ir, entry) - end + entry = finish_module!(job, ir, entry) # deferred code generation + has_deferred_jobs = toplevel && !only_entry && haskey(functions(ir), "deferred_codegen") jobs = Dict{CompilerJob, String}(job => entry_fn) if has_deferred_jobs dyn_marker = functions(ir)["deferred_codegen"] @@ -261,10 +258,6 @@ const __llvm_initialized = Ref(false) unsafe_delete!(ir, dyn_marker) end - if has_deferred_jobs - entry = finish_module!(job, ir, entry) - end - if libraries # load the runtime outside of a timing block (because it recurses into the compiler) if !uses_julia_runtime(job) diff --git a/src/metal.jl b/src/metal.jl index 76b2668e..87d2106e 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -458,19 +458,11 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, # iteratively discover functions that use an intrinsic or any function calling it worklist_length = length(worklist) additions = LLVM.Function[] - for f in worklist - recursive_uses = collect(uses(f)) - for use in recursive_uses - val = user(use) - if val isa LLVM.ConstantExpr - append!(recursive_uses, uses(val)) - continue - end - inst = val::Instruction - bb = LLVM.parent(inst) - new_f = LLVM.parent(bb) - in(new_f, worklist) || push!(additions, new_f) - end + for f in worklist, use in uses(f) + inst = user(use)::Instruction + bb = LLVM.parent(inst) + new_f = LLVM.parent(bb) + in(new_f, worklist) || push!(additions, new_f) end for f in additions push!(worklist, f) @@ -538,7 +530,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, end # update other uses of the old function, modifying call sites to pass the arguments - function rewrite_uses!(f, new_f, fty) + function rewrite_uses!(f, new_f) # update uses @dispose builder=IRBuilder() begin for use in uses(f) @@ -548,7 +540,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, # forward the arguments position!(builder, val) new_val = if val isa LLVM.CallInst - call!(builder, fty, new_f, + call!(builder, function_type(new_f), new_f, [arguments(val)..., parameters(callee_f)[end-nargs+1:end]...], operand_bundles(val)) else @@ -560,22 +552,12 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, replace_uses!(val, new_val) @assert isempty(uses(val)) unsafe_delete!(LLVM.parent(val), val) - elseif val isa LLVM.ConstantExpr && (opcode(val) == LLVM.API.LLVMBitCast || - opcode(val) == LLVM.API.LLVMPtrToInt || - opcode(val) == LLVM.API.LLVMIntToPtr) + elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast # XXX: why isn't this caught by the value materializer above? target = operands(val)[1] @assert target == f - - new_val = if opcode(val) == LLVM.API.LLVMBitCast - LLVM.const_bitcast(new_f, value_type(val)) - elseif opcode(val) == LLVM.API.LLVMPtrToInt - LLVM.const_ptrtoint(new_f, value_type(val)) - elseif opcode(val) == LLVM.API.LLVMIntToPtr - LLVM.const_inttoptr(new_f, value_type(val)) - end - - rewrite_uses!(val, new_val, fty) + new_val = LLVM.const_bitcast(new_f, value_type(val)) + rewrite_uses!(val, new_val) # we can't simply replace this constant expression, as it may be used # as a call, taking arguments (so we need to rewrite it to pass the input arguments) @@ -584,31 +566,6 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, if isempty(uses(val)) LLVM.unsafe_destroy!(val) end - elseif val isa LLVM.Instruction && (opcode(val) == LLVM.API.LLVMBitCast || - opcode(val) == LLVM.API.LLVMPtrToInt || - opcode(val) == LLVM.API.LLVMIntToPtr) && - all(isa.(operands(val::Instruction), LLVM.ConstantExpr)) - # XXX: why isn't this caught by the value materializer above? - target = operands(val)[1] - @assert target == f - - new_val = if opcode(val) == LLVM.API.LLVMBitCast - LLVM.const_bitcast(new_f, value_type(val)) - elseif opcode(val) == LLVM.API.LLVMPtrToInt - LLVM.const_ptrtoint(new_f, value_type(val)) - elseif opcode(val) == LLVM.API.LLVMIntToPtr - LLVM.const_inttoptr(new_f, value_type(val)) - end - - rewrite_uses!(val, new_val, fty) - # we can't simply replace this constant expression, as it may be used - # as a call, taking arguments (so we need to rewrite it to pass the input arguments) - - # drop the old constant if it is unused - # XXX: can we do this differently? - if isempty(uses(val)) - unsafe_delete!(LLVM.parent(val), val) - end else error("Cannot rewrite unknown use of function: $val") end @@ -616,7 +573,7 @@ function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, end end for (f, new_f) in workmap - rewrite_uses!(f, new_f, function_type(new_f)) + rewrite_uses!(f, new_f) @assert isempty(uses(f)) unsafe_delete!(mod, f) end