Skip to content

Commit 45c4325

Browse files
authored
Merge pull request #524 from JuliaGPU/tb/ptx_trap
PTX: Improve handling of trap
2 parents 75765bf + 0523284 commit 45c4325

File tree

3 files changed

+15
-36
lines changed

3 files changed

+15
-36
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "GPUCompiler"
22
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
33
authors = ["Tim Besard <[email protected]>"]
4-
version = "0.24.5"
4+
version = "0.25.0"
55

66
[deps]
77
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"

src/precompile.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,6 @@ end
5252
function _precompile_()
5353
ccall(:jl_generating_output, Cint, ()) == 1 || return nothing
5454
@assert precompile(Tuple{typeof(GPUCompiler.assign_args!),Expr,Vector{Any}})
55-
@assert precompile(Tuple{typeof(GPUCompiler.lower_trap!),LLVM.Module})
5655
@assert precompile(Tuple{typeof(GPUCompiler.lower_unreachable!),LLVM.Function})
5756
@assert precompile(Tuple{typeof(GPUCompiler.lower_gc_frame!),LLVM.Function})
5857
@assert precompile(Tuple{typeof(GPUCompiler.lower_throw!),LLVM.Module})

src/ptx.jl

Lines changed: 14 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,6 @@ end
187187

188188
function finish_ir!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
189189
mod::LLVM.Module, entry::LLVM.Function)
190-
lower_trap!(mod)
191190
for f in functions(mod)
192191
lower_unreachable!(f)
193192
end
@@ -246,36 +245,6 @@ end
246245

247246
## LLVM passes
248247

249-
# replace calls to `trap` with inline assembly calling `exit`, which isn't fatal
250-
function lower_trap!(mod::LLVM.Module)
251-
job = current_job::CompilerJob
252-
changed = false
253-
@timeit_debug to "lower trap" begin
254-
255-
if haskey(functions(mod), "llvm.trap")
256-
trap = functions(mod)["llvm.trap"]
257-
258-
# inline assembly to exit a thread
259-
exit_ft = LLVM.FunctionType(LLVM.VoidType())
260-
exit = InlineAsm(exit_ft, "exit;", "", true)
261-
262-
for use in uses(trap)
263-
val = user(use)
264-
if isa(val, LLVM.CallInst)
265-
@dispose builder=IRBuilder() begin
266-
position!(builder, val)
267-
call!(builder, exit_ft, exit)
268-
end
269-
unsafe_delete!(LLVM.parent(val), val)
270-
changed = true
271-
end
272-
end
273-
end
274-
275-
end
276-
return changed
277-
end
278-
279248
# lower `unreachable` to `exit` so that the emitted PTX has correct control flow
280249
#
281250
# During back-end compilation, `ptxas` inserts instructions to manage the harware's
@@ -328,10 +297,14 @@ end
328297
# `bar.sync` cannot be executed divergently on Pascal hardware or earlier.
329298
#
330299
# To avoid these fall-through successors that change the control flow,
331-
# we replace `unreachable` instructions with a call to `exit`. This informs
332-
# `ptxas` that the thread exits, and allows it to correctly construct a CFG,
333-
# and consequently correctly determine the divergence regions as intended.
300+
# we replace `unreachable` instructions with a call to `trap` and `exit`. This
301+
# informs `ptxas` that the thread exits, and allows it to correctly construct a
302+
# CFG, and consequently correctly determine the divergence regions as intended.
303+
# Note that we first emit a call to `trap`, so that the behaviour is the same
304+
# as before.
334305
function lower_unreachable!(f::LLVM.Function)
306+
mod = LLVM.parent(f)
307+
335308
# TODO:
336309
# - if unreachable blocks have been merged, we still may be jumping from different
337310
# divergent regions, potentially causing the same problem as above:
@@ -375,6 +348,12 @@ function lower_unreachable!(f::LLVM.Function)
375348
# inline assembly to exit a thread
376349
exit_ft = LLVM.FunctionType(LLVM.VoidType())
377350
exit = InlineAsm(exit_ft, "exit;", "", true)
351+
trap_ft = LLVM.FunctionType(LLVM.VoidType())
352+
trap = if haskey(functions(mod), "llvm.trap")
353+
functions(mod)["llvm.trap"]
354+
else
355+
LLVM.Function(mod, "llvm.trap", trap_ft)
356+
end
378357

379358
# rewrite the unreachable terminators
380359
@dispose builder=IRBuilder() begin
@@ -384,6 +363,7 @@ function lower_unreachable!(f::LLVM.Function)
384363
@assert inst isa LLVM.UnreachableInst
385364

386365
position!(builder, inst)
366+
call!(builder, trap_ft, trap)
387367
call!(builder, exit_ft, exit)
388368
end
389369
end

0 commit comments

Comments
 (0)