@@ -158,11 +158,15 @@ struct ConvertReduce : public OpConversionPattern<ReduceOp> {
158158 ReduceOp op, OpAdaptor adaptor,
159159 ConversionPatternRewriter &rewriter) const override {
160160 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
161+ arith::IntegerOverflowFlags overflowFlag (arith::IntegerOverflowFlags::nuw &
162+ arith::IntegerOverflowFlags::nsw);
163+ auto noOverflow =
164+ arith::IntegerOverflowFlagsAttr::get (b.getContext (), overflowFlag);
161165
162166 auto cmod = b.create <arith::ConstantOp>(modulusAttr (op));
163167 // ModArithType ensures cmod can be correctly interpreted as a signed number
164168 auto rems = b.create <arith::RemSIOp>(adaptor.getOperands ()[0 ], cmod);
165- auto add = b.create <arith::AddIOp>(rems, cmod);
169+ auto add = b.create <arith::AddIOp>(rems, cmod, noOverflow );
166170 // TODO(google/heir #710): better with a subifge
167171 auto remu = b.create <arith::RemUIOp>(add, cmod);
168172 rewriter.replaceOp (op, remu);
@@ -207,6 +211,11 @@ struct ConvertMontReduce : public OpConversionPattern<MontReduceOp> {
207211 TypedAttr limbShiftAttr = b.getIntegerAttr (getElementTypeOrSelf (tLow),
208212 (numLimbs - 1 ) * limbWidth);
209213
214+ arith::IntegerOverflowFlags overflowFlag (arith::IntegerOverflowFlags::nuw &
215+ arith::IntegerOverflowFlags::nsw);
216+ auto noOverflow =
217+ arith::IntegerOverflowFlagsAttr::get (b.getContext (), overflowFlag);
218+
210219 // Splat the attributes to match the shape of `tLow`.
211220 if (auto shapedType = dyn_cast<ShapedType>(tLow.getType ())) {
212221 limbType = shapedType.cloneWith (std::nullopt , limbType);
@@ -238,11 +247,11 @@ struct ConvertMontReduce : public OpConversionPattern<MontReduceOp> {
238247 // Add the product to `T`.
239248 auto sum = b.create <arith::AddUIExtendedOp>(tLow, mN .getLow ());
240249 tLow = sum.getSum ();
241- tHigh = b.create <arith::AddIOp>(tHigh, mN .getHigh ());
250+ tHigh = b.create <arith::AddIOp>(tHigh, mN .getHigh (), noOverflow );
242251 // Add carry from the `sum` to `tHigh`.
243252 auto carryExt =
244253 b.create <arith::ExtUIOp>(tHigh.getType (), sum.getOverflow ());
245- tHigh = b.create <arith::AddIOp>(tHigh, carryExt);
254+ tHigh = b.create <arith::AddIOp>(tHigh, carryExt, noOverflow );
246255 // Shift right by `limbWidth` to discard the zeroed limb.
247256 tLow = b.create <arith::ShRUIOp>(tLow, limbWidthConst);
248257 // copy the lowest limb of `tHigh` to the highest limb of `tLow`
@@ -257,7 +266,7 @@ struct ConvertMontReduce : public OpConversionPattern<MontReduceOp> {
257266 // `modulus`.
258267 auto cmp =
259268 b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, tLow, modConst);
260- auto sub = b.create <arith::SubIOp>(tLow, modConst);
269+ auto sub = b.create <arith::SubIOp>(tLow, modConst, noOverflow );
261270 auto result = b.create <arith::SelectOp>(cmp, sub, tLow);
262271
263272 rewriter.replaceOp (op, result);
@@ -438,10 +447,16 @@ struct ConvertAdd : public OpConversionPattern<AddOp> {
438447 ConversionPatternRewriter &rewriter) const override {
439448 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
440449
450+ arith::IntegerOverflowFlags overflowFlag (arith::IntegerOverflowFlags::nuw &
451+ arith::IntegerOverflowFlags::nsw);
452+ auto noOverflow =
453+ arith::IntegerOverflowFlagsAttr::get (b.getContext (), overflowFlag);
454+
441455 auto cmod = b.create <arith::ConstantOp>(modulusAttr (op));
442- auto add = b.create <arith::AddIOp>(adaptor.getLhs (), adaptor.getRhs ());
456+ auto add =
457+ b.create <arith::AddIOp>(adaptor.getLhs (), adaptor.getRhs (), noOverflow);
443458 auto ifge = b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, add, cmod);
444- auto sub = b.create <arith::SubIOp>(add, cmod);
459+ auto sub = b.create <arith::SubIOp>(add, cmod, noOverflow );
445460 auto select = b.create <arith::SelectOp>(ifge, sub, add);
446461
447462 rewriter.replaceOp (op, select);
@@ -459,10 +474,15 @@ struct ConvertSub : public OpConversionPattern<SubOp> {
459474 SubOp op, OpAdaptor adaptor,
460475 ConversionPatternRewriter &rewriter) const override {
461476 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
477+ arith::IntegerOverflowFlags overflowFlag (arith::IntegerOverflowFlags::nuw &
478+ arith::IntegerOverflowFlags::nsw);
479+ auto noOverflow =
480+ arith::IntegerOverflowFlagsAttr::get (b.getContext (), overflowFlag);
462481
463482 auto cmod = b.create <arith::ConstantOp>(modulusAttr (op));
464- auto sub = b.create <arith::SubIOp>(adaptor.getLhs (), adaptor.getRhs ());
465- auto add = b.create <arith::AddIOp>(sub, cmod);
483+ auto sub =
484+ b.create <arith::SubIOp>(adaptor.getLhs (), adaptor.getRhs (), noOverflow);
485+ auto add = b.create <arith::AddIOp>(sub, cmod, noOverflow);
466486 auto ifge = b.create <arith::CmpIOp>(arith::CmpIPredicate::uge,
467487 adaptor.getLhs (), adaptor.getRhs ());
468488 auto select = b.create <arith::SelectOp>(ifge, sub, add);
@@ -507,6 +527,10 @@ struct ConvertMac : public OpConversionPattern<MacOp> {
507527 MacOp op, OpAdaptor adaptor,
508528 ConversionPatternRewriter &rewriter) const override {
509529 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
530+ arith::IntegerOverflowFlags overflowFlag (arith::IntegerOverflowFlags::nuw &
531+ arith::IntegerOverflowFlags::nsw);
532+ auto noOverflow =
533+ arith::IntegerOverflowFlagsAttr::get (b.getContext (), overflowFlag);
510534
511535 auto cmod = b.create <arith::ConstantOp>(modulusAttr (op, true ));
512536 auto x = b.create <arith::ExtUIOp>(modulusType (op, true ),
@@ -516,7 +540,7 @@ struct ConvertMac : public OpConversionPattern<MacOp> {
516540 auto acc = b.create <arith::ExtUIOp>(modulusType (op, true ),
517541 adaptor.getOperands ()[2 ]);
518542 auto mul = b.create <arith::MulIOp>(x, y);
519- auto add = b.create <arith::AddIOp>(mul, acc);
543+ auto add = b.create <arith::AddIOp>(mul, acc, noOverflow );
520544 auto remu = b.create <arith::RemUIOp>(add, cmod);
521545 auto trunc = b.create <arith::TruncIOp>(modulusType (op), remu);
522546
0 commit comments