Skip to content

Commit 88bf4fd

Browse files
authored
PTX: Lower unreachable control flow to avoid bad CFG reconstruction (#467)
1 parent aaaf1de commit 88bf4fd

File tree

2 files changed

+137
-143
lines changed

2 files changed

+137
-143
lines changed

src/precompile.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,8 @@ end
5252
function _precompile_()
5353
ccall(:jl_generating_output, Cint, ()) == 1 || return nothing
5454
@assert precompile(Tuple{typeof(GPUCompiler.assign_args!),Expr,Vector{Any}})
55-
@assert precompile(Tuple{typeof(GPUCompiler.hide_trap!),LLVM.Module})
56-
@assert precompile(Tuple{typeof(GPUCompiler.hide_unreachable!),LLVM.Function})
55+
@assert precompile(Tuple{typeof(GPUCompiler.lower_trap!),LLVM.Module})
56+
@assert precompile(Tuple{typeof(GPUCompiler.lower_unreachable!),LLVM.Function})
5757
@assert precompile(Tuple{typeof(GPUCompiler.lower_gc_frame!),LLVM.Function})
5858
@assert precompile(Tuple{typeof(GPUCompiler.lower_throw!),LLVM.Module})
5959
#@assert precompile(Tuple{typeof(GPUCompiler.split_kwargs),Tuple{},Vector{Symbol},Vararg{Vector{Symbol}, N} where N})

src/ptx.jl

Lines changed: 135 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -11,25 +11,23 @@ Base.@kwdef struct PTXCompilerTarget <: AbstractCompilerTarget
1111
# codegen quirks
1212
## can we emit debug info in the PTX assembly?
1313
debuginfo::Bool = false
14-
## do we permit unrachable statements, which often result in divergent control flow?
15-
unreachable::Bool = false
16-
## can exceptions use `exit` (which doesn't kill the GPU), or should they use `trap`?
17-
exitable::Bool = false
1814

1915
# optional properties
2016
minthreads::Union{Nothing,Int,NTuple{<:Any,Int}} = nothing
2117
maxthreads::Union{Nothing,Int,NTuple{<:Any,Int}} = nothing
2218
blocks_per_sm::Union{Nothing,Int} = nothing
2319
maxregs::Union{Nothing,Int} = nothing
20+
21+
# deprecated; remove with next major version
22+
exitable::Union{Nothing,Bool} = nothing
23+
unreachable::Union{Nothing,Bool} = nothing
2424
end
2525

2626
function Base.hash(target::PTXCompilerTarget, h::UInt)
2727
h = hash(target.cap, h)
2828
h = hash(target.ptx, h)
2929

3030
h = hash(target.debuginfo, h)
31-
h = hash(target.unreachable, h)
32-
h = hash(target.exitable, h)
3331

3432
h = hash(target.minthreads, h)
3533
h = hash(target.maxthreads, h)
@@ -92,8 +90,7 @@ isintrinsic(@nospecialize(job::CompilerJob{PTXCompilerTarget}), fn::String) =
9290
# XXX: the debuginfo part should be handled by GPUCompiler as it applies to all back-ends.
9391
runtime_slug(@nospecialize(job::CompilerJob{PTXCompilerTarget})) =
9492
"ptx-sm_$(job.config.target.cap.major)$(job.config.target.cap.minor)" *
95-
"-debuginfo=$(Int(llvm_debug_info(job)))" *
96-
"-exitable=$(job.config.target.exitable)"
93+
"-debuginfo=$(Int(llvm_debug_info(job)))"
9794

9895
function finish_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
9996
mod::LLVM.Module, entry::LLVM.Function)
@@ -132,14 +129,6 @@ function finish_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
132129
end
133130

