-
Notifications
You must be signed in to change notification settings - Fork 54
SPIR-V: Add a pre-optimization pass to convert unreachable to return. #709
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/src/spirv.jl b/src/spirv.jl
index 89bb54a..145685c 100644
--- a/src/spirv.jl
+++ b/src/spirv.jl
@@ -370,224 +370,238 @@ end
# results in `OpUnreachable` actually getting executed, which is undefined behavior.
# Instead, we transform unreachable instructions to returns with an error flag that's
# checked by the caller.
-function lower_unreachable_to_return!(@nospecialize(job::CompilerJob),
- mod::LLVM.Module, entry::LLVM.Function)
+function lower_unreachable_to_return!(
+ @nospecialize(job::CompilerJob),
+ mod::LLVM.Module, entry::LLVM.Function
+ )
job = current_job::CompilerJob
changed = false
@tracepoint "lower unreachable to return" begin
- already_transformed_functions = Set{LLVM.Function}()
+ already_transformed_functions = Set{LLVM.Function}()
- # The pass runs until all unreachable instructions are transformed. During each
- # iteration, we transform all unreachable instructions to returns, and transform all
- # callers to handle the flag, generating a new unreachable when it is set.
- while true
- # Find all functions with unreachable instructions
- functions_with_unreachable = Set{LLVM.Function}()
- for f in functions(mod)
- for bb in blocks(f), inst in instructions(bb)
- if inst isa LLVM.UnreachableInst
- push!(functions_with_unreachable, f)
- break
- end
- end
- end
- isempty(functions_with_unreachable) && break
-
- # Transform functions with unreachable to return a flag next to the original value
- transformed_functions = Dict{LLVM.Function, LLVM.Function}()
- for f in functions_with_unreachable
- ft = function_type(f)
- ret_type = return_type(ft)
- fn = LLVM.name(f)
-
- # in the case of the entry-point function, we cannot touch its type or returned
- # value, so simply replace the unreachable with a return.
- if f == entry
- @compiler_assert ret_type == LLVM.VoidType() job
-
- # find un reachables
- unreachables = LLVM.Value[]
+ # The pass runs until all unreachable instructions are transformed. During each
+ # iteration, we transform all unreachable instructions to returns, and transform all
+ # callers to handle the flag, generating a new unreachable when it is set.
+ while true
+ # Find all functions with unreachable instructions
+ functions_with_unreachable = Set{LLVM.Function}()
+ for f in functions(mod)
for bb in blocks(f), inst in instructions(bb)
if inst isa LLVM.UnreachableInst
- push!(unreachables, inst)
+ push!(functions_with_unreachable, f)
+ break
end
end
+ end
+ isempty(functions_with_unreachable) && break
+
+ # Transform functions with unreachable to return a flag next to the original value
+ transformed_functions = Dict{LLVM.Function, LLVM.Function}()
+ for f in functions_with_unreachable
+ ft = function_type(f)
+ ret_type = return_type(ft)
+ fn = LLVM.name(f)
+
+ # in the case of the entry-point function, we cannot touch its type or returned
+ # value, so simply replace the unreachable with a return.
+ if f == entry
+ @compiler_assert ret_type == LLVM.VoidType() job
+
+ # find un reachables
+ unreachables = LLVM.Value[]
+ for bb in blocks(f), inst in instructions(bb)
+ if inst isa LLVM.UnreachableInst
+ push!(unreachables, inst)
+ end
+ end
- # transform unreachable to return
- @dispose builder=IRBuilder() begin
- for inst in unreachables
- position!(builder, inst)
- ret!(builder)
- erase!(inst)
+ # transform unreachable to return
+ @dispose builder = IRBuilder() begin
+ for inst in unreachables
+ position!(builder, inst)
+ ret!(builder)
+ erase!(inst)
+ end
end
+
+ continue
end
- continue
- end
+ # If this is the first time looking at this function, we need to change its type
+ if !in(f, already_transformed_functions)
+ # Create new return type: {i1, original_type}
+ new_ret_type = if ret_type == LLVM.VoidType()
+ LLVM.StructType([LLVM.Int1Type()])
+ else
+ LLVM.StructType([LLVM.Int1Type(), ret_type])
+ end
- # If this is the first time looking at this function, we need to change its type
- if !in(f, already_transformed_functions)
- # Create new return type: {i1, original_type}
- new_ret_type = if ret_type == LLVM.VoidType()
- LLVM.StructType([LLVM.Int1Type()])
- else
- LLVM.StructType([LLVM.Int1Type(), ret_type])
- end
+ LLVM.name!(f, fn * ".old")
+ new_ft = LLVM.FunctionType(new_ret_type, parameters(ft))
+ new_f = LLVM.Function(mod, fn, new_ft)
+ linkage!(new_f, linkage(f))
+ for (i, param) in enumerate(parameters(f))
+ LLVM.name!(parameters(new_f)[i], LLVM.name(param))
+ end
+
+ # clone the IR
+ value_map = Dict{LLVM.Value, LLVM.Value}(
+ param => parameters(new_f)[i] for (i, param) in enumerate(parameters(f))
+ )
+ clone_into!(
+ new_f, f; value_map,
+ changes = LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges
+ )
+
+ # rewrite return instructions
+ returns = LLVM.Value[]
+ for bb in blocks(new_f), inst in instructions(bb)
+ if inst isa LLVM.RetInst
+ push!(returns, inst)
+ end
+ end
+ @dispose builder = IRBuilder() begin
+ for inst in returns
+ position!(builder, inst)
+ if ret_type == LLVM.VoidType()
+ # void function: return {false}
+ flag_and_val =
+ insert_value!(
+ builder, UndefValue(new_ret_type),
+ ConstantInt(LLVM.Int1Type(), false), 0
+ )
+ else
+ # non-void function: return {false, val}
+ val = only(operands(inst))
+ flag_and_val =
+ insert_value!(
+ builder, UndefValue(new_ret_type),
+ ConstantInt(LLVM.Int1Type(), false), 0
+ )
+ flag_and_val = insert_value!(builder, flag_and_val, val, 1)
+ end
+ ret!(builder, flag_and_val)
+ erase!(inst)
+ end
+ end
- LLVM.name!(f, fn * ".old")
- new_ft = LLVM.FunctionType(new_ret_type, parameters(ft))
- new_f = LLVM.Function(mod, fn, new_ft)
- linkage!(new_f, linkage(f))
- for (i, param) in enumerate(parameters(f))
- LLVM.name!(parameters(new_f)[i], LLVM.name(param))
+ transformed_functions[f] = new_f
+ push!(already_transformed_functions, new_f)
+ f = new_f
end
- # clone the IR
- value_map = Dict{LLVM.Value, LLVM.Value}(
- param => parameters(new_f)[i] for (i,param) in enumerate(parameters(f))
- )
- clone_into!(new_f, f; value_map,
- changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
-
- # rewrite return instructions
- returns = LLVM.Value[]
- for bb in blocks(new_f), inst in instructions(bb)
- if inst isa LLVM.RetInst
- push!(returns, inst)
+ # rewrite unreachable instructions
+ ret_type = return_type(function_type(f))
+ unreachables = LLVM.Value[]
+ for bb in blocks(f), inst in instructions(bb)
+ if inst isa LLVM.UnreachableInst
+ push!(unreachables, inst)
end
end
- @dispose builder=IRBuilder() begin
- for inst in returns
+ @dispose builder = IRBuilder() begin
+ for inst in unreachables
position!(builder, inst)
- if ret_type == LLVM.VoidType()
- # void function: return {false}
- flag_and_val =
- insert_value!(builder, UndefValue(new_ret_type),
- ConstantInt(LLVM.Int1Type(), false), 0)
+ if length(elements(ret_type)) == 1
+ # void function: return {true}
+ flag_and_val = insert_value!(
+ builder, UndefValue(ret_type),
+ ConstantInt(LLVM.Int1Type(), true), 0
+ )
else
- # non-void function: return {false, val}
- val = only(operands(inst))
- flag_and_val =
- insert_value!(builder, UndefValue(new_ret_type),
- ConstantInt(LLVM.Int1Type(), false), 0)
- flag_and_val = insert_value!(builder, flag_and_val, val, 1)
+ # non-void function: return {true, undef}
+ val_type = elements(ret_type)[2]
+ flag_and_val = insert_value!(
+ builder, UndefValue(ret_type),
+ ConstantInt(LLVM.Int1Type(), true), 0
+ )
+ flag_and_val = insert_value!(
+ builder, flag_and_val,
+ UndefValue(val_type), 1
+ )
end
ret!(builder, flag_and_val)
erase!(inst)
end
end
- transformed_functions[f] = new_f
- push!(already_transformed_functions, new_f)
- f = new_f
+ changed = true
end
- # rewrite unreachable instructions
- ret_type = return_type(function_type(f))
- unreachables = LLVM.Value[]
- for bb in blocks(f), inst in instructions(bb)
- if inst isa LLVM.UnreachableInst
- push!(unreachables, inst)
- end
- end
- @dispose builder=IRBuilder() begin
- for inst in unreachables
- position!(builder, inst)
- if length(elements(ret_type)) == 1
- # void function: return {true}
- flag_and_val = insert_value!(builder, UndefValue(ret_type),
- ConstantInt(LLVM.Int1Type(), true), 0)
- else
- # non-void function: return {true, undef}
- val_type = elements(ret_type)[2]
- flag_and_val = insert_value!(builder, UndefValue(ret_type),
- ConstantInt(LLVM.Int1Type(), true), 0)
- flag_and_val = insert_value!(builder, flag_and_val,
- UndefValue(val_type), 1)
+ # Rewrite calls
+ for (old_f, new_f) in transformed_functions
+ calls_to_rewrite = LLVM.CallInst[]
+ for use in uses(old_f)
+ call_inst = user(use)
+ if call_inst isa LLVM.CallInst && called_operand(call_inst) == old_f
+ push!(calls_to_rewrite, call_inst)
end
- ret!(builder, flag_and_val)
- erase!(inst)
end
- end
- changed = true
- end
+ @dispose builder = IRBuilder() begin
+ for call_inst in calls_to_rewrite
+ f = LLVM.parent(LLVM.parent(call_inst))
+ position!(builder, call_inst)
+
+ # Call the new function
+ new_call = call!(builder, function_type(new_f), new_f, arguments(call_inst))
+ callconv!(new_call, callconv(call_inst))
+
+ # Split the block and branch based on the flag
+ flag = extract_value!(builder, new_call, 0)
+ error_block = BasicBlock(f, "error")
+ move_after(error_block, LLVM.parent(call_inst))
+ continue_block = BasicBlock(f, "continue")
+ move_after(continue_block, error_block)
+ br_inst = br!(builder, flag, error_block, continue_block)
+
+ # Extract the returned value in the continue block
+ position!(builder, continue_block)
+ if value_type(call_inst) != LLVM.VoidType()
+ value = extract_value!(builder, new_call, 1)
+ replace_uses!(call_inst, value)
+ end
+ @compiler_assert isempty(uses(call_inst)) job
+ erase!(call_inst)
+
+ # Move the remaining instructions over to the continue block
+ while true
+ inst = LLVM.nextinst(br_inst)
+ inst === nothing && break
+ remove!(inst)
+ insert!(builder, inst)
+ end
- # Rewrite calls
- for (old_f, new_f) in transformed_functions
- calls_to_rewrite = LLVM.CallInst[]
- for use in uses(old_f)
- call_inst = user(use)
- if call_inst isa LLVM.CallInst && called_operand(call_inst) == old_f
- push!(calls_to_rewrite, call_inst)
+ # Generate an unreachable in the error block
+ position!(builder, error_block)
+ unreachable!(builder)
+ end
end
+
+ @compiler_assert isempty(uses(old_f)) job
+ erase!(old_f)
end
+ end
- @dispose builder=IRBuilder() begin
- for call_inst in calls_to_rewrite
- f = LLVM.parent(LLVM.parent(call_inst))
- position!(builder, call_inst)
-
- # Call the new function
- new_call = call!(builder, function_type(new_f), new_f, arguments(call_inst))
- callconv!(new_call, callconv(call_inst))
-
- # Split the block and branch based on the flag
- flag = extract_value!(builder, new_call, 0)
- error_block = BasicBlock(f, "error")
- move_after(error_block, LLVM.parent(call_inst))
- continue_block = BasicBlock(f, "continue")
- move_after(continue_block, error_block)
- br_inst = br!(builder, flag, error_block, continue_block)
-
- # Extract the returned value in the continue block
- position!(builder, continue_block)
- if value_type(call_inst) != LLVM.VoidType()
- value = extract_value!(builder, new_call, 1)
- replace_uses!(call_inst, value)
- end
- @compiler_assert isempty(uses(call_inst)) job
- erase!(call_inst)
-
- # Move the remaining instructions over to the continue block
- while true
- inst = LLVM.nextinst(br_inst)
- inst === nothing && break
- remove!(inst)
- insert!(builder, inst)
- end
+ # Get rid of `llvm.trap` and `noreturn` to prevent reconstructing `unreachable`
+ if haskey(functions(mod), "llvm.trap")
+ trap = functions(mod)["llvm.trap"]
- # Generate an unreachable in the error block
- position!(builder, error_block)
- unreachable!(builder)
+ for use in uses(trap)
+ val = user(use)
+ if isa(val, LLVM.CallInst)
+ erase!(val)
+ changed = true
end
end
- @compiler_assert isempty(uses(old_f)) job
- erase!(old_f)
+ @compiler_assert isempty(uses(trap)) job
+ erase!(trap)
end
- end
-
- # Get rid of `llvm.trap` and `noreturn` to prevent reconstructing `unreachable`
- if haskey(functions(mod), "llvm.trap")
- trap = functions(mod)["llvm.trap"]
-
- for use in uses(trap)
- val = user(use)
- if isa(val, LLVM.CallInst)
- erase!(val)
- changed = true
- end
+ for f in functions(mod)
+ delete!(function_attributes(f), EnumAttribute("noreturn", 0))
end
- @compiler_assert isempty(uses(trap)) job
- erase!(trap)
- end
- for f in functions(mod)
- delete!(function_attributes(f), EnumAttribute("noreturn", 0))
- end
-
end
return changed
end |
I'm running into the following issue with this PR: julia> OpenCL.code_llvm((Int,); kernel = true) do x
fldmod1(x, 10)
nothing
end
ERROR: LLVM error: Invalid struct return type!
{ i1 } ([2 x i64]*, i64, i64)* @julia_fldmod1_71344
Stacktrace:
[1] verify(mod::LLVM.Module)
@ LLVM ~/.julia/packages/LLVM/UFrs4/src/analysis.jl:19
[2] finish_module!(job::GPUCompiler.CompilerJob{GPUCompiler.SPIRVCompilerTarget, OpenCL.OpenCLCompilerParams}, mod::LLVM.Module, entry::LLVM.Function)
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/spirv.jl:80
[3] macro expansion
@ ~/.julia/dev/GPUCompiler/src/driver.jl:183 [inlined]
[4] emit_llvm(job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/utils.jl:116
[5] emit_llvm(job::GPUCompiler.CompilerJob)
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/utils.jl:114
[6] compile_unhooked(output::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/driver.jl:95
[7] compile_unhooked
@ ~/.julia/dev/GPUCompiler/src/driver.jl:80 [inlined]
[8] compile(target::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/driver.jl:67
[9] compile
@ ~/.julia/dev/GPUCompiler/src/driver.jl:55 [inlined]
[10] (::GPUCompiler.var"#186#187"{Bool, Symbol, Bool, GPUCompiler.CompilerJob{GPUCompiler.SPIRVCompilerTarget, OpenCL.OpenCLCompilerParams}, GPUCompiler.CompilerConfig{GPUCompiler.SPIRVCompilerTarget, OpenCL.OpenCLCompilerParams}})(ctx::LLVM.Context)
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/reflection.jl:191
[11] JuliaContext(f::GPUCompiler.var"#186#187"{Bool, Symbol, Bool, GPUCompiler.CompilerJob{GPUCompiler.SPIRVCompilerTarget, OpenCL.OpenCLCompilerParams}, GPUCompiler.CompilerConfig{GPUCompiler.SPIRVCompilerTarget, OpenCL.OpenCLCompilerParams}}; kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/driver.jl:34
[12] JuliaContext(f::Function)
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/driver.jl:25
[13] code_llvm(io::Base.TTY, job::GPUCompiler.CompilerJob; optimize::Bool, raw::Bool, debuginfo::Symbol, dump_module::Bool, kwargs::@Kwargs{})
@ GPUCompiler ~/.julia/dev/GPUCompiler/src/reflection.jl:190
[14] code_llvm
@ ~/.julia/dev/GPUCompiler/src/reflection.jl:186 [inlined]
[15] code_llvm(io::Base.TTY, func::Any, types::Any; kernel::Bool, kwargs::@Kwargs{})
@ OpenCL ~/.julia/dev/OpenCL/src/compiler/reflection.jl:33
[16] code_llvm(func::Any, types::Any; kwargs::@Kwargs{kernel::Bool})
@ OpenCL ~/.julia/dev/OpenCL/src/compiler/reflection.jl:35
[17] top-level scope
@ REPL[8]:1 |
It's not unlikely some corner cases are handled incorrectly, as I wasn't able to validate the change due to the LLVM SPIR-V back-end not supporting struct return. I'll try to take a look. |
Actually, I just remembered pocl/pocl#1971 (comment) where it was determined that this approach isn't viable. We cannot have one thread One option is to use a global flag that all callers check, but that would probably kill performance, as well as not cover the case where a call to an exception-throwing function isn't done by all threads: if work_item() == 0
fun_that_throws()
end
barrier() # deadlocks So maybe the only way forwards is to require |
This because SPIR-V doesn't have
trap
, meaningunreachable
instructions get executed. That's UB, triggering issues with the PoCL driver (x-ref pocl/pocl#1971).Sadly, this seems to be triggering bugs in the LLVM SPIR-V back-end, so looking into that.
EDIT: filed llvm/llvm-project#151344