Skip to content

SPIR-V: Add a pre-optimization pass to convert unreachable to return. #709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/ptx.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ function finish_ir!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
mod::LLVM.Module, entry::LLVM.Function)
if LLVM.version() < v"17"
for f in functions(mod)
lower_unreachable!(f)
lower_unreachable_to_exit!(f)
end
end

Expand Down Expand Up @@ -308,7 +308,7 @@ end
# CFG, and consequently correctly determine the divergence regions as intended.
# Note that we first emit a call to `trap`, so that the behaviour is the same
# as before.
function lower_unreachable!(f::LLVM.Function)
function lower_unreachable_to_exit!(f::LLVM.Function)
mod = LLVM.parent(f)

# TODO:
Expand Down
233 changes: 233 additions & 0 deletions src/spirv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ function finish_module!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
# they do support struct byval, for OpenCL, so wrap byval parameters in a struct.
if job.config.kernel
entry = wrap_byval(job, mod, entry)
lower_unreachable_to_return!(job, mod, entry)
verify(mod)
end

# add module metadata
Expand Down Expand Up @@ -358,3 +360,234 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F

return new_f
end


## LLVM IR passes

# lower unreachable instructions to returns with error flags
#
# SPIR-V does not have a trap instruction, so the common trap + unreachable sequence
# results in `OpUnreachable` actually getting executed, which is undefined behavior.
# Instead, we transform unreachable instructions to returns with an error flag that's
# checked by the caller.
function lower_unreachable_to_return!(@nospecialize(job::CompilerJob),
mod::LLVM.Module, entry::LLVM.Function)
job = current_job::CompilerJob
changed = false
@tracepoint "lower unreachable to return" begin

already_transformed_functions = Set{LLVM.Function}()

# The pass runs until all unreachable instructions are transformed. During each
# iteration, we transform all unreachable instructions to returns, and transform all
# callers to handle the flag, generating a new unreachable when it is set.
while true
# Find all functions with unreachable instructions
functions_with_unreachable = Set{LLVM.Function}()
for f in functions(mod)
for bb in blocks(f), inst in instructions(bb)
if inst isa LLVM.UnreachableInst
push!(functions_with_unreachable, f)
break
end
end
end
isempty(functions_with_unreachable) && break

# Transform functions with unreachable to return a flag next to the original value
transformed_functions = Dict{LLVM.Function, LLVM.Function}()
for f in functions_with_unreachable
ft = function_type(f)
ret_type = return_type(ft)
fn = LLVM.name(f)

# in the case of the entry-point function, we cannot touch its type or returned
# value, so simply replace the unreachable with a return.
if f == entry
@compiler_assert ret_type == LLVM.VoidType() job

# find un reachables
unreachables = LLVM.Value[]
for bb in blocks(f), inst in instructions(bb)
if inst isa LLVM.UnreachableInst
push!(unreachables, inst)
end
end

# transform unreachable to return
@dispose builder=IRBuilder() begin
for inst in unreachables
position!(builder, inst)
ret!(builder)
erase!(inst)
end
end

continue
end

# If this is the first time looking at this function, we need to change its type
if !in(f, already_transformed_functions)
# Create new return type: {i1, original_type}
new_ret_type = if ret_type == LLVM.VoidType()
LLVM.StructType([LLVM.Int1Type()])
else
LLVM.StructType([LLVM.Int1Type(), ret_type])
end

LLVM.name!(f, fn * ".old")
new_ft = LLVM.FunctionType(new_ret_type, parameters(ft))
new_f = LLVM.Function(mod, fn, new_ft)
linkage!(new_f, linkage(f))
for (i, param) in enumerate(parameters(f))
LLVM.name!(parameters(new_f)[i], LLVM.name(param))
end

# clone the IR
value_map = Dict{LLVM.Value, LLVM.Value}(
param => parameters(new_f)[i] for (i,param) in enumerate(parameters(f))
)
clone_into!(new_f, f; value_map,
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)

# rewrite return instructions
returns = LLVM.Value[]
for bb in blocks(new_f), inst in instructions(bb)
if inst isa LLVM.RetInst
push!(returns, inst)
end
end
@dispose builder=IRBuilder() begin
for inst in returns
position!(builder, inst)
if ret_type == LLVM.VoidType()
# void function: return {false}
flag_and_val =
insert_value!(builder, UndefValue(new_ret_type),
ConstantInt(LLVM.Int1Type(), false), 0)
else
# non-void function: return {false, val}
val = only(operands(inst))
flag_and_val =
insert_value!(builder, UndefValue(new_ret_type),
ConstantInt(LLVM.Int1Type(), false), 0)
flag_and_val = insert_value!(builder, flag_and_val, val, 1)
end
ret!(builder, flag_and_val)
erase!(inst)
end
end

transformed_functions[f] = new_f
push!(already_transformed_functions, new_f)
f = new_f
end

# rewrite unreachable instructions
ret_type = return_type(function_type(f))
unreachables = LLVM.Value[]
for bb in blocks(f), inst in instructions(bb)
if inst isa LLVM.UnreachableInst
push!(unreachables, inst)
end
end
@dispose builder=IRBuilder() begin
for inst in unreachables
position!(builder, inst)
if length(elements(ret_type)) == 1
# void function: return {true}
flag_and_val = insert_value!(builder, UndefValue(ret_type),
ConstantInt(LLVM.Int1Type(), true), 0)
else
# non-void function: return {true, undef}
val_type = elements(ret_type)[2]
flag_and_val = insert_value!(builder, UndefValue(ret_type),
ConstantInt(LLVM.Int1Type(), true), 0)
flag_and_val = insert_value!(builder, flag_and_val,
UndefValue(val_type), 1)
end
ret!(builder, flag_and_val)
erase!(inst)
end
end

changed = true
end

# Rewrite calls
for (old_f, new_f) in transformed_functions
calls_to_rewrite = LLVM.CallInst[]
for use in uses(old_f)
call_inst = user(use)
if call_inst isa LLVM.CallInst && called_operand(call_inst) == old_f
push!(calls_to_rewrite, call_inst)
end
end

@dispose builder=IRBuilder() begin
for call_inst in calls_to_rewrite
f = LLVM.parent(LLVM.parent(call_inst))
position!(builder, call_inst)

# Call the new function
new_call = call!(builder, function_type(new_f), new_f, arguments(call_inst))
callconv!(new_call, callconv(call_inst))

# Split the block and branch based on the flag
flag = extract_value!(builder, new_call, 0)
error_block = BasicBlock(f, "error")
move_after(error_block, LLVM.parent(call_inst))
continue_block = BasicBlock(f, "continue")
move_after(continue_block, error_block)
br_inst = br!(builder, flag, error_block, continue_block)

# Extract the returned value in the continue block
position!(builder, continue_block)
if value_type(call_inst) != LLVM.VoidType()
value = extract_value!(builder, new_call, 1)
replace_uses!(call_inst, value)
end
@compiler_assert isempty(uses(call_inst)) job
erase!(call_inst)

# Move the remaining instructions over to the continue block
while true
inst = LLVM.nextinst(br_inst)
inst === nothing && break
remove!(inst)
insert!(builder, inst)
end

# Generate an unreachable in the error block
position!(builder, error_block)
unreachable!(builder)
end
end

@compiler_assert isempty(uses(old_f)) job
erase!(old_f)
end
end

# Get rid of `llvm.trap` and `noreturn` to prevent reconstructing `unreachable`
if haskey(functions(mod), "llvm.trap")
trap = functions(mod)["llvm.trap"]

for use in uses(trap)
val = user(use)
if isa(val, LLVM.CallInst)
erase!(val)
changed = true
end
end

@compiler_assert isempty(uses(trap)) job
erase!(trap)
end
for f in functions(mod)
delete!(function_attributes(f), EnumAttribute("noreturn", 0))
end

end
return changed
end
Loading