Skip to content

Commit 26781e4

Browse files
authored
Revert "Revert "Add back barrier after asserts (#5043)"" (#2657)
Closes #2644 The error (more details: #2644 (comment)) seems to be that the operation is incorrectly inserted into the block. My best guess is that we need to explicitly insert a barrier at the beginning of the `thenBlock`. However I don't know the exact reason why this code works for nvidia (maybe because of the different number of instructions that initially replace `"gpu.barrier"() : () -> ()` however I'm not sure). ```bash python: /home/runner/work/triton/triton/llvm-project/llvm/include/llvm/ADT/ilist_iterator.h:168: llvm::ilist_iterator::reference llvm::ilist_iterator<llvm::ilist_detail::node_options<mlir::Operation, true, false, void, false, void>, false, false>::operator*() const [OptionsT = llvm::ilist_detail::node_options<mlir::Operation, true, false, void, false, void>, IsReverse = false, IsConst = false]: Assertion `!NodePtr->isKnownSentinel()' failed. Aborted (core dumped) ``` --------- Signed-off-by: Anatoly Myachev <[email protected]>
1 parent ca95a70 commit 26781e4

File tree

3 files changed

+11
-6
lines changed

3 files changed

+11
-6
lines changed

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,21 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
3535
}
3636
}
3737
llAssert(op, condition, adaptor.getMessage(), rewriter);
38+
if (isa<RankedTensorType>(op.getCondition().getType())) {
39+
// Add a barrier to avoid a race condition in case an assert is followed
40+
// by an op that may trap if the assert condition is true. Since the
41+
// tensor in those two operations may have different layout we need to
42+
// make sure all the threads are done executing the assert before going to
43+
// the next op.
44+
barrier();
45+
}
3846
rewriter.eraseOp(op);
3947
return success();
4048
}
4149
// op: the op at which the assert is inserted. Unlike printf, we need to
4250
// know about the op to split the block.
4351
void llAssert(Operation *op, Value condition, StringRef message,
4452
ConversionPatternRewriter &rewriter) const {
45-
ConversionPatternRewriter::InsertionGuard guard(rewriter);
46-
4753
auto ctx = rewriter.getContext();
4854
auto loc = op->getLoc();
4955

@@ -79,6 +85,7 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
7985
rewriter.create<cf::BranchOp>(loc, thenBlock);
8086
rewriter.setInsertionPointToEnd(prevBlock);
8187
rewriter.create<cf::CondBranchOp>(loc, condition, ifBlock, thenBlock);
88+
rewriter.setInsertionPointToStart(thenBlock);
8289
}
8390

8491
protected:

python/triton/language/semantic.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1729,10 +1729,6 @@ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.buil
17291729
def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
17301730
if not builder.options.debug:
17311731
return
1732-
cond_ty = cond.type
1733-
if not cond_ty.is_block():
1734-
cond_ty = tl.block_type(cond_ty.scalar, (1, ))
1735-
cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty)
17361732
return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)
17371733

17381734

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1842,6 +1842,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
18421842
// CHECK-DAG: llvm.mlir.global internal constant @assertFunc_0("unknown\00") {addr_space = 0 : i32}
18431843
// CHECK-DAG: llvm.mlir.global internal constant @assertFile_0("inner_call\00") {addr_space = 0 : i32}
18441844
// CHECK-DAG: llvm.mlir.global internal constant @assertMessage_0("assert text\00") {addr_space = 0 : i32}
1845+
// CHECK: llvm.call @__assertfail
1846+
// CHECK: nvvm.barrier0
18451847
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
18461848
tt.func public @add_kernel(%arg0: tensor<1xi1, #blocked>) {
18471849
tt.assert %arg0, "assert text" : tensor<1xi1, #blocked> loc(#loc5)

0 commit comments

Comments
 (0)