Skip to content

Commit 2b4b973

Browse files
committed
Emit trap before exit when lowering unreachable instructions.
We only insert exit to correct the CFG, but shouldn't execute it. Make sure to insert trap first so that the behavior of the unreachable remains the same.
1 parent 631a6bb commit 2b4b973

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

src/ptx.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -297,10 +297,14 @@ end
297297
# `bar.sync` cannot be executed divergently on Pascal hardware or earlier.
298298
#
299299
# To avoid these fall-through successors that change the control flow,
300-
# we replace `unreachable` instructions with a call to `exit`. This informs
301-
# `ptxas` that the thread exits, and allows it to correctly construct a CFG,
302-
# and consequently correctly determine the divergence regions as intended.
300+
# we replace `unreachable` instructions with a call to `trap` and `exit`. This
301+
# informs `ptxas` that the thread exits, and allows it to correctly construct a
302+
# CFG, and consequently correctly determine the divergence regions as intended.
303+
# Note that we first emit a call to `trap`, so that the behaviour is the same
304+
# as before.
303305
function lower_unreachable!(f::LLVM.Function)
306+
mod = LLVM.parent(f)
307+
304308
# TODO:
305309
# - if unreachable blocks have been merged, we still may be jumping from different
306310
# divergent regions, potentially causing the same problem as above:
@@ -344,6 +348,12 @@ function lower_unreachable!(f::LLVM.Function)
344348
# inline assembly to exit a thread
345349
exit_ft = LLVM.FunctionType(LLVM.VoidType())
346350
exit = InlineAsm(exit_ft, "exit;", "", true)
351+
trap_ft = LLVM.FunctionType(LLVM.VoidType())
352+
trap = if haskey(functions(mod), "llvm.trap")
353+
functions(mod)["llvm.trap"]
354+
else
355+
LLVM.Function(mod, "llvm.trap", trap_ft)
356+
end
347357

348358
# rewrite the unreachable terminators
349359
@dispose builder=IRBuilder() begin
@@ -353,6 +363,7 @@ function lower_unreachable!(f::LLVM.Function)
353363
@assert inst isa LLVM.UnreachableInst
354364

355365
position!(builder, inst)
366+
call!(builder, trap_ft, trap)
356367
call!(builder, exit_ft, exit)
357368
end
358369
end

0 commit comments

Comments
 (0)