Skip to content

Commit 92a4fad

Browse files
authored
Add back barrier after asserts (triton-lang#5043)
support asserts with scalar condition and only emit barrier for assert of tensors. Thanks to @peterbell10 for the suggestion.
1 parent edc5c5c commit 92a4fad

File tree

4 files changed

+11
-5
lines changed

4 files changed

+11
-5
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -891,7 +891,7 @@ def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite<GlobalMemory>]>]> {
891891
`tt.assert` takes a condition tensor and a message string.
892892
If the condition is false, the message is printed, and the program is aborted.
893893
}];
894-
let arguments = (ins TT_Tensor:$condition, StrAttr:$message);
894+
let arguments = (ins AnyTypeOf<[I1, I1Tensor]>:$condition, StrAttr:$message);
895895
let assemblyFormat = "$condition `,` $message attr-dict `:` type($condition)";
896896
}
897897

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,14 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
3535
}
3636
}
3737
llAssert(op, condition, adaptor.getMessage(), rewriter);
38+
if (isa<RankedTensorType>(op.getCondition().getType())) {
39+
// Add a barrier to avoid a race condition in case an assert is followed
40+
// by an op that may trap if the assert condition is true. Since the
41+
// tensor in those two operations may have different layout we need to
42+
// make sure all the threads are done executing the assert before going to
43+
// the next op.
44+
barrier();
45+
}
3846
rewriter.eraseOp(op);
3947
return success();
4048
}

python/triton/language/semantic.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1724,10 +1724,6 @@ def device_print(prefix: str, args: List[tl.tensor], hex: bool, builder: ir.buil
17241724
def device_assert(cond: tl.tensor, msg: str, builder: ir.builder) -> tl.tensor:
17251725
if not builder.options.debug:
17261726
return
1727-
cond_ty = cond.type
1728-
if not cond_ty.is_block():
1729-
cond_ty = tl.block_type(cond_ty.scalar, (1, ))
1730-
cond = tl.tensor(builder.create_splat(cond.handle, (1, )), cond_ty)
17311727
return tl.tensor(builder.create_assert(cond.handle, msg), tl.void)
17321728

17331729

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1906,6 +1906,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
19061906
// CHECK-DAG: llvm.mlir.global internal constant @assertFunc_0("unknown\00") {addr_space = 0 : i32}
19071907
// CHECK-DAG: llvm.mlir.global internal constant @assertFile_0("inner_call\00") {addr_space = 0 : i32}
19081908
// CHECK-DAG: llvm.mlir.global internal constant @assertMessage_0("assert text\00") {addr_space = 0 : i32}
1909+
// CHECK: llvm.call @__assertfail
1910+
// CHECK: nvvm.barrier0
19091911
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
19101912
tt.func public @add_kernel(%arg0: tensor<1xi1, #blocked>) {
19111913
tt.assert %arg0, "assert text" : tensor<1xi1, #blocked> loc(#loc5)

0 commit comments

Comments
 (0)