@@ -163,7 +163,8 @@ struct ConvertMontReduce : public OpConversionPattern<MontReduceOp> {
163163
164164 // `T` is the operand (e.g. the result of a multiplication, twice the
165165 // bitwidth of modulus).
166- Value T = adaptor.getOperands ()[0 ];
166+ Value tLow = adaptor.getOperands ()[0 ];
167+ Value tHigh = adaptor.getOperands ()[1 ];
167168
168169 // Extract Montgomery constants: `nPrime` and `modulus`.
169170 IntegerAttr nPrimeAttr = op.getMontgomeryAttr ().getNPrime ();
@@ -178,41 +179,49 @@ struct ConvertMontReduce : public OpConversionPattern<MontReduceOp> {
178179 const unsigned limbWidth = APInt::APINT_BITS_PER_WORD;
179180 unsigned numLimbs = (modBitWidth + limbWidth - 1 ) / limbWidth;
180181
181- // Arith operations require the operands to be of same bit width
182- Value modExt = b.create <arith::ExtUIOp>(T.getType (), mod);
183-
184182 // Prepare constants for limb operations.
185- Value limbWidthConst =
186- b.create <arith::ConstantOp>(b.getIntegerAttr (T.getType (), limbWidth));
187-
188- // Because the number of limbs (numLimbs) is known at compile time, we can
189- // unroll the loop as a straight-line chain of operations. Let `u` be the
190- // current working value, initially `T`.
191- Value u = T;
183+ auto limbWidthConst = b.create <arith::ConstantOp>(
184+ b.getIntegerAttr (tLow.getType (), limbWidth));
185+ auto lowLimbMask = b.create <arith::ConstantOp>(b.getIntegerAttr (
186+ tLow.getType (), APInt::getAllOnes (limbWidth).zext (modBitWidth)));
187+ auto lowLimbShift = b.create <arith::ConstantOp>(
188+ b.getIntegerAttr (tLow.getType (), (numLimbs - 1 ) * limbWidth));
189+
190+ // Because the number of limbs (`numLimbs`) is known at compile time, we can
191+ // unroll the loop as a straight-line chain of operations.
192192 for (unsigned i = 0 ; i < numLimbs; ++i) {
193- // Extract the current lowest limb: `u ` (mod `base`)
194- Value lowerLimb = b.create <arith::TruncIOp>(nPrimeAttr.getType (), u );
193+ // Extract the current lowest limb: `tLow ` (mod `base`)
194+ auto lowerLimb = b.create <arith::TruncIOp>(nPrimeAttr.getType (), tLow );
195195 // Compute `m` = `lowerLimb` * `nPrime` (mod `base`)
196- Value m = b.create <arith::MulIOp>(lowerLimb, nPrime);
196+ auto m = b.create <arith::MulIOp>(lowerLimb, nPrime);
197197 // Compute `m` * `N` , where `N` is modulus
198- Value mExt = b.create <arith::ExtUIOp>(T.getType (), m);
199- Value mN = b.create <arith::MulIOp>(modExt, mExt );
200- // Add the product to `u`.
201- Value sum = b.create <arith::AddIOp>(u, mN );
198+ auto mExt = b.create <arith::ExtUIOp>(mod.getType (), m);
199+ auto mN = b.create <arith::MulUIExtendedOp>(mod, mExt );
200+ // Add the product to `T`.
201+ auto sum = b.create <arith::AddUIExtendedOp>(tLow, mN .getLow ());
202+ tLow = sum.getSum ();
203+ tHigh = b.create <arith::AddIOp>(tHigh, mN .getHigh ());
204+ // Add carry from the `sum` to `tHigh`.
205+ auto carryExt =
206+ b.create <arith::ExtUIOp>(tHigh.getType (), sum.getOverflow ());
207+ tHigh = b.create <arith::AddIOp>(tHigh, carryExt);
202208 // Shift right by `limbWidth` to discard the zeroed limb.
203- u = b.create <arith::ShRUIOp>(sum, limbWidthConst);
209+ tLow = b.create <arith::ShRUIOp>(tLow, limbWidthConst);
210+ // copy the lowest limb of `tHigh` to the highest limb of `tLow`
211+ Value tHighLimb = b.create <arith::AndIOp>(tHigh, lowLimbMask);
212+ tHighLimb = b.create <arith::ShLIOp>(tHighLimb, lowLimbShift);
213+ tLow = b.create <arith::OrIOp>(tLow, tHighLimb);
214+ // Shift right `tHigh` by `limbWidth`.
215+ tHigh = b.create <arith::ShRUIOp>(tHigh, limbWidthConst);
204216 }
205217
206- // Final conditional subtraction: if (`u_final ` >= modulus) then subtract
207- // modulus.
208- Value cmp = b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, u, modExt );
209- Value sub = b.create <arith::SubIOp>(u, modExt );
210- Value result = b.create <arith::SelectOp>(cmp, sub, u );
218+ // Final conditional subtraction: if (`tLow ` >= ` modulus` ) then subtract
219+ // ` modulus` .
220+ auto cmp = b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, tLow, mod );
221+ auto sub = b.create <arith::SubIOp>(tLow, mod );
222+ auto result = b.create <arith::SelectOp>(cmp, sub, tLow );
211223
212- // Truncate the result to the bitwidth of the modulus.
213- Value truncated = b.create <arith::TruncIOp>(mod.getType (), result);
214-
215- rewriter.replaceOp (op, truncated);
224+ rewriter.replaceOp (op, result);
216225 return success ();
217226 }
218227};
@@ -227,20 +236,15 @@ struct ConvertToMont : public OpConversionPattern<ToMontOp> {
227236 ToMontOp op, OpAdaptor adaptor,
228237 ConversionPatternRewriter &rewriter) const override {
229238 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
230- IntegerAttr rSquaredAttr = op.getMontgomery ().getRSquared ();
231239
232240 // x * R = REDC(x * rSquared)
233241 auto rSquared =
234242 b.create <arith::ConstantOp>(op.getMontgomery ().getRSquared ());
235- auto extended = b.create <arith::ExtUIOp>(rSquaredAttr.getType (),
236- adaptor.getOperands ()[0 ]);
237-
238- // TODO(batzor): Use extended multiplication to avoid full length
239- // multiplication. Now we extend both operands to 2x the bitwidth of the
240- // modulus to avoid the truncation in multiplication.
241- auto product = b.create <arith::MulIOp>(extended, rSquared);
242- auto reduced = b.create <MontReduceOp>(op.getResult ().getType (), product,
243- op.getMontgomery ());
243+ auto product =
244+ b.create <arith::MulUIExtendedOp>(adaptor.getOperands ()[0 ], rSquared);
245+ auto reduced =
246+ b.create <MontReduceOp>(op.getResult ().getType (), product.getLow (),
247+ product.getHigh (), op.getMontgomery ());
244248 rewriter.replaceOp (op, reduced);
245249 return success ();
246250 }
@@ -258,10 +262,11 @@ struct ConvertFromMont : public OpConversionPattern<FromMontOp> {
258262 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
259263
260264 // x * R⁻¹ = REDC(x)
261- auto extended = b.create <arith::ExtUIOp>(
262- op.getMontgomery ().getRSquared ().getType (), adaptor.getOperands ()[0 ]);
263- auto reduced = b.create <MontReduceOp>(op.getResult ().getType (), extended,
264- op.getMontgomery ());
265+ auto zeroHighConst = b.create <arith::ConstantOp>(
266+ IntegerAttr::get (op.getMontgomery ().getRSquared ().getType (), 0 ));
267+ auto reduced = b.create <MontReduceOp>(op.getResult ().getType (),
268+ adaptor.getOperands ()[0 ],
269+ zeroHighConst, op.getMontgomery ());
265270 rewriter.replaceOp (op, reduced);
266271 return success ();
267272 }
@@ -492,13 +497,11 @@ struct ConvertMontMul : public OpConversionPattern<MontMulOp> {
492497 ConversionPatternRewriter &rewriter) const override {
493498 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
494499
495- auto lhs =
496- b.create <arith::ExtUIOp>(modulusType (op, true ), adaptor.getLhs ());
497- auto rhs =
498- b.create <arith::ExtUIOp>(modulusType (op, true ), adaptor.getRhs ());
499- auto mul = b.create <arith::MulIOp>(lhs, rhs);
500+ auto mul =
501+ b.create <arith::MulUIExtendedOp>(adaptor.getLhs (), adaptor.getRhs ());
500502 auto reduced = b.create <mod_arith::MontReduceOp>(
501- getResultModArithType (op), mul.getResult (), op.getMontgomery ());
503+ getResultModArithType (op), mul.getLow (), mul.getHigh (),
504+ op.getMontgomery ());
502505
503506 rewriter.replaceOp (op, reduced);
504507 return success ();
0 commit comments