@@ -34,22 +34,23 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
3434 return failure ();
3535 }
3636 }
37- llAssert (op, condition, adaptor.getMessage (), rewriter);
37+ Block *thenBlock = llAssert (op, condition, adaptor.getMessage (), rewriter);
3838 if (isa<RankedTensorType>(op.getCondition ().getType ())) {
3939 // Add a barrier to avoid a race condition in case an assert is followed
4040 // by an op that may trap if the assert condition is true. Since the
4141 // tensor in those two operations may have different layout we need to
4242 // make sure all the threads are done executing the assert before going to
4343 // the next op.
44+ rewriter.setInsertionPointToStart (thenBlock);
4445 barrier ();
4546 }
4647 rewriter.eraseOp (op);
4748 return success ();
4849 }
4950 // op: the op at which the assert is inserted. Unlike printf, we need to
5051 // know about the op to split the block.
51- void llAssert (Operation *op, Value condition, StringRef message,
52- ConversionPatternRewriter &rewriter) const {
52+ Block * llAssert (Operation *op, Value condition, StringRef message,
53+ ConversionPatternRewriter &rewriter) const {
5354 ConversionPatternRewriter::InsertionGuard guard (rewriter);
5455
5556 auto ctx = rewriter.getContext ();
@@ -87,6 +88,7 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
8788 rewriter.create <cf::BranchOp>(loc, thenBlock);
8889 rewriter.setInsertionPointToEnd (prevBlock);
8990 rewriter.create <cf::CondBranchOp>(loc, condition, ifBlock, thenBlock);
91+ return thenBlock;
9092 }
9193
9294protected:
0 commit comments