@@ -34,25 +34,22 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
3434 return failure ();
3535 }
3636 }
37- Block *thenBlock = llAssert (op, condition, adaptor.getMessage (), rewriter);
37+ llAssert (op, condition, adaptor.getMessage (), rewriter);
3838 if (isa<RankedTensorType>(op.getCondition ().getType ())) {
3939 // Add a barrier to avoid a race condition in case an assert is followed
4040 // by an op that may trap if the assert condition is true. Since the
4141 // tensor in those two operations may have different layout we need to
4242 // make sure all the threads are done executing the assert before going to
4343 // the next op.
44- rewriter.setInsertionPointToStart (thenBlock);
4544 barrier ();
4645 }
4746 rewriter.eraseOp (op);
4847 return success ();
4948 }
5049 // op: the op at which the assert is inserted. Unlike printf, we need to
5150 // know about the op to split the block.
52- Block *llAssert (Operation *op, Value condition, StringRef message,
53- ConversionPatternRewriter &rewriter) const {
54- ConversionPatternRewriter::InsertionGuard guard (rewriter);
55-
51+ void llAssert (Operation *op, Value condition, StringRef message,
52+ ConversionPatternRewriter &rewriter) const {
5653 auto ctx = rewriter.getContext ();
5754 auto loc = op->getLoc ();
5855
@@ -88,7 +85,7 @@ struct AssertOpConversion : public ConvertOpToLLVMPattern<triton::AssertOp> {
8885 rewriter.create <cf::BranchOp>(loc, thenBlock);
8986 rewriter.setInsertionPointToEnd (prevBlock);
9087 rewriter.create <cf::CondBranchOp>(loc, condition, ifBlock, thenBlock);
91- return thenBlock;
88+ rewriter. setInsertionPointToStart ( thenBlock) ;
9289 }
9390
9491protected:
0 commit comments