Skip to content

Commit 410e65e

Browse files
ThomasRaouxliuyunqi20
authored andcommitted
Add back barrier after asserts (#5043)
support asserts with scalar condition and only emit barrier for assert of tensors. Thanks to @peterbell10 for the suggestion.
1 parent 3d30102 commit 410e65e

File tree

4 files changed

+11
-5
lines changed

4 files changed

+11
-5
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
891891
`tt.assert` takes a condition tensor and a message string.
892892
If the condition is false, the message is printed, and the program is aborted.
893893
}];
894-
let arguments = (ins TT_Tensor:$condition, StrAttr:$message);
894+
let arguments = (ins AnyTypeOf<[I1, I1Tensor]>:$condition, StrAttr:$message);
895895
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
896896
}
897897

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
3535
}
3636
}
3737
llAssert(op, condition, adaptor.getMessage(), rewriter);
38+
if (isa<RankedTensorType>(op.getCondition().getType())) {
39+
// Add a barrier to avoid a race condition in case an assert is followed
40+
// by an op that may trap if the assert condition is true. Since the
41+
// tensor in those two operations may have different layout we need to
42+
// make sure all the threads are done executing the assert before going to
43+
// the next op.
44+
barrier();
45+
}
3846
rewriter.eraseOp(op);
3947
return success();
4048
}

python/triton/language/semantic.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1721,10 +1721,6 @@ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.buil
17211721
def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
17221722
if not builder.options.debug:
17231723
return
1724-
cond_ty = cond.type
1725-
if not cond_ty.is_block():
1726-
cond_ty = tl.block_type(cond_ty.scalar, (1, ))
1727-
cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty)
17281724
return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)
17291725

17301726

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1728,6 +1728,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
17281728
// CHECK-DAG: llvm.mlir.global internal constant @assertFunc_0("unknown\00") {addr_space = 0 : i32}
17291729
// CHECK-DAG: llvm.mlir.global internal constant @assertFile_0("inner_call\00") {addr_space = 0 : i32}
17301730
// CHECK-DAG: llvm.mlir.global internal constant @assertMessage_0("assert text\00") {addr_space = 0 : i32}
1731+
// CHECK: llvm.call @__assertfail
1732+
// CHECK: nvvm.barrier0
17311733
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
17321734
tt.func public @add_kernel(%arg0: tensor<1xi1, #blocked>) {
17331735
tt.assert %arg0, "assert text" : tensor<1xi1, #blocked> loc(#loc5)

0 commit comments

Comments
 (0)