From 9723149f12692b976744ca28c9d8d28bb2166544 Mon Sep 17 00:00:00 2001 From: Tim Besard Date: Wed, 30 Jul 2025 14:39:36 +0200 Subject: [PATCH] SPIR-V: Add a pre-optimization pass to convert unreachable to return. --- src/ptx.jl | 4 +- src/spirv.jl | 233 +++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 235 insertions(+), 2 deletions(-) diff --git a/src/ptx.jl b/src/ptx.jl index 0c0a556b..506bc56a 100644 --- a/src/ptx.jl +++ b/src/ptx.jl @@ -193,7 +193,7 @@ function finish_ir!(@nospecialize(job::CompilerJob{PTXCompilerTarget}), mod::LLVM.Module, entry::LLVM.Function) if LLVM.version() < v"17" for f in functions(mod) - lower_unreachable!(f) + lower_unreachable_to_exit!(f) end end @@ -308,7 +308,7 @@ end # CFG, and consequently correctly determine the divergence regions as intended. # Note that we first emit a call to `trap`, so that the behaviour is the same # as before. -function lower_unreachable!(f::LLVM.Function) +function lower_unreachable_to_exit!(f::LLVM.Function) mod = LLVM.parent(f) # TODO: diff --git a/src/spirv.jl b/src/spirv.jl index 2afd4f60..89bb54ac 100644 --- a/src/spirv.jl +++ b/src/spirv.jl @@ -76,6 +76,8 @@ function finish_module!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module, # they do support struct byval, for OpenCL, so wrap byval parameters in a struct. if job.config.kernel entry = wrap_byval(job, mod, entry) + lower_unreachable_to_return!(job, mod, entry) + verify(mod) end # add module metadata @@ -358,3 +360,234 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F return new_f end + + +## LLVM IR passes + +# lower unreachable instructions to returns with error flags +# +# SPIR-V does not have a trap instruction, so the common trap + unreachable sequence +# results in `OpUnreachable` actually getting executed, which is undefined behavior. +# Instead, we transform unreachable instructions to returns with an error flag that's +# checked by the caller. +function lower_unreachable_to_return!(@nospecialize(job::CompilerJob), + mod::LLVM.Module, entry::LLVM.Function) + job = current_job::CompilerJob + changed = false + @tracepoint "lower unreachable to return" begin + + already_transformed_functions = Set{LLVM.Function}() + + # The pass runs until all unreachable instructions are transformed. During each + # iteration, we transform all unreachable instructions to returns, and transform all + # callers to handle the flag, generating a new unreachable when it is set. + while true + # Find all functions with unreachable instructions + functions_with_unreachable = Set{LLVM.Function}() + for f in functions(mod) + for bb in blocks(f), inst in instructions(bb) + if inst isa LLVM.UnreachableInst + push!(functions_with_unreachable, f) + break + end + end + end + isempty(functions_with_unreachable) && break + + # Transform functions with unreachable to return a flag next to the original value + transformed_functions = Dict{LLVM.Function, LLVM.Function}() + for f in functions_with_unreachable + ft = function_type(f) + ret_type = return_type(ft) + fn = LLVM.name(f) + + # in the case of the entry-point function, we cannot touch its type or returned + # value, so simply replace the unreachable with a return. + if f == entry + @compiler_assert ret_type == LLVM.VoidType() job + + # find un reachables + unreachables = LLVM.Value[] + for bb in blocks(f), inst in instructions(bb) + if inst isa LLVM.UnreachableInst + push!(unreachables, inst) + end + end + + # transform unreachable to return + @dispose builder=IRBuilder() begin + for inst in unreachables + position!(builder, inst) + ret!(builder) + erase!(inst) + end + end + + continue + end + + # If this is the first time looking at this function, we need to change its type + if !in(f, already_transformed_functions) + # Create new return type: {i1, original_type} + new_ret_type = if ret_type == LLVM.VoidType() + LLVM.StructType([LLVM.Int1Type()]) + else + LLVM.StructType([LLVM.Int1Type(), ret_type]) + end + + LLVM.name!(f, fn * ".old") + new_ft = LLVM.FunctionType(new_ret_type, parameters(ft)) + new_f = LLVM.Function(mod, fn, new_ft) + linkage!(new_f, linkage(f)) + for (i, param) in enumerate(parameters(f)) + LLVM.name!(parameters(new_f)[i], LLVM.name(param)) + end + + # clone the IR + value_map = Dict{LLVM.Value, LLVM.Value}( + param => parameters(new_f)[i] for (i,param) in enumerate(parameters(f)) + ) + clone_into!(new_f, f; value_map, + changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges) + + # rewrite return instructions + returns = LLVM.Value[] + for bb in blocks(new_f), inst in instructions(bb) + if inst isa LLVM.RetInst + push!(returns, inst) + end + end + @dispose builder=IRBuilder() begin + for inst in returns + position!(builder, inst) + if ret_type == LLVM.VoidType() + # void function: return {false} + flag_and_val = + insert_value!(builder, UndefValue(new_ret_type), + ConstantInt(LLVM.Int1Type(), false), 0) + else + # non-void function: return {false, val} + val = only(operands(inst)) + flag_and_val = + insert_value!(builder, UndefValue(new_ret_type), + ConstantInt(LLVM.Int1Type(), false), 0) + flag_and_val = insert_value!(builder, flag_and_val, val, 1) + end + ret!(builder, flag_and_val) + erase!(inst) + end + end + + transformed_functions[f] = new_f + push!(already_transformed_functions, new_f) + f = new_f + end + + # rewrite unreachable instructions + ret_type = return_type(function_type(f)) + unreachables = LLVM.Value[] + for bb in blocks(f), inst in instructions(bb) + if inst isa LLVM.UnreachableInst + push!(unreachables, inst) + end + end + @dispose builder=IRBuilder() begin + for inst in unreachables + position!(builder, inst) + if length(elements(ret_type)) == 1 + # void function: return {true} + flag_and_val = insert_value!(builder, UndefValue(ret_type), + ConstantInt(LLVM.Int1Type(), true), 0) + else + # non-void function: return {true, undef} + val_type = elements(ret_type)[2] + flag_and_val = insert_value!(builder, UndefValue(ret_type), + ConstantInt(LLVM.Int1Type(), true), 0) + flag_and_val = insert_value!(builder, flag_and_val, + UndefValue(val_type), 1) + end + ret!(builder, flag_and_val) + erase!(inst) + end + end + + changed = true + end + + # Rewrite calls + for (old_f, new_f) in transformed_functions + calls_to_rewrite = LLVM.CallInst[] + for use in uses(old_f) + call_inst = user(use) + if call_inst isa LLVM.CallInst && called_operand(call_inst) == old_f + push!(calls_to_rewrite, call_inst) + end + end + + @dispose builder=IRBuilder() begin + for call_inst in calls_to_rewrite + f = LLVM.parent(LLVM.parent(call_inst)) + position!(builder, call_inst) + + # Call the new function + new_call = call!(builder, function_type(new_f), new_f, arguments(call_inst)) + callconv!(new_call, callconv(call_inst)) + + # Split the block and branch based on the flag + flag = extract_value!(builder, new_call, 0) + error_block = BasicBlock(f, "error") + move_after(error_block, LLVM.parent(call_inst)) + continue_block = BasicBlock(f, "continue") + move_after(continue_block, error_block) + br_inst = br!(builder, flag, error_block, continue_block) + + # Extract the returned value in the continue block + position!(builder, continue_block) + if value_type(call_inst) != LLVM.VoidType() + value = extract_value!(builder, new_call, 1) + replace_uses!(call_inst, value) + end + @compiler_assert isempty(uses(call_inst)) job + erase!(call_inst) + + # Move the remaining instructions over to the continue block + while true + inst = LLVM.nextinst(br_inst) + inst === nothing && break + remove!(inst) + insert!(builder, inst) + end + + # Generate an unreachable in the error block + position!(builder, error_block) + unreachable!(builder) + end + end + + @compiler_assert isempty(uses(old_f)) job + erase!(old_f) + end + end + + # Get rid of `llvm.trap` and `noreturn` to prevent reconstructing `unreachable` + if haskey(functions(mod), "llvm.trap") + trap = functions(mod)["llvm.trap"] + + for use in uses(trap) + val = user(use) + if isa(val, LLVM.CallInst) + erase!(val) + changed = true + end + end + + @compiler_assert isempty(uses(trap)) job + erase!(trap) + end + for f in functions(mod) + delete!(function_attributes(f), EnumAttribute("noreturn", 0)) + end + + end + return changed +end