134131
@dispose pm=ModulePassManager() begin
135-
# hide `unreachable` from LLVM so that it doesn't introduce divergent control flow
136-
if !job.config.target.unreachable
137-
add!(pm, FunctionPass("HideUnreachable", hide_unreachable!))
138-
end
139-
140-
# even if we support `unreachable`, we still prefer `exit` to `trap`
141-
add!(pm, ModulePass("HideTrap", hide_trap!))
142-
143132
# we emit properties (of the device and ptx isa) as private global constants,
144133
# so run the optimizer so that they are inlined before the rest of the optimizer runs.
145134
global_optimizer!(pm)
@@ -188,6 +177,13 @@ function finish_ir!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
188177
mod::LLVM.Module, entry::LLVM.Function)
189178
ctx = context(mod)
190179

180+
@dispose pm=ModulePassManager() begin
181+
add!(pm, ModulePass("LowerTrap", lower_trap!))
182+
add!(pm, FunctionPass("LowerUnreachable", lower_unreachable!))
183+
184+
run!(pm, mod)
185+
end
186+
191187
if job.config.kernel
192188
# add metadata annotations for the assembler to the module
193189

@@ -242,111 +238,29 @@ end
242238

243239
## LLVM passes
244240

245-
# HACK: this pass removes `unreachable` information from LLVM
246-
#
247-
# `ptxas` is buggy and cannot deal with thread-divergent control flow in the presence of
248-
# shared memory (see JuliaGPU/CUDAnative.jl#4). avoid that by rewriting control flow to fall
249-
# through any other block. this is semantically invalid, but the code is unreachable anyhow
250-
# (and we expect it to be preceded by eg. a noreturn function, or a trap).
251-
#
252-
# TODO: can LLVM do this with structured CFGs? It seems to have some support, but seemingly
253-
# only to prevent introducing non-structureness during optimization (ie. the front-end
254-
# is still responsible for generating structured control flow).
255-
function hide_unreachable!(fun::LLVM.Function)
241+
# replace calls to `trap` with inline assembly calling `exit`, which isn't fatal
242+
function lower_trap!(mod::LLVM.Module)
256243
job = current_job::CompilerJob
257-
ctx = context(fun)
244+
ctx = context(mod)
258245
changed = false
259-
@timeit_debug to "hide unreachable" begin
246+
@timeit_debug to "lower trap" begin
260247

261-
# remove `noreturn` attributes
262-
#
263-
# when calling a `noreturn` function, LLVM places an `unreachable` after the call.
264-
# this leads to an early `ret` from the function.
265-
attrs = function_attributes(fun)
266-
delete!(attrs, EnumAttribute("noreturn", 0; ctx))
248+
if haskey(functions(mod), "llvm.trap")
249+
trap = functions(mod)["llvm.trap"]
267250

268-
# build a map of basic block predecessors
269-
predecessors = Dict(bb => Set{LLVM.BasicBlock}() for bb in blocks(fun))
270-
@timeit_debug to "predecessors" for bb in blocks(fun)
271-
insts = instructions(bb)
272-
if !isempty(insts)
273-
inst = last(insts)
274-
if isterminator(inst)
275-
for bb′ in successors(inst)
276-
push!(predecessors[bb′], bb)
277-
end
278-
end
279-
end
280-
end
251+
# inline assembly to exit a thread
252+
exit_ft = LLVM.FunctionType(LLVM.VoidType(ctx))
253+
exit = InlineAsm(exit_ft, "exit;", "", true)
281254

282-
# scan for unreachable terminators and alternative successors
283-
worklist = Pair{LLVM.BasicBlock, Union{Nothing,LLVM.BasicBlock}}[]
284-
@timeit_debug to "find" for bb in blocks(fun)
285-
unreachable = terminator(bb)
286-
if isa(unreachable, LLVM.UnreachableInst)
287-
unsafe_delete!(bb, unreachable)
288-
changed = true
289-
290-
try
291-
terminator(bb)
292-
# the basic-block is still terminated properly, nothing to do
293-
# (this can happen with `ret; unreachable`)
294-
# TODO: `unreachable; unreachable`
295-
catch ex
296-
isa(ex, UndefRefError) || rethrow(ex)
255+
for use in uses(trap)
256+
val = user(use)
257+
if isa(val, LLVM.CallInst)
297258
@dispose builder=IRBuilder(ctx) begin
298-
position!(builder, bb)
299-
300-
# find the strict predecessors to this block
301-
preds = collect(predecessors[bb])
302-
303-
# find a fallthrough block: recursively look at predecessors
304-
# and find a successor that branches to any other block
305-
fallthrough = nothing
306-
while !isempty(preds)
307-
# find an alternative successor
308-
for pred in preds, succ in successors(terminator(pred))
309-
if succ != bb
310-
fallthrough = succ
311-
break
312-
end
313-
end
314-
fallthrough === nothing || break
315-
316-
# recurse upwards
317-
old_preds = copy(preds)
318-
empty!(preds)
319-
for pred in old_preds
320-
append!(preds, predecessors[pred])
321-
end
322-
end
323-
push!(worklist, bb => fallthrough)
324-
end
325-
end
326-
end
327-
end
328-
329-
# apply the pending terminator rewrites
330-
@timeit_debug to "replace" if !isempty(worklist)
331-
let builder = IRBuilder(ctx)
332-
for (bb, fallthrough) in worklist
333-
position!(builder, bb)
334-
if fallthrough !== nothing
335-
br!(builder, fallthrough)
336-
else
337-
# couldn't find any other successor. this happens with functions
338-
# that only contain a single block, or when the block is dead.
339-
ft = function_type(fun)
340-
if return_type(ft) == LLVM.VoidType(ctx)
341-
# even though returning can lead to invalid control flow,
342-
# it mostly happens with functions that just throw,
343-
# and leaving the unreachable there would make the optimizer
344-
# place another after the call.
345-
ret!(builder)
346-
else
347-
unreachable!(builder)
348-
end
259+
position!(builder, val)
260+
call!(builder, exit_ft, exit)
349261
end
262+
unsafe_delete!(LLVM.parent(val), val)
263+
changed = true
350264
end
351265
end
352266
end
@@ -355,41 +269,121 @@ function hide_unreachable!(fun::LLVM.Function)
355269
return changed
356270
end
357271

358-
# HACK: this pass removes calls to `trap` and replaces them with inline assembly
272+
# lower `unreachable` to `exit` so that the emitted PTX has correct control flow
359273
#
360-
# if LLVM knows we're trapping, code is marked `unreachable` (see `hide_unreachable!`).
361-
function hide_trap!(mod::LLVM.Module)
362-
job = current_job::CompilerJob
363-
ctx = context(mod)
364-
changed = false
365-
@timeit_debug to "hide trap" begin
274+
# During back-end compilation, `ptxas` inserts instructions to manage the harware's
275+
# reconvergence stack (SSY and SYNC). In order to do so, it needs to identify
276+
# divergent regions:
277+
#
278+
# entry:
279+
# // start of divergent region
280+
# @%p0 bra cont;
281+
# ...
282+
# bra.uni cont;
283+
# cont:
284+
# // end of divergent region
285+
# bar.sync 0;
286+
#
287+
# Meanwhile, LLVM's branch-folder and block-placement MIR passes will try to optimize
288+
# the block layout, e.g., by placing unlikely blocks at the end of the function:
289+
#
290+
# entry:
291+
# // start of divergent region
292+
# @%p0 bra cont;
293+
# @%p1 bra unlikely;
294+
# bra.uni cont;
295+
# cont:
296+
# // end of divergent region
297+
# bar.sync 0;
298+
# unlikely:
299+
# bra.uni cont;
300+
#
301+
# That is not a problem as long as the unlikely block continunes back into the
302+
# divergent region. Crucially, this is not the case with unreachable control flow:
303+
#
304+
# entry:
305+
# // start of divergent region
306+
# @%p0 bra cont;
307+
# @%p1 bra throw;
308+
# bra.uni cont;
309+
# cont:
310+
# bar.sync 0;
311+
# throw:
312+
# call throw_and_trap();
313+
# // unreachable
314+
# exit:
315+
# // end of divergent region
316+
# ret;
317+
#
318+
# Dynamically, this is fine, because the called function does not return.
319+
# However, `ptxas` does not know that and adds a successor edge to the `exit`
320+
# block, widening the divergence range. In this example, that's not allowed, as
321+
# `bar.sync` cannot be executed divergently on Pascal hardware or earlier.
322+
#
323+
# To avoid these fall-through successors that change the control flow,
324+
# we replace `unreachable` instructions with a call to `exit`. This informs
325+
# `ptxas` that the thread exits, and allows it to correctly construct a CFG,
326+
# and consequently correctly determine the divergence regions as intended.
327+
function lower_unreachable!(f::LLVM.Function)
328+
ctx = context(f)
329+
330+
# TODO:
331+
# - if unreachable blocks have been merged, we still may be jumping from different
332+
# divergent regions, potentially causing the same problem as above:
333+
# entry:
334+
# // start of divergent region 1
335+
# @%p0 bra cont1;
336+
# @%p1 bra throw;
337+
# bra.uni cont1;
338+
# cont1:
339+
# // end of divergent region 1
340+
# bar.sync 0; // is this executed divergently?
341+
# // start of divergent region 2
342+
# @%p2 bra cont2;
343+
# @%p3 bra throw;
344+
# bra.uni cont2;
345+
# cont2:
346+
# // end of divergent region 2
347+
# ...
348+
# throw:
349+
# trap;
350+
# br throw;
351+
# if this is a problem, we probably need to clone blocks with multiple
352+
# predecessors so that there's a unique path from each region of
353+
# divergence to every `unreachable` terminator
354+
355+
# remove `noreturn` attributes, to avoid the (minimal) optimization that
356+
# happens during `prepare_execution!` undoing our work here.
357+
# this shouldn't be needed when we upstream the pass.
358+
attrs = function_attributes(f)
359+
delete!(attrs, EnumAttribute("noreturn", 0; ctx))
366360

367-
# inline assembly to exit a thread, hiding control flow from LLVM
368-
exit_ft = LLVM.FunctionType(LLVM.VoidType(ctx))
369-
exit = if job.config.target.exitable
370-
InlineAsm(exit_ft, "exit;", "", true)
371-
else
372-
InlineAsm(exit_ft, "trap;", "", true)
361+
# find unreachable blocks
362+
unreachable_blocks = BasicBlock[]
363+
for block in blocks(f)
364+
if terminator(block) isa LLVM.UnreachableInst
365+
push!(unreachable_blocks, block)
366+
end
373367
end
368+
isempty(unreachable_blocks) && return false
374369

375-
if haskey(functions(mod), "llvm.trap")
376-
trap = functions(mod)["llvm.trap"]
370+
# inline assembly to exit a thread
371+
exit_ft = LLVM.FunctionType(LLVM.VoidType(ctx))
372+
exit = InlineAsm(exit_ft, "exit;", "", true)
377373

378-
for use in uses(trap)
379-
val = user(use)
380-
if isa(val, LLVM.CallInst)
381-
@dispose builder=IRBuilder(ctx) begin
382-
position!(builder, val)
383-
call!(builder, exit_ft, exit)
384-
end
385-
unsafe_delete!(LLVM.parent(val), val)
386-
changed = true
387-
end
374+
# rewrite the unreachable terminators
375+
@dispose builder=IRBuilder(ctx) begin
376+
entry_block = first(blocks(f))
377+
for block in unreachable_blocks
378+
inst = terminator(block)
379+
@assert inst isa LLVM.UnreachableInst
380+
381+
position!(builder, inst)
382+
call!(builder, exit_ft, exit)
388383
end
389384
end
390385

391-
end
392-
return changed
386+
return true
393387
end
394388

395389
# Replace occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect with an integer.

0 commit comments

Comments
 (0)