Skip to content

Commit c4e1f90

Browse files
committed
refac(mod_arith): use MulUIExtendedOp for MontReduce
With this change, we avoid i512 operations (i.e. when adding `mN`) and instead uses extended operations for multiplying/adding i256 values.
1 parent e9243ad commit c4e1f90

File tree

5 files changed

+75
-70
lines changed

5 files changed

+75
-70
lines changed

tests/Dialect/ModArith/mod_arith_runner.mlir

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,22 +35,26 @@ func.func @test_lower_inverse() {
3535
#Fq_mont = #mod_arith.montgomery<!Fq>
3636

3737
func.func @test_lower_mont_reduce() {
38-
%p = arith.constant 3723 : i512
39-
%p_mont = mod_arith.mont_reduce %p {montgomery=#Fq_mont} : i512 -> !Fq
38+
%p = arith.constant 2188824287183927522224640574525727508854836440041603434369820418657580849561 : i256
39+
%zero = arith.constant 0 : i256
40+
// `pR` is `p` << 256 so just give `p` as `high` and set `low` to 0
41+
%p_mont = mod_arith.mont_reduce %zero, %p {montgomery=#Fq_mont} : i256 -> !Fq
4042

4143
%2 = mod_arith.extract %p_mont : !Fq -> i256
42-
%3 = vector.from_elements %2 : vector<1xi256>
43-
%4 = vector.bitcast %3 : vector<1xi256> to vector<8xi32>
44-
%mem = memref.alloc() : memref<8xi32>
44+
// check if mod_arith.mont_reduce(pR) == p
45+
%true = arith.cmpi eq, %2, %p : i256
46+
%trueExt = arith.extui %true : i1 to i32
47+
%3 = vector.from_elements %trueExt : vector<1xi32>
48+
%mem = memref.alloc() : memref<1xi32>
4549
%idx_0 = arith.constant 0 : index
46-
vector.store %4, %mem[%idx_0] : memref<8xi32>, vector<8xi32>
50+
vector.store %3, %mem[%idx_0] : memref<1xi32>, vector<1xi32>
4751

48-
%U = memref.cast %mem : memref<8xi32> to memref<*xi32>
52+
%U = memref.cast %mem : memref<1xi32> to memref<*xi32>
4953
func.call @printMemrefI32(%U) : (memref<*xi32>) -> ()
5054
return
5155
}
5256

53-
// CHECK_TEST_MONT_REDUCE: [-1635059004, -1772563805, -2074116324, -156049350, 156881531, -524227392, -1359481138, 438709201]
57+
// CHECK_TEST_MONT_REDUCE: [1]
5458

5559
func.func @test_lower_mont_mul() {
5660
%p = mod_arith.constant 17221657567640823606390383439573883756117969501024189775361 : !Fq

zkir/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp

Lines changed: 51 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,8 @@ struct ConvertMontReduce : public OpConversionPattern<MontReduceOp> {
163163

164164
// `T` is the operand (e.g. the result of a multiplication, twice the
165165
// bitwidth of modulus).
166-
Value T = adaptor.getOperands()[0];
166+
Value tLow = adaptor.getOperands()[0];
167+
Value tHigh = adaptor.getOperands()[1];
167168

168169
// Extract Montgomery constants: `nPrime` and `modulus`.
169170
IntegerAttr nPrimeAttr = op.getMontgomeryAttr().getNPrime();
@@ -178,41 +179,49 @@ struct ConvertMontReduce : public OpConversionPattern<MontReduceOp> {
178179
const unsigned limbWidth = APInt::APINT_BITS_PER_WORD;
179180
unsigned numLimbs = (modBitWidth + limbWidth - 1) / limbWidth;
180181

181-
// Arith operations require the operands to be of same bit width
182-
Value modExt = b.create<arith::ExtUIOp>(T.getType(), mod);
183-
184182
// Prepare constants for limb operations.
185-
Value limbWidthConst =
186-
b.create<arith::ConstantOp>(b.getIntegerAttr(T.getType(), limbWidth));
187-
188-
// Because the number of limbs (numLimbs) is known at compile time, we can
189-
// unroll the loop as a straight-line chain of operations. Let `u` be the
190-
// current working value, initially `T`.
191-
Value u = T;
183+
auto limbWidthConst = b.create<arith::ConstantOp>(
184+
b.getIntegerAttr(tLow.getType(), limbWidth));
185+
auto lowLimbMask = b.create<arith::ConstantOp>(b.getIntegerAttr(
186+
tLow.getType(), APInt::getAllOnes(limbWidth).zext(modBitWidth)));
187+
auto lowLimbShift = b.create<arith::ConstantOp>(
188+
b.getIntegerAttr(tLow.getType(), (numLimbs - 1) * limbWidth));
189+
190+
// Because the number of limbs (`numLimbs`) is known at compile time, we can
191+
// unroll the loop as a straight-line chain of operations.
192192
for (unsigned i = 0; i < numLimbs; ++i) {
193-
// Extract the current lowest limb: `u` (mod `base`)
194-
Value lowerLimb = b.create<arith::TruncIOp>(nPrimeAttr.getType(), u);
193+
// Extract the current lowest limb: `tLow` (mod `base`)
194+
auto lowerLimb = b.create<arith::TruncIOp>(nPrimeAttr.getType(), tLow);
195195
// Compute `m` = `lowerLimb` * `nPrime` (mod `base`)
196-
Value m = b.create<arith::MulIOp>(lowerLimb, nPrime);
196+
auto m = b.create<arith::MulIOp>(lowerLimb, nPrime);
197197
// Compute `m` * `N` , where `N` is modulus
198-
Value mExt = b.create<arith::ExtUIOp>(T.getType(), m);
199-
Value mN = b.create<arith::MulIOp>(modExt, mExt);
200-
// Add the product to `u`.
201-
Value sum = b.create<arith::AddIOp>(u, mN);
198+
auto mExt = b.create<arith::ExtUIOp>(mod.getType(), m);
199+
auto mN = b.create<arith::MulUIExtendedOp>(mod, mExt);
200+
// Add the product to `T`.
201+
auto sum = b.create<arith::AddUIExtendedOp>(tLow, mN.getLow());
202+
tLow = sum.getSum();
203+
tHigh = b.create<arith::AddIOp>(tHigh, mN.getHigh());
204+
// Add carry from the `sum` to `tHigh`.
205+
auto carryExt =
206+
b.create<arith::ExtUIOp>(tHigh.getType(), sum.getOverflow());
207+
tHigh = b.create<arith::AddIOp>(tHigh, carryExt);
202208
// Shift right by `limbWidth` to discard the zeroed limb.
203-
u = b.create<arith::ShRUIOp>(sum, limbWidthConst);
209+
tLow = b.create<arith::ShRUIOp>(tLow, limbWidthConst);
210+
// copy the lowest limb of `tHigh` to the highest limb of `tLow`
211+
Value tHighLimb = b.create<arith::AndIOp>(tHigh, lowLimbMask);
212+
tHighLimb = b.create<arith::ShLIOp>(tHighLimb, lowLimbShift);
213+
tLow = b.create<arith::OrIOp>(tLow, tHighLimb);
214+
// Shift right `tHigh` by `limbWidth`.
215+
tHigh = b.create<arith::ShRUIOp>(tHigh, limbWidthConst);
204216
}
205217

206-
// Final conditional subtraction: if (`u_final` >= modulus) then subtract
207-
// modulus.
208-
Value cmp = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, u, modExt);
209-
Value sub = b.create<arith::SubIOp>(u, modExt);
210-
Value result = b.create<arith::SelectOp>(cmp, sub, u);
218+
// Final conditional subtraction: if (`tLow` >= `modulus`) then subtract
219+
// `modulus`.
220+
auto cmp = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, tLow, mod);
221+
auto sub = b.create<arith::SubIOp>(tLow, mod);
222+
auto result = b.create<arith::SelectOp>(cmp, sub, tLow);
211223

212-
// Truncate the result to the bitwidth of the modulus.
213-
Value truncated = b.create<arith::TruncIOp>(mod.getType(), result);
214-
215-
rewriter.replaceOp(op, truncated);
224+
rewriter.replaceOp(op, result);
216225
return success();
217226
}
218227
};
@@ -227,20 +236,15 @@ struct ConvertToMont : public OpConversionPattern<ToMontOp> {
227236
ToMontOp op, OpAdaptor adaptor,
228237
ConversionPatternRewriter &rewriter) const override {
229238
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
230-
IntegerAttr rSquaredAttr = op.getMontgomery().getRSquared();
231239

232240
// x * R = REDC(x * rSquared)
233241
auto rSquared =
234242
b.create<arith::ConstantOp>(op.getMontgomery().getRSquared());
235-
auto extended = b.create<arith::ExtUIOp>(rSquaredAttr.getType(),
236-
adaptor.getOperands()[0]);
237-
238-
// TODO(batzor): Use extended multiplication to avoid full length
239-
// multiplication. Now we extend both operands to 2x the bitwidth of the
240-
// modulus to avoid the truncation in multiplication.
241-
auto product = b.create<arith::MulIOp>(extended, rSquared);
242-
auto reduced = b.create<MontReduceOp>(op.getResult().getType(), product,
243-
op.getMontgomery());
243+
auto product =
244+
b.create<arith::MulUIExtendedOp>(adaptor.getOperands()[0], rSquared);
245+
auto reduced =
246+
b.create<MontReduceOp>(op.getResult().getType(), product.getLow(),
247+
product.getHigh(), op.getMontgomery());
244248
rewriter.replaceOp(op, reduced);
245249
return success();
246250
}
@@ -258,10 +262,11 @@ struct ConvertFromMont : public OpConversionPattern<FromMontOp> {
258262
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
259263

260264
// x * R⁻¹ = REDC(x)
261-
auto extended = b.create<arith::ExtUIOp>(
262-
op.getMontgomery().getRSquared().getType(), adaptor.getOperands()[0]);
263-
auto reduced = b.create<MontReduceOp>(op.getResult().getType(), extended,
264-
op.getMontgomery());
265+
auto zeroHighConst = b.create<arith::ConstantOp>(
266+
IntegerAttr::get(op.getMontgomery().getRSquared().getType(), 0));
267+
auto reduced = b.create<MontReduceOp>(op.getResult().getType(),
268+
adaptor.getOperands()[0],
269+
zeroHighConst, op.getMontgomery());
265270
rewriter.replaceOp(op, reduced);
266271
return success();
267272
}
@@ -492,13 +497,11 @@ struct ConvertMontMul : public OpConversionPattern<MontMulOp> {
492497
ConversionPatternRewriter &rewriter) const override {
493498
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
494499

495-
auto lhs =
496-
b.create<arith::ExtUIOp>(modulusType(op, true), adaptor.getLhs());
497-
auto rhs =
498-
b.create<arith::ExtUIOp>(modulusType(op, true), adaptor.getRhs());
499-
auto mul = b.create<arith::MulIOp>(lhs, rhs);
500+
auto mul =
501+
b.create<arith::MulUIExtendedOp>(adaptor.getLhs(), adaptor.getRhs());
500502
auto reduced = b.create<mod_arith::MontReduceOp>(
501-
getResultModArithType(op), mul.getResult(), op.getMontgomery());
503+
getResultModArithType(op), mul.getLow(), mul.getHigh(),
504+
op.getMontgomery());
502505

503506
rewriter.replaceOp(op, reduced);
504507
return success();

zkir/Dialect/ModArith/IR/ModArithAttributes.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,12 +61,9 @@ MontgomeryAttrStorage *MontgomeryAttrStorage::construct(
6161
// Construct the `rInvAttr` with the bitwidth of the modulus
6262
IntegerAttr rInvAttr = IntegerAttr::get(modType.getModulus().getType(), rInv);
6363

64-
// Construct the `rSquaredAttr` with 2x the bitwidth of the modulus
65-
// NOTE(batzor): It is currently 2x bitwidth due to how the `ToMontOp` works
66-
// but should be later changed.
67-
IntegerAttr rSquaredAttr = IntegerAttr::get(
68-
IntegerType::get(modType.getContext(), modulus.getBitWidth() * 2),
69-
rSquared.zext(modulus.getBitWidth() * 2));
64+
// Construct the `rSquaredAttr` with the bitwidth of the modulus
65+
IntegerAttr rSquaredAttr =
66+
IntegerAttr::get(modType.getModulus().getType(), rSquared);
7067

7168
// Construct the `nPrimeAttr` with the bitwidth `w`
7269
IntegerAttr nPrimeAttr = IntegerAttr::get(

zkir/Dialect/ModArith/IR/ModArithDialect.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,14 @@ LogicalResult ReduceOp::verify() {
108108
}
109109

110110
LogicalResult MontReduceOp::verify() {
111-
IntegerType integerType = getOperandIntegerType(*this);
111+
IntegerType integerType =
112+
cast<IntegerType>(getElementTypeOrSelf(this->getLow().getType()));
112113
ModArithType modArithType = getResultModArithType(*this);
113114
unsigned intWidth = integerType.getWidth();
114115
unsigned modWidth = modArithType.getModulus().getValue().getBitWidth();
115-
if (intWidth != 2 * modWidth)
116-
return emitOpError() << "Expected operand width to be " << 2 * modWidth
117-
<< ", but got " << intWidth
118-
<< " while modulus width is " << modWidth << ".";
116+
if (intWidth != modWidth)
117+
return emitOpError() << "Expected operand width to be " << modWidth
118+
<< ", but got " << intWidth << " instead.";
119119
return success();
120120
}
121121

zkir/Dialect/ModArith/IR/ModArithOps.td

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def ModArith_ReduceOp : ModArith_Op<"reduce", [Pure, ElementwiseMappable, SameOp
123123
let assemblyFormat = "operands attr-dict `:` type($output)";
124124
}
125125

126-
def ModArith_MontReduceOp : ModArith_Op<"mont_reduce", [Pure, ElementwiseMappable]> {
126+
def ModArith_MontReduceOp : ModArith_Op<"mont_reduce", [Pure, ElementwiseMappable, SameTypeOperands]> {
127127
let summary = "applies montgomery reduction to the integer of twice the modulus bitwidth";
128128

129129
let description = [{
@@ -137,12 +137,13 @@ def ModArith_MontReduceOp : ModArith_Op<"mont_reduce", [Pure, ElementwiseMappabl
137137
}];
138138

139139
let arguments = (ins
140-
SignlessIntegerLike:$input,
140+
SignlessIntegerLike:$low,
141+
SignlessIntegerLike:$high,
141142
ModArith_MontgomeryAttr:$montgomery
142143
);
143144
let results = (outs ModArithLike:$output);
144145
let hasVerifier = 1;
145-
let assemblyFormat = "operands attr-dict `:` type($input) `->` type($output)";
146+
let assemblyFormat = "operands attr-dict `:` type($low) `->` type($output)";
146147
}
147148

148149
def ModArith_ToMontOp : ModArith_Op<"to_mont", [Pure, ElementwiseMappable, SameOperandsAndResultType]> {

0 commit comments

Comments
 (0)