Skip to content

Commit 4ccb41f

Browse files
committed
[intel] define barrier in AssertOp' last block
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 7b821d0 commit 4ccb41f

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

lib/Conversion/TritonGPUToLLVM/AssertOpToLLVM.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,23 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
3434
return failure();
3535
}
3636
}
37-
llAssert(op, condition, adaptor.getMessage(), rewriter);
37+
Block *thenBlock = llAssert(op, condition, adaptor.getMessage(), rewriter);
3838
if (isa<RankedTensorType>(op.getCondition().getType())) {
3939
// Add a barrier to avoid a race condition in case an assert is followed
4040
// by an op that may trap if the assert condition is true. Since the
4141
// tensor in those two operations may have different layout we need to
4242
// make sure all the threads are done executing the assert before going to
4343
// the next op.
44+
rewriter.setInsertionPointToStart(thenBlock);
4445
barrier();
4546
}
4647
rewriter.eraseOp(op);
4748
return success();
4849
}
4950
// op: the op at which the assert is inserted. Unlike printf, we need to
5051
// know about the op to split the block.
51-
void llAssert(Operation *op, Value condition, StringRef message,
52-
ConversionPatternRewriter &rewriter) const {
52+
Block *llAssert(Operation *op, Value condition, StringRef message,
53+
ConversionPatternRewriter &rewriter) const {
5354
ConversionPatternRewriter::InsertionGuard guard(rewriter);
5455

5556
auto ctx = rewriter.getContext();
@@ -87,6 +88,7 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
8788
rewriter.create<cf::BranchOp>(loc, thenBlock);
8889
rewriter.setInsertionPointToEnd(prevBlock);
8990
rewriter.create<cf::CondBranchOp>(loc, condition, ifBlock, thenBlock);
91+
return thenBlock;
9092
}
9193

9294
protected:

0 commit comments

Comments
 (0)