Skip to content

SPIR-V: Add a pre-optimization pass to convert unreachable to return. #709

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from

Conversation

maleadt
Copy link
Member

@maleadt maleadt commented Jul 30, 2025

This because SPIR-V doesn't have trap, meaning unreachable instructions get executed. That's UB, triggering issues with the PoCL driver (x-ref pocl/pocl#1971).

Sadly, this seems to be triggering bugs in the LLVM SPIR-V back-end, so looking into that.
EDIT: filed llvm/llvm-project#151344

Copy link
Contributor

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/src/spirv.jl b/src/spirv.jl
index 89bb54a..145685c 100644
--- a/src/spirv.jl
+++ b/src/spirv.jl
@@ -370,224 +370,238 @@ end
 # results in `OpUnreachable` actually getting executed, which is undefined behavior.
 # Instead, we transform unreachable instructions to returns with an error flag that's
 # checked by the caller.
-function lower_unreachable_to_return!(@nospecialize(job::CompilerJob),
-                                      mod::LLVM.Module, entry::LLVM.Function)
+function lower_unreachable_to_return!(
+        @nospecialize(job::CompilerJob),
+        mod::LLVM.Module, entry::LLVM.Function
+    )
     job = current_job::CompilerJob
     changed = false
     @tracepoint "lower unreachable to return" begin
 
-    already_transformed_functions = Set{LLVM.Function}()
+        already_transformed_functions = Set{LLVM.Function}()
 
-    # The pass runs until all unreachable instructions are transformed. During each
-    # iteration, we transform all unreachable instructions to returns, and transform all
-    # callers to handle the flag, generating a new unreachable when it is set.
-    while true
-        # Find all functions with unreachable instructions
-        functions_with_unreachable = Set{LLVM.Function}()
-        for f in functions(mod)
-            for bb in blocks(f), inst in instructions(bb)
-                if inst isa LLVM.UnreachableInst
-                    push!(functions_with_unreachable, f)
-                    break
-                end
-            end
-        end
-        isempty(functions_with_unreachable) && break
-
-        # Transform functions with unreachable to return a flag next to the original value
-        transformed_functions = Dict{LLVM.Function, LLVM.Function}()
-        for f in functions_with_unreachable
-            ft = function_type(f)
-            ret_type = return_type(ft)
-            fn = LLVM.name(f)
-
-            # in the case of the entry-point function, we cannot touch its type or returned
-            # value, so simply replace the unreachable with a return.
-            if f == entry
-                @compiler_assert ret_type == LLVM.VoidType() job
-
-                # find un reachables
-                unreachables = LLVM.Value[]
+        # The pass runs until all unreachable instructions are transformed. During each
+        # iteration, we transform all unreachable instructions to returns, and transform all
+        # callers to handle the flag, generating a new unreachable when it is set.
+        while true
+            # Find all functions with unreachable instructions
+            functions_with_unreachable = Set{LLVM.Function}()
+            for f in functions(mod)
                 for bb in blocks(f), inst in instructions(bb)
                     if inst isa LLVM.UnreachableInst
-                        push!(unreachables, inst)
+                        push!(functions_with_unreachable, f)
+                        break
                     end
                 end
+            end
+            isempty(functions_with_unreachable) && break
+
+            # Transform functions with unreachable to return a flag next to the original value
+            transformed_functions = Dict{LLVM.Function, LLVM.Function}()
+            for f in functions_with_unreachable
+                ft = function_type(f)
+                ret_type = return_type(ft)
+                fn = LLVM.name(f)
+
+                # in the case of the entry-point function, we cannot touch its type or returned
+                # value, so simply replace the unreachable with a return.
+                if f == entry
+                    @compiler_assert ret_type == LLVM.VoidType() job
+
+                    # find un reachables
+                    unreachables = LLVM.Value[]
+                    for bb in blocks(f), inst in instructions(bb)
+                        if inst isa LLVM.UnreachableInst
+                            push!(unreachables, inst)
+                        end
+                    end
 
-                # transform unreachable to return
-                @dispose builder=IRBuilder() begin
-                    for inst in unreachables
-                        position!(builder, inst)
-                        ret!(builder)
-                        erase!(inst)
+                    # transform unreachable to return
+                    @dispose builder = IRBuilder() begin
+                        for inst in unreachables
+                            position!(builder, inst)
+                            ret!(builder)
+                            erase!(inst)
+                        end
                     end
+
+                    continue
                 end
 
-                continue
-            end
+                # If this is the first time looking at this function, we need to change its type
+                if !in(f, already_transformed_functions)
+                    # Create new return type: {i1, original_type}
+                    new_ret_type = if ret_type == LLVM.VoidType()
+                        LLVM.StructType([LLVM.Int1Type()])
+                    else
+                        LLVM.StructType([LLVM.Int1Type(), ret_type])
+                    end
 
-            # If this is the first time looking at this function, we need to change its type
-            if !in(f, already_transformed_functions)
-                # Create new return type: {i1, original_type}
-                new_ret_type = if ret_type == LLVM.VoidType()
-                    LLVM.StructType([LLVM.Int1Type()])
-                else
-                    LLVM.StructType([LLVM.Int1Type(), ret_type])
-                end
+                    LLVM.name!(f, fn * ".old")
+                    new_ft = LLVM.FunctionType(new_ret_type, parameters(ft))
+                    new_f = LLVM.Function(mod, fn, new_ft)
+                    linkage!(new_f, linkage(f))
+                    for (i, param) in enumerate(parameters(f))
+                        LLVM.name!(parameters(new_f)[i], LLVM.name(param))
+                    end
+
+                    # clone the IR
+                    value_map = Dict{LLVM.Value, LLVM.Value}(
+                        param => parameters(new_f)[i] for (i, param) in enumerate(parameters(f))
+                    )
+                    clone_into!(
+                        new_f, f; value_map,
+                        changes = LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges
+                    )
+
+                    # rewrite return instructions
+                    returns = LLVM.Value[]
+                    for bb in blocks(new_f), inst in instructions(bb)
+                        if inst isa LLVM.RetInst
+                            push!(returns, inst)
+                        end
+                    end
+                    @dispose builder = IRBuilder() begin
+                        for inst in returns
+                            position!(builder, inst)
+                            if ret_type == LLVM.VoidType()
+                                # void function: return {false}
+                                flag_and_val =
+                                    insert_value!(
+                                    builder, UndefValue(new_ret_type),
+                                    ConstantInt(LLVM.Int1Type(), false), 0
+                                )
+                            else
+                                # non-void function: return {false, val}
+                                val = only(operands(inst))
+                                flag_and_val =
+                                    insert_value!(
+                                    builder, UndefValue(new_ret_type),
+                                    ConstantInt(LLVM.Int1Type(), false), 0
+                                )
+                                flag_and_val = insert_value!(builder, flag_and_val, val, 1)
+                            end
+                            ret!(builder, flag_and_val)
+                            erase!(inst)
+                        end
+                    end
 
-                LLVM.name!(f, fn * ".old")
-                new_ft = LLVM.FunctionType(new_ret_type, parameters(ft))
-                new_f = LLVM.Function(mod, fn, new_ft)
-                linkage!(new_f, linkage(f))
-                for (i, param) in enumerate(parameters(f))
-                    LLVM.name!(parameters(new_f)[i], LLVM.name(param))
+                    transformed_functions[f] = new_f
+                    push!(already_transformed_functions, new_f)
+                    f = new_f
                 end
 
-                # clone the IR
-                value_map = Dict{LLVM.Value, LLVM.Value}(
-                    param => parameters(new_f)[i] for (i,param) in enumerate(parameters(f))
-                )
-                clone_into!(new_f, f; value_map,
-                            changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
-
-                # rewrite return instructions
-                returns = LLVM.Value[]
-                for bb in blocks(new_f), inst in instructions(bb)
-                    if inst isa LLVM.RetInst
-                        push!(returns, inst)
+                # rewrite unreachable instructions
+                ret_type = return_type(function_type(f))
+                unreachables = LLVM.Value[]
+                for bb in blocks(f), inst in instructions(bb)
+                    if inst isa LLVM.UnreachableInst
+                        push!(unreachables, inst)
                     end
                 end
-                @dispose builder=IRBuilder() begin
-                    for inst in returns
+                @dispose builder = IRBuilder() begin
+                    for inst in unreachables
                         position!(builder, inst)
-                        if ret_type == LLVM.VoidType()
-                            # void function: return {false}
-                            flag_and_val =
-                                insert_value!(builder, UndefValue(new_ret_type),
-                                              ConstantInt(LLVM.Int1Type(), false), 0)
+                        if length(elements(ret_type)) == 1
+                            # void function: return {true}
+                            flag_and_val = insert_value!(
+                                builder, UndefValue(ret_type),
+                                ConstantInt(LLVM.Int1Type(), true), 0
+                            )
                         else
-                            # non-void function: return {false, val}
-                            val = only(operands(inst))
-                            flag_and_val =
-                                insert_value!(builder, UndefValue(new_ret_type),
-                                              ConstantInt(LLVM.Int1Type(), false), 0)
-                            flag_and_val = insert_value!(builder, flag_and_val, val, 1)
+                            # non-void function: return {true, undef}
+                            val_type = elements(ret_type)[2]
+                            flag_and_val = insert_value!(
+                                builder, UndefValue(ret_type),
+                                ConstantInt(LLVM.Int1Type(), true), 0
+                            )
+                            flag_and_val = insert_value!(
+                                builder, flag_and_val,
+                                UndefValue(val_type), 1
+                            )
                         end
                         ret!(builder, flag_and_val)
                         erase!(inst)
                     end
                 end
 
-                transformed_functions[f] = new_f
-                push!(already_transformed_functions, new_f)
-                f = new_f
+                changed = true
             end
 
-            # rewrite unreachable instructions
-            ret_type = return_type(function_type(f))
-            unreachables = LLVM.Value[]
-            for bb in blocks(f), inst in instructions(bb)
-                if inst isa LLVM.UnreachableInst
-                    push!(unreachables, inst)
-                end
-            end
-            @dispose builder=IRBuilder() begin
-                for inst in unreachables
-                    position!(builder, inst)
-                    if length(elements(ret_type)) == 1
-                        # void function: return {true}
-                        flag_and_val = insert_value!(builder, UndefValue(ret_type),
-                                                     ConstantInt(LLVM.Int1Type(), true), 0)
-                    else
-                        # non-void function: return {true, undef}
-                        val_type = elements(ret_type)[2]
-                        flag_and_val = insert_value!(builder, UndefValue(ret_type),
-                                                     ConstantInt(LLVM.Int1Type(), true), 0)
-                        flag_and_val = insert_value!(builder, flag_and_val,
-                                                     UndefValue(val_type), 1)
+            # Rewrite calls
+            for (old_f, new_f) in transformed_functions
+                calls_to_rewrite = LLVM.CallInst[]
+                for use in uses(old_f)
+                    call_inst = user(use)
+                    if call_inst isa LLVM.CallInst && called_operand(call_inst) == old_f
+                        push!(calls_to_rewrite, call_inst)
                     end
-                    ret!(builder, flag_and_val)
-                    erase!(inst)
                 end
-            end
 
-            changed = true
-        end
+                @dispose builder = IRBuilder() begin
+                    for call_inst in calls_to_rewrite
+                        f = LLVM.parent(LLVM.parent(call_inst))
+                        position!(builder, call_inst)
+
+                        # Call the new function
+                        new_call = call!(builder, function_type(new_f), new_f, arguments(call_inst))
+                        callconv!(new_call, callconv(call_inst))
+
+                        # Split the block and branch based on the flag
+                        flag = extract_value!(builder, new_call, 0)
+                        error_block = BasicBlock(f, "error")
+                        move_after(error_block, LLVM.parent(call_inst))
+                        continue_block = BasicBlock(f, "continue")
+                        move_after(continue_block, error_block)
+                        br_inst = br!(builder, flag, error_block, continue_block)
+
+                        # Extract the returned value in the continue block
+                        position!(builder, continue_block)
+                        if value_type(call_inst) != LLVM.VoidType()
+                            value = extract_value!(builder, new_call, 1)
+                            replace_uses!(call_inst, value)
+                        end
+                        @compiler_assert isempty(uses(call_inst)) job
+                        erase!(call_inst)
+
+                        # Move the remaining instructions over to the continue block
+                        while true
+                            inst = LLVM.nextinst(br_inst)
+                            inst === nothing && break
+                            remove!(inst)
+                            insert!(builder, inst)
+                        end
 
-        # Rewrite calls
-        for (old_f, new_f) in transformed_functions
-            calls_to_rewrite = LLVM.CallInst[]
-            for use in uses(old_f)
-                call_inst = user(use)
-                if call_inst isa LLVM.CallInst && called_operand(call_inst) == old_f
-                    push!(calls_to_rewrite, call_inst)
+                        # Generate an unreachable in the error block
+                        position!(builder, error_block)
+                        unreachable!(builder)
+                    end
                 end
+
+                @compiler_assert isempty(uses(old_f)) job
+                erase!(old_f)
             end
+        end
 
-            @dispose builder=IRBuilder() begin
-                for call_inst in calls_to_rewrite
-                    f = LLVM.parent(LLVM.parent(call_inst))
-                    position!(builder, call_inst)
-
-                    # Call the new function
-                    new_call = call!(builder, function_type(new_f), new_f, arguments(call_inst))
-                    callconv!(new_call, callconv(call_inst))
-
-                    # Split the block and branch based on the flag
-                    flag = extract_value!(builder, new_call, 0)
-                    error_block = BasicBlock(f, "error")
-                    move_after(error_block, LLVM.parent(call_inst))
-                    continue_block = BasicBlock(f, "continue")
-                    move_after(continue_block, error_block)
-                    br_inst = br!(builder, flag, error_block, continue_block)
-
-                    # Extract the returned value in the continue block
-                    position!(builder, continue_block)
-                    if value_type(call_inst) != LLVM.VoidType()
-                        value = extract_value!(builder, new_call, 1)
-                        replace_uses!(call_inst, value)
-                    end
-                    @compiler_assert isempty(uses(call_inst)) job
-                    erase!(call_inst)
-
-                    # Move the remaining instructions over to the continue block
-                    while true
-                        inst = LLVM.nextinst(br_inst)
-                        inst === nothing && break
-                        remove!(inst)
-                        insert!(builder, inst)
-                    end
+        # Get rid of `llvm.trap` and `noreturn` to prevent reconstructing `unreachable`
+        if haskey(functions(mod), "llvm.trap")
+            trap = functions(mod)["llvm.trap"]
 
-                    # Generate an unreachable in the error block
-                    position!(builder, error_block)
-                    unreachable!(builder)
+            for use in uses(trap)
+                val = user(use)
+                if isa(val, LLVM.CallInst)
+                    erase!(val)
+                    changed = true
                 end
             end
 
-            @compiler_assert isempty(uses(old_f)) job
-            erase!(old_f)
+            @compiler_assert isempty(uses(trap)) job
+            erase!(trap)
         end
-    end
-
-    # Get rid of `llvm.trap` and `noreturn` to prevent reconstructing `unreachable`
-    if haskey(functions(mod), "llvm.trap")
-        trap = functions(mod)["llvm.trap"]
-
-        for use in uses(trap)
-            val = user(use)
-            if isa(val, LLVM.CallInst)
-                erase!(val)
-                changed = true
-            end
+        for f in functions(mod)
+            delete!(function_attributes(f), EnumAttribute("noreturn", 0))
         end
 
-        @compiler_assert isempty(uses(trap)) job
-        erase!(trap)
-    end
-    for f in functions(mod)
-        delete!(function_attributes(f), EnumAttribute("noreturn", 0))
-    end
-
     end
     return changed
 end

@simeonschaub
Copy link
Contributor

I'm running into the following issue with this PR:

julia> OpenCL.code_llvm((Int,); kernel = true) do x
           fldmod1(x, 10)
           nothing
       end
ERROR: LLVM error: Invalid struct return type!
{ i1 } ([2 x i64]*, i64, i64)* @julia_fldmod1_71344

Stacktrace:
  [1] verify(mod::LLVM.Module)
    @ LLVM ~/.julia/packages/LLVM/UFrs4/src/analysis.jl:19
  [2] finish_module!(job::GPUCompiler.CompilerJob{GPUCompiler.SPIRVCompilerTarget, OpenCL.OpenCLCompilerParams}, mod::LLVM.Module, entry::LLVM.Function)
    @ GPUCompiler ~/.julia/dev/GPUCompiler/src/spirv.jl:80
  [3] macro expansion
    @ ~/.julia/dev/GPUCompiler/src/driver.jl:183 [inlined]
  [4] emit_llvm(job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/dev/GPUCompiler/src/utils.jl:116
  [5] emit_llvm(job::GPUCompiler.CompilerJob)
    @ GPUCompiler ~/.julia/dev/GPUCompiler/src/utils.jl:114
  [6] compile_unhooked(output::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/dev/GPUCompiler/src/driver.jl:95
  [7] compile_unhooked
    @ ~/.julia/dev/GPUCompiler/src/driver.jl:80 [inlined]
  [8] compile(target::Symbol, job::GPUCompiler.CompilerJob; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/dev/GPUCompiler/src/driver.jl:67
  [9] compile
    @ ~/.julia/dev/GPUCompiler/src/driver.jl:55 [inlined]
 [10] (::GPUCompiler.var"#186#187"{Bool, Symbol, Bool, GPUCompiler.CompilerJob{GPUCompiler.SPIRVCompilerTarget, OpenCL.OpenCLCompilerParams}, GPUCompiler.CompilerConfig{GPUCompiler.SPIRVCompilerTarget, OpenCL.OpenCLCompilerParams}})(ctx::LLVM.Context)
    @ GPUCompiler ~/.julia/dev/GPUCompiler/src/reflection.jl:191
 [11] JuliaContext(f::GPUCompiler.var"#186#187"{Bool, Symbol, Bool, GPUCompiler.CompilerJob{GPUCompiler.SPIRVCompilerTarget, OpenCL.OpenCLCompilerParams}, GPUCompiler.CompilerConfig{GPUCompiler.SPIRVCompilerTarget, OpenCL.OpenCLCompilerParams}}; kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/dev/GPUCompiler/src/driver.jl:34
 [12] JuliaContext(f::Function)
    @ GPUCompiler ~/.julia/dev/GPUCompiler/src/driver.jl:25
 [13] code_llvm(io::Base.TTY, job::GPUCompiler.CompilerJob; optimize::Bool, raw::Bool, debuginfo::Symbol, dump_module::Bool, kwargs::@Kwargs{})
    @ GPUCompiler ~/.julia/dev/GPUCompiler/src/reflection.jl:190
 [14] code_llvm
    @ ~/.julia/dev/GPUCompiler/src/reflection.jl:186 [inlined]
 [15] code_llvm(io::Base.TTY, func::Any, types::Any; kernel::Bool, kwargs::@Kwargs{})
    @ OpenCL ~/.julia/dev/OpenCL/src/compiler/reflection.jl:33
 [16] code_llvm(func::Any, types::Any; kwargs::@Kwargs{kernel::Bool})
    @ OpenCL ~/.julia/dev/OpenCL/src/compiler/reflection.jl:35
 [17] top-level scope
    @ REPL[8]:1

@maleadt
Copy link
Member Author

maleadt commented Aug 11, 2025

It's not unlikely some corner cases are handled incorrectly, as I wasn't able to validate the change due to the LLVM SPIR-V back-end not supporting struct return. I'll try to take a look.

@maleadt
Copy link
Member Author

maleadt commented Aug 11, 2025

Actually, I just remembered pocl/pocl#1971 (comment) where it was determined that this approach isn't viable. We cannot have one thread return, because that could cause other threads to deadlock when the hit a barrier.

One option is to use a global flag that all callers check, but that would probably kill performance, as well as not cover the case where a call to an exception-throwing function isn't done by all threads:

if work_item() == 0
  fun_that_throws()
end
barrier() # deadlocks

So maybe the only way forwards is to require cl_arm_controlled_kernel_termination/arm_terminate_kernel, which PoCL doesn't implement yet (let alone other OpenCL/SPIR-V drivers...).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants