diff --git a/src/irgen.jl b/src/irgen.jl index cb38cb3e..2d19961a 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -921,3 +921,160 @@ function kernel_state_to_reference!(@nospecialize(job::CompilerJob), mod::LLVM.M return new_f end end + +function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, + entry::LLVM.Function, kernel_intrinsics::Dict) + entry_fn = LLVM.name(entry) + + # figure out which intrinsics are used and need to be added as arguments + used_intrinsics = filter(keys(kernel_intrinsics)) do intr_fn + haskey(functions(mod), intr_fn) + end |> collect + nargs = length(used_intrinsics) + + # determine which functions need these arguments + worklist = Set{LLVM.Function}([entry]) + for intr_fn in used_intrinsics + push!(worklist, functions(mod)[intr_fn]) + end + worklist_length = 0 + while worklist_length != length(worklist) + # iteratively discover functions that use an intrinsic or any function calling it + worklist_length = length(worklist) + additions = LLVM.Function[] + for f in worklist, use in uses(f) + inst = user(use)::Instruction + bb = LLVM.parent(inst) + new_f = LLVM.parent(bb) + in(new_f, worklist) || push!(additions, new_f) + end + for f in additions + push!(worklist, f) + end + end + for intr_fn in used_intrinsics + delete!(worklist, functions(mod)[intr_fn]) + end + + # add the arguments + # NOTE: we don't need to be fine-grained here, as unused args will be removed during opt + workmap = Dict{LLVM.Function, LLVM.Function}() + for f in worklist + fn = LLVM.name(f) + ft = function_type(f) + LLVM.name!(f, fn * ".orig") + # create a new function + new_param_types = LLVMType[parameters(ft)...] + + for intr_fn in used_intrinsics + llvm_typ = convert(LLVMType, kernel_intrinsics[intr_fn].typ) + push!(new_param_types, llvm_typ) + end + new_ft = LLVM.FunctionType(return_type(ft), new_param_types) + new_f = LLVM.Function(mod, fn, new_ft) + linkage!(new_f, linkage(f)) + for (arg, new_arg) in zip(parameters(f), parameters(new_f)) + LLVM.name!(new_arg, LLVM.name(arg)) + end + for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end]) + LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name) + end + + workmap[f] = new_f + end + + # clone and rewrite the function bodies. + # we don't need to rewrite much as the arguments are added last. + for (f, new_f) in workmap + # map the arguments + value_map = Dict{LLVM.Value, LLVM.Value}() + for (param, new_param) in zip(parameters(f), parameters(new_f)) + LLVM.name!(new_param, LLVM.name(param)) + value_map[param] = new_param + end + + value_map[f] = new_f + clone_into!(new_f, f; value_map, + changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly) + + # we can't remove this function yet, as we might still need to rewrite any called, + # but remove the IR already + empty!(f) + end + + # drop unused constants that may be referring to the old functions + # XXX: can we do this differently? + for f in worklist + prune_constexpr_uses!(f) + end + + # update other uses of the old function, modifying call sites to pass the arguments + function rewrite_uses!(f, new_f) + # update uses + @dispose builder=IRBuilder() begin + for use in uses(f) + val = user(use) + if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst + callee_f = LLVM.parent(LLVM.parent(val)) + # forward the arguments + position!(builder, val) + new_val = if val isa LLVM.CallInst + call!(builder, function_type(new_f), new_f, + [arguments(val)..., parameters(callee_f)[end-nargs+1:end]...], + operand_bundles(val)) + else + # TODO: invoke and callbr + error("Rewrite of $(typeof(val))-based calls is not implemented: $val") + end + callconv!(new_val, callconv(val)) + + replace_uses!(val, new_val) + @assert isempty(uses(val)) + erase!(val) + elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast + # XXX: why isn't this caught by the value materializer above? + target = operands(val)[1] + @assert target == f + new_val = LLVM.const_bitcast(new_f, value_type(val)) + rewrite_uses!(val, new_val) + # we can't simply replace this constant expression, as it may be used + # as a call, taking arguments (so we need to rewrite it to pass the input arguments) + + # drop the old constant if it is unused + # XXX: can we do this differently? + if isempty(uses(val)) + LLVM.unsafe_destroy!(val) + end + else + error("Cannot rewrite unknown use of function: $val") + end + end + end + end + for (f, new_f) in workmap + rewrite_uses!(f, new_f) + @assert isempty(uses(f)) + erase!(f) + end + + # replace uses of the intrinsics with references to the input arguments + for (i, intr_fn) in enumerate(used_intrinsics) + intr = functions(mod)[intr_fn] + for use in uses(intr) + val = user(use) + callee_f = LLVM.parent(LLVM.parent(val)) + if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst + replace_uses!(val, parameters(callee_f)[end-nargs+i]) + else + error("Cannot rewrite unknown use of function: $val") + end + + @assert isempty(uses(val)) + erase!(val) + end + @assert isempty(uses(intr)) + erase!(intr) + end + + return functions(mod)[entry_fn] +end diff --git a/src/metal.jl b/src/metal.jl index 26ceec9c..00abf288 100644 --- a/src/metal.jl +++ b/src/metal.jl @@ -53,9 +53,7 @@ function finish_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mo # update calling conventions if job.config.kernel entry = pass_by_reference!(job, mod, entry) - - add_input_arguments!(job, mod, entry) - entry = LLVM.functions(mod)[entry_fn] + entry = add_input_arguments!(job, mod, entry, kernel_intrinsics) end # emit the AIR and Metal version numbers as constants in the module. this makes it @@ -553,164 +551,6 @@ function argument_type_name(typ) end end -function add_input_arguments!(@nospecialize(job::CompilerJob), mod::LLVM.Module, - entry::LLVM.Function) - entry_fn = LLVM.name(entry) - - # figure out which intrinsics are used and need to be added as arguments - used_intrinsics = filter(keys(kernel_intrinsics)) do intr_fn - haskey(functions(mod), intr_fn) - end |> collect - nargs = length(used_intrinsics) - - # determine which functions need these arguments - worklist = Set{LLVM.Function}([entry]) - for intr_fn in used_intrinsics - push!(worklist, functions(mod)[intr_fn]) - end - worklist_length = 0 - while worklist_length != length(worklist) - # iteratively discover functions that use an intrinsic or any function calling it - worklist_length = length(worklist) - additions = LLVM.Function[] - for f in worklist, use in uses(f) - inst = user(use)::Instruction - bb = LLVM.parent(inst) - new_f = LLVM.parent(bb) - in(new_f, worklist) || push!(additions, new_f) - end - for f in additions - push!(worklist, f) - end - end - for intr_fn in used_intrinsics - delete!(worklist, functions(mod)[intr_fn]) - end - - # add the arguments - # NOTE: we don't need to be fine-grained here, as unused args will be removed during opt - workmap = Dict{LLVM.Function, LLVM.Function}() - for f in worklist - fn = LLVM.name(f) - ft = function_type(f) - LLVM.name!(f, fn * ".orig") - # create a new function - new_param_types = LLVMType[parameters(ft)...] - - for intr_fn in used_intrinsics - llvm_typ = convert(LLVMType, kernel_intrinsics[intr_fn].typ) - push!(new_param_types, llvm_typ) - end - new_ft = LLVM.FunctionType(return_type(ft), new_param_types) - new_f = LLVM.Function(mod, fn, new_ft) - linkage!(new_f, linkage(f)) - for (arg, new_arg) in zip(parameters(f), parameters(new_f)) - LLVM.name!(new_arg, LLVM.name(arg)) - end - for (intr_fn, new_arg) in zip(used_intrinsics, parameters(new_f)[end-nargs+1:end]) - LLVM.name!(new_arg, kernel_intrinsics[intr_fn].name) - end - - workmap[f] = new_f - end - - # clone and rewrite the function bodies. - # we don't need to rewrite much as the arguments are added last. - for (f, new_f) in workmap - # map the arguments - value_map = Dict{LLVM.Value, LLVM.Value}() - for (param, new_param) in zip(parameters(f), parameters(new_f)) - LLVM.name!(new_param, LLVM.name(param)) - value_map[param] = new_param - end - - value_map[f] = new_f - clone_into!(new_f, f; value_map, - changes=LLVM.API.LLVMCloneFunctionChangeTypeLocalChangesOnly) - - # we can't remove this function yet, as we might still need to rewrite any called, - # but remove the IR already - empty!(f) - end - - # drop unused constants that may be referring to the old functions - # XXX: can we do this differently? - for f in worklist - prune_constexpr_uses!(f) - end - - # update other uses of the old function, modifying call sites to pass the arguments - function rewrite_uses!(f, new_f) - # update uses - @dispose builder=IRBuilder() begin - for use in uses(f) - val = user(use) - if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst - callee_f = LLVM.parent(LLVM.parent(val)) - # forward the arguments - position!(builder, val) - new_val = if val isa LLVM.CallInst - call!(builder, function_type(new_f), new_f, - [arguments(val)..., parameters(callee_f)[end-nargs+1:end]...], - operand_bundles(val)) - else - # TODO: invoke and callbr - error("Rewrite of $(typeof(val))-based calls is not implemented: $val") - end - callconv!(new_val, callconv(val)) - - replace_uses!(val, new_val) - @assert isempty(uses(val)) - erase!(val) - elseif val isa LLVM.ConstantExpr && opcode(val) == LLVM.API.LLVMBitCast - # XXX: why isn't this caught by the value materializer above? - target = operands(val)[1] - @assert target == f - new_val = LLVM.const_bitcast(new_f, value_type(val)) - rewrite_uses!(val, new_val) - # we can't simply replace this constant expression, as it may be used - # as a call, taking arguments (so we need to rewrite it to pass the input arguments) - - # drop the old constant if it is unused - # XXX: can we do this differently? - if isempty(uses(val)) - LLVM.unsafe_destroy!(val) - end - else - error("Cannot rewrite unknown use of function: $val") - end - end - end - end - for (f, new_f) in workmap - rewrite_uses!(f, new_f) - @assert isempty(uses(f)) - erase!(f) - end - - # replace uses of the intrinsics with references to the input arguments - for (i, intr_fn) in enumerate(used_intrinsics) - intr = functions(mod)[intr_fn] - for use in uses(intr) - val = user(use) - callee_f = LLVM.parent(LLVM.parent(val)) - if val isa LLVM.CallInst || val isa LLVM.InvokeInst || val isa LLVM.CallBrInst - replace_uses!(val, parameters(callee_f)[end-nargs+i]) - else - error("Cannot rewrite unknown use of function: $val") - end - - @assert isempty(uses(val)) - erase!(val) - end - @assert isempty(uses(intr)) - erase!(intr) - end - - return -end - - # argument metadata generation # # module metadata is used to identify buffers that are passed as kernel arguments.