@@ -11,25 +11,23 @@ Base.@kwdef struct PTXCompilerTarget <: AbstractCompilerTarget
11
11
# codegen quirks
12
12
# # can we emit debug info in the PTX assembly?
13
13
debuginfo:: Bool = false
14
- # # do we permit unrachable statements, which often result in divergent control flow?
15
- unreachable:: Bool = false
16
- # # can exceptions use `exit` (which doesn't kill the GPU), or should they use `trap`?
17
- exitable:: Bool = false
18
14
19
15
# optional properties
20
16
minthreads:: Union{Nothing,Int,NTuple{<:Any,Int}} = nothing
21
17
maxthreads:: Union{Nothing,Int,NTuple{<:Any,Int}} = nothing
22
18
blocks_per_sm:: Union{Nothing,Int} = nothing
23
19
maxregs:: Union{Nothing,Int} = nothing
20
+
21
+ # deprecated; remove with next major version
22
+ exitable:: Union{Nothing,Bool} = nothing
23
+ unreachable:: Union{Nothing,Bool} = nothing
24
24
end
25
25
26
26
function Base. hash (target:: PTXCompilerTarget , h:: UInt )
27
27
h = hash (target. cap, h)
28
28
h = hash (target. ptx, h)
29
29
30
30
h = hash (target. debuginfo, h)
31
- h = hash (target. unreachable, h)
32
- h = hash (target. exitable, h)
33
31
34
32
h = hash (target. minthreads, h)
35
33
h = hash (target. maxthreads, h)
@@ -92,8 +90,7 @@ isintrinsic(@nospecialize(job::CompilerJob{PTXCompilerTarget}), fn::String) =
92
90
# XXX : the debuginfo part should be handled by GPUCompiler as it applies to all back-ends.
93
91
runtime_slug (@nospecialize (job:: CompilerJob{PTXCompilerTarget} )) =
94
92
" ptx-sm_$(job. config. target. cap. major)$(job. config. target. cap. minor) " *
95
- " -debuginfo=$(Int (llvm_debug_info (job))) " *
96
- " -exitable=$(job. config. target. exitable) "
93
+ " -debuginfo=$(Int (llvm_debug_info (job))) "
97
94
98
95
function finish_module! (@nospecialize (job:: CompilerJob{PTXCompilerTarget} ),
99
96
mod:: LLVM.Module , entry:: LLVM.Function )
@@ -132,14 +129,6 @@ function finish_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
132
129
end
133
130
134
131
@dispose pm= ModulePassManager () begin
135
- # hide `unreachable` from LLVM so that it doesn't introduce divergent control flow
136
- if ! job. config. target. unreachable
137
- add! (pm, FunctionPass (" HideUnreachable" , hide_unreachable!))
138
- end
139
-
140
- # even if we support `unreachable`, we still prefer `exit` to `trap`
141
- add! (pm, ModulePass (" HideTrap" , hide_trap!))
142
-
143
132
# we emit properties (of the device and ptx isa) as private global constants,
144
133
# so run the optimizer so that they are inlined before the rest of the optimizer runs.
145
134
global_optimizer! (pm)
@@ -188,6 +177,13 @@ function finish_ir!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
188
177
mod:: LLVM.Module , entry:: LLVM.Function )
189
178
ctx = context (mod)
190
179
180
+ @dispose pm= ModulePassManager () begin
181
+ add! (pm, ModulePass (" LowerTrap" , lower_trap!))
182
+ add! (pm, FunctionPass (" LowerUnreachable" , lower_unreachable!))
183
+
184
+ run! (pm, mod)
185
+ end
186
+
191
187
if job. config. kernel
192
188
# add metadata annotations for the assembler to the module
193
189
@@ -242,111 +238,29 @@ end
242
238
243
239
# # LLVM passes
244
240
245
- # HACK: this pass removes `unreachable` information from LLVM
246
- #
247
- # `ptxas` is buggy and cannot deal with thread-divergent control flow in the presence of
248
- # shared memory (see JuliaGPU/CUDAnative.jl#4). avoid that by rewriting control flow to fall
249
- # through any other block. this is semantically invalid, but the code is unreachable anyhow
250
- # (and we expect it to be preceded by eg. a noreturn function, or a trap).
251
- #
252
- # TODO : can LLVM do this with structured CFGs? It seems to have some support, but seemingly
253
- # only to prevent introducing non-structureness during optimization (ie. the front-end
254
- # is still responsible for generating structured control flow).
255
- function hide_unreachable! (fun:: LLVM.Function )
241
+ # replace calls to `trap` with inline assembly calling `exit`, which isn't fatal
242
+ function lower_trap! (mod:: LLVM.Module )
256
243
job = current_job:: CompilerJob
257
- ctx = context (fun )
244
+ ctx = context (mod )
258
245
changed = false
259
- @timeit_debug to " hide unreachable " begin
246
+ @timeit_debug to " lower trap " begin
260
247
261
- # remove `noreturn` attributes
262
- #
263
- # when calling a `noreturn` function, LLVM places an `unreachable` after the call.
264
- # this leads to an early `ret` from the function.
265
- attrs = function_attributes (fun)
266
- delete! (attrs, EnumAttribute (" noreturn" , 0 ; ctx))
248
+ if haskey (functions (mod), " llvm.trap" )
249
+ trap = functions (mod)[" llvm.trap" ]
267
250
268
- # build a map of basic block predecessors
269
- predecessors = Dict (bb => Set {LLVM.BasicBlock} () for bb in blocks (fun))
270
- @timeit_debug to " predecessors" for bb in blocks (fun)
271
- insts = instructions (bb)
272
- if ! isempty (insts)
273
- inst = last (insts)
274
- if isterminator (inst)
275
- for bb′ in successors (inst)
276
- push! (predecessors[bb′], bb)
277
- end
278
- end
279
- end
280
- end
251
+ # inline assembly to exit a thread
252
+ exit_ft = LLVM. FunctionType (LLVM. VoidType (ctx))
253
+ exit = InlineAsm (exit_ft, " exit;" , " " , true )
281
254
282
- # scan for unreachable terminators and alternative successors
283
- worklist = Pair{LLVM. BasicBlock, Union{Nothing,LLVM. BasicBlock}}[]
284
- @timeit_debug to " find" for bb in blocks (fun)
285
- unreachable = terminator (bb)
286
- if isa (unreachable, LLVM. UnreachableInst)
287
- unsafe_delete! (bb, unreachable)
288
- changed = true
289
-
290
- try
291
- terminator (bb)
292
- # the basic-block is still terminated properly, nothing to do
293
- # (this can happen with `ret; unreachable`)
294
- # TODO : `unreachable; unreachable`
295
- catch ex
296
- isa (ex, UndefRefError) || rethrow (ex)
255
+ for use in uses (trap)
256
+ val = user (use)
257
+ if isa (val, LLVM. CallInst)
297
258
@dispose builder= IRBuilder (ctx) begin
298
- position! (builder, bb)
299
-
300
- # find the strict predecessors to this block
301
- preds = collect (predecessors[bb])
302
-
303
- # find a fallthrough block: recursively look at predecessors
304
- # and find a successor that branches to any other block
305
- fallthrough = nothing
306
- while ! isempty (preds)
307
- # find an alternative successor
308
- for pred in preds, succ in successors (terminator (pred))
309
- if succ != bb
310
- fallthrough = succ
311
- break
312
- end
313
- end
314
- fallthrough === nothing || break
315
-
316
- # recurse upwards
317
- old_preds = copy (preds)
318
- empty! (preds)
319
- for pred in old_preds
320
- append! (preds, predecessors[pred])
321
- end
322
- end
323
- push! (worklist, bb => fallthrough)
324
- end
325
- end
326
- end
327
- end
328
-
329
- # apply the pending terminator rewrites
330
- @timeit_debug to " replace" if ! isempty (worklist)
331
- let builder = IRBuilder (ctx)
332
- for (bb, fallthrough) in worklist
333
- position! (builder, bb)
334
- if fallthrough != = nothing
335
- br! (builder, fallthrough)
336
- else
337
- # couldn't find any other successor. this happens with functions
338
- # that only contain a single block, or when the block is dead.
339
- ft = function_type (fun)
340
- if return_type (ft) == LLVM. VoidType (ctx)
341
- # even though returning can lead to invalid control flow,
342
- # it mostly happens with functions that just throw,
343
- # and leaving the unreachable there would make the optimizer
344
- # place another after the call.
345
- ret! (builder)
346
- else
347
- unreachable! (builder)
348
- end
259
+ position! (builder, val)
260
+ call! (builder, exit_ft, exit)
349
261
end
262
+ unsafe_delete! (LLVM. parent (val), val)
263
+ changed = true
350
264
end
351
265
end
352
266
end
@@ -355,41 +269,121 @@ function hide_unreachable!(fun::LLVM.Function)
355
269
return changed
356
270
end
357
271
358
- # HACK: this pass removes calls to `trap` and replaces them with inline assembly
272
+ # lower `unreachable` to `exit` so that the emitted PTX has correct control flow
359
273
#
360
- # if LLVM knows we're trapping, code is marked `unreachable` (see `hide_unreachable!`).
361
- function hide_trap! (mod:: LLVM.Module )
362
- job = current_job:: CompilerJob
363
- ctx = context (mod)
364
- changed = false
365
- @timeit_debug to " hide trap" begin
274
+ # During back-end compilation, `ptxas` inserts instructions to manage the harware's
275
+ # reconvergence stack (SSY and SYNC). In order to do so, it needs to identify
276
+ # divergent regions:
277
+ #
278
+ # entry:
279
+ # // start of divergent region
280
+ # @%p0 bra cont;
281
+ # ...
282
+ # bra.uni cont;
283
+ # cont:
284
+ # // end of divergent region
285
+ # bar.sync 0;
286
+ #
287
+ # Meanwhile, LLVM's branch-folder and block-placement MIR passes will try to optimize
288
+ # the block layout, e.g., by placing unlikely blocks at the end of the function:
289
+ #
290
+ # entry:
291
+ # // start of divergent region
292
+ # @%p0 bra cont;
293
+ # @%p1 bra unlikely;
294
+ # bra.uni cont;
295
+ # cont:
296
+ # // end of divergent region
297
+ # bar.sync 0;
298
+ # unlikely:
299
+ # bra.uni cont;
300
+ #
301
+ # That is not a problem as long as the unlikely block continunes back into the
302
+ # divergent region. Crucially, this is not the case with unreachable control flow:
303
+ #
304
+ # entry:
305
+ # // start of divergent region
306
+ # @%p0 bra cont;
307
+ # @%p1 bra throw;
308
+ # bra.uni cont;
309
+ # cont:
310
+ # bar.sync 0;
311
+ # throw:
312
+ # call throw_and_trap();
313
+ # // unreachable
314
+ # exit:
315
+ # // end of divergent region
316
+ # ret;
317
+ #
318
+ # Dynamically, this is fine, because the called function does not return.
319
+ # However, `ptxas` does not know that and adds a successor edge to the `exit`
320
+ # block, widening the divergence range. In this example, that's not allowed, as
321
+ # `bar.sync` cannot be executed divergently on Pascal hardware or earlier.
322
+ #
323
+ # To avoid these fall-through successors that change the control flow,
324
+ # we replace `unreachable` instructions with a call to `exit`. This informs
325
+ # `ptxas` that the thread exits, and allows it to correctly construct a CFG,
326
+ # and consequently correctly determine the divergence regions as intended.
327
+ function lower_unreachable! (f:: LLVM.Function )
328
+ ctx = context (f)
329
+
330
+ # TODO :
331
+ # - if unreachable blocks have been merged, we still may be jumping from different
332
+ # divergent regions, potentially causing the same problem as above:
333
+ # entry:
334
+ # // start of divergent region 1
335
+ # @%p0 bra cont1;
336
+ # @%p1 bra throw;
337
+ # bra.uni cont1;
338
+ # cont1:
339
+ # // end of divergent region 1
340
+ # bar.sync 0; // is this executed divergently?
341
+ # // start of divergent region 2
342
+ # @%p2 bra cont2;
343
+ # @%p3 bra throw;
344
+ # bra.uni cont2;
345
+ # cont2:
346
+ # // end of divergent region 2
347
+ # ...
348
+ # throw:
349
+ # trap;
350
+ # br throw;
351
+ # if this is a problem, we probably need to clone blocks with multiple
352
+ # predecessors so that there's a unique path from each region of
353
+ # divergence to every `unreachable` terminator
354
+
355
+ # remove `noreturn` attributes, to avoid the (minimal) optimization that
356
+ # happens during `prepare_execution!` undoing our work here.
357
+ # this shouldn't be needed when we upstream the pass.
358
+ attrs = function_attributes (f)
359
+ delete! (attrs, EnumAttribute (" noreturn" , 0 ; ctx))
366
360
367
- # inline assembly to exit a thread, hiding control flow from LLVM
368
- exit_ft = LLVM . FunctionType (LLVM . VoidType (ctx))
369
- exit = if job . config . target . exitable
370
- InlineAsm (exit_ft, " exit; " , " " , true )
371
- else
372
- InlineAsm (exit_ft, " trap; " , " " , true )
361
+ # find unreachable blocks
362
+ unreachable_blocks = BasicBlock[]
363
+ for block in blocks (f)
364
+ if terminator (block) isa LLVM . UnreachableInst
365
+ push! (unreachable_blocks, block)
366
+ end
373
367
end
368
+ isempty (unreachable_blocks) && return false
374
369
375
- if haskey (functions (mod), " llvm.trap" )
376
- trap = functions (mod)[" llvm.trap" ]
370
+ # inline assembly to exit a thread
371
+ exit_ft = LLVM. FunctionType (LLVM. VoidType (ctx))
372
+ exit = InlineAsm (exit_ft, " exit;" , " " , true )
377
373
378
- for use in uses (trap)
379
- val = user (use)
380
- if isa (val, LLVM. CallInst)
381
- @dispose builder= IRBuilder (ctx) begin
382
- position! (builder, val)
383
- call! (builder, exit_ft, exit)
384
- end
385
- unsafe_delete! (LLVM. parent (val), val)
386
- changed = true
387
- end
374
+ # rewrite the unreachable terminators
375
+ @dispose builder= IRBuilder (ctx) begin
376
+ entry_block = first (blocks (f))
377
+ for block in unreachable_blocks
378
+ inst = terminator (block)
379
+ @assert inst isa LLVM. UnreachableInst
380
+
381
+ position! (builder, inst)
382
+ call! (builder, exit_ft, exit)
388
383
end
389
384
end
390
385
391
- end
392
- return changed
386
+ return true
393
387
end
394
388
395
389
# Replace occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect with an integer.
0 commit comments