Skip to content

Commit 7f4c0da

Browse files
committed
feat(mod_arith): disallow overflow for add/sub
This had no performance impact but it should help us debug unwanted overflows.
1 parent ac2f281 commit 7f4c0da

File tree

1 file changed

+33
-9
lines changed

1 file changed

+33
-9
lines changed

zkir/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -158,11 +158,15 @@ struct ConvertReduce : public OpConversionPattern<ReduceOp> {
158158
ReduceOp op, OpAdaptor adaptor,
159159
ConversionPatternRewriter &rewriter) const override {
160160
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
161+
arith::IntegerOverflowFlags overflowFlag(arith::IntegerOverflowFlags::nuw &
162+
arith::IntegerOverflowFlags::nsw);
163+
auto noOverflow =
164+
arith::IntegerOverflowFlagsAttr::get(b.getContext(), overflowFlag);
161165

162166
auto cmod = b.create<arith::ConstantOp>(modulusAttr(op));
163167
// ModArithType ensures cmod can be correctly interpreted as a signed number
164168
auto rems = b.create<arith::RemSIOp>(adaptor.getOperands()[0], cmod);
165-
auto add = b.create<arith::AddIOp>(rems, cmod);
169+
auto add = b.create<arith::AddIOp>(rems, cmod, noOverflow);
166170
// TODO(google/heir #710): better with a subifge
167171
auto remu = b.create<arith::RemUIOp>(add, cmod);
168172
rewriter.replaceOp(op, remu);
@@ -207,6 +211,11 @@ struct ConvertMontReduce : public OpConversionPattern<MontReduceOp> {
207211
TypedAttr limbShiftAttr = b.getIntegerAttr(getElementTypeOrSelf(tLow),
208212
(numLimbs - 1) * limbWidth);
209213

214+
arith::IntegerOverflowFlags overflowFlag(arith::IntegerOverflowFlags::nuw &
215+
arith::IntegerOverflowFlags::nsw);
216+
auto noOverflow =
217+
arith::IntegerOverflowFlagsAttr::get(b.getContext(), overflowFlag);
218+
210219
// Splat the attributes to match the shape of `tLow`.
211220
if (auto shapedType = dyn_cast<ShapedType>(tLow.getType())) {
212221
limbType = shapedType.cloneWith(std::nullopt, limbType);
@@ -238,11 +247,11 @@ struct ConvertMontReduce : public OpConversionPattern<MontReduceOp> {
238247
// Add the product to `T`.
239248
auto sum = b.create<arith::AddUIExtendedOp>(tLow, mN.getLow());
240249
tLow = sum.getSum();
241-
tHigh = b.create<arith::AddIOp>(tHigh, mN.getHigh());
250+
tHigh = b.create<arith::AddIOp>(tHigh, mN.getHigh(), noOverflow);
242251
// Add carry from the `sum` to `tHigh`.
243252
auto carryExt =
244253
b.create<arith::ExtUIOp>(tHigh.getType(), sum.getOverflow());
245-
tHigh = b.create<arith::AddIOp>(tHigh, carryExt);
254+
tHigh = b.create<arith::AddIOp>(tHigh, carryExt, noOverflow);
246255
// Shift right by `limbWidth` to discard the zeroed limb.
247256
tLow = b.create<arith::ShRUIOp>(tLow, limbWidthConst);
248257
// copy the lowest limb of `tHigh` to the highest limb of `tLow`
@@ -257,7 +266,7 @@ struct ConvertMontReduce : public OpConversionPattern<MontReduceOp> {
257266
// `modulus`.
258267
auto cmp =
259268
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, tLow, modConst);
260-
auto sub = b.create<arith::SubIOp>(tLow, modConst);
269+
auto sub = b.create<arith::SubIOp>(tLow, modConst, noOverflow);
261270
auto result = b.create<arith::SelectOp>(cmp, sub, tLow);
262271

263272
rewriter.replaceOp(op, result);
@@ -438,10 +447,16 @@ struct ConvertAdd : public OpConversionPattern<AddOp> {
438447
ConversionPatternRewriter &rewriter) const override {
439448
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
440449

450+
arith::IntegerOverflowFlags overflowFlag(arith::IntegerOverflowFlags::nuw &
451+
arith::IntegerOverflowFlags::nsw);
452+
auto noOverflow =
453+
arith::IntegerOverflowFlagsAttr::get(b.getContext(), overflowFlag);
454+
441455
auto cmod = b.create<arith::ConstantOp>(modulusAttr(op));
442-
auto add = b.create<arith::AddIOp>(adaptor.getLhs(), adaptor.getRhs());
456+
auto add =
457+
b.create<arith::AddIOp>(adaptor.getLhs(), adaptor.getRhs(), noOverflow);
443458
auto ifge = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, add, cmod);
444-
auto sub = b.create<arith::SubIOp>(add, cmod);
459+
auto sub = b.create<arith::SubIOp>(add, cmod, noOverflow);
445460
auto select = b.create<arith::SelectOp>(ifge, sub, add);
446461

447462
rewriter.replaceOp(op, select);
@@ -459,10 +474,15 @@ struct ConvertSub : public OpConversionPattern<SubOp> {
459474
SubOp op, OpAdaptor adaptor,
460475
ConversionPatternRewriter &rewriter) const override {
461476
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
477+
arith::IntegerOverflowFlags overflowFlag(arith::IntegerOverflowFlags::nuw &
478+
arith::IntegerOverflowFlags::nsw);
479+
auto noOverflow =
480+
arith::IntegerOverflowFlagsAttr::get(b.getContext(), overflowFlag);
462481

463482
auto cmod = b.create<arith::ConstantOp>(modulusAttr(op));
464-
auto sub = b.create<arith::SubIOp>(adaptor.getLhs(), adaptor.getRhs());
465-
auto add = b.create<arith::AddIOp>(sub, cmod);
483+
auto sub =
484+
b.create<arith::SubIOp>(adaptor.getLhs(), adaptor.getRhs(), noOverflow);
485+
auto add = b.create<arith::AddIOp>(sub, cmod, noOverflow);
466486
auto ifge = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge,
467487
adaptor.getLhs(), adaptor.getRhs());
468488
auto select = b.create<arith::SelectOp>(ifge, sub, add);
@@ -507,6 +527,10 @@ struct ConvertMac : public OpConversionPattern<MacOp> {
507527
MacOp op, OpAdaptor adaptor,
508528
ConversionPatternRewriter &rewriter) const override {
509529
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
530+
arith::IntegerOverflowFlags overflowFlag(arith::IntegerOverflowFlags::nuw &
531+
arith::IntegerOverflowFlags::nsw);
532+
auto noOverflow =
533+
arith::IntegerOverflowFlagsAttr::get(b.getContext(), overflowFlag);
510534

511535
auto cmod = b.create<arith::ConstantOp>(modulusAttr(op, true));
512536
auto x = b.create<arith::ExtUIOp>(modulusType(op, true),
@@ -516,7 +540,7 @@ struct ConvertMac : public OpConversionPattern<MacOp> {
516540
auto acc = b.create<arith::ExtUIOp>(modulusType(op, true),
517541
adaptor.getOperands()[2]);
518542
auto mul = b.create<arith::MulIOp>(x, y);
519-
auto add = b.create<arith::AddIOp>(mul, acc);
543+
auto add = b.create<arith::AddIOp>(mul, acc, noOverflow);
520544
auto remu = b.create<arith::RemUIOp>(add, cmod);
521545
auto trunc = b.create<arith::TruncIOp>(modulusType(op), remu);
522546

0 commit comments

Comments
 (0)