@@ -185,10 +185,8 @@ struct ConvertMontReduce : public OpConversionPattern<MontReduceOp> {
185185 Value tHigh = adaptor.getOperands ()[1 ];
186186
187187 // Extract Montgomery constants: `nPrime` and `modulus`.
188- IntegerAttr nPrimeAttr = op.getMontgomeryAttr ().getNPrime ();
189- Value nPrime = b.create <arith::ConstantOp>(nPrimeAttr);
188+ TypedAttr nPrimeAttr = op.getMontgomeryAttr ().getNPrime ();
190189 TypedAttr modAttr = modulusAttr (op);
191- Value mod = b.create <arith::ConstantOp>(modAttr);
192190
193191 // Retrieve the modulus bitwidth.
194192 unsigned modBitWidth = cast<IntegerType>(modAttr.getType ()).getWidth ();
@@ -198,23 +196,43 @@ struct ConvertMontReduce : public OpConversionPattern<MontReduceOp> {
198196 unsigned numLimbs = (modBitWidth + limbWidth - 1 ) / limbWidth;
199197
200198 // Prepare constants for limb operations.
201- auto limbWidthConst = b.create <arith::ConstantOp>(
202- b.getIntegerAttr (tLow.getType (), limbWidth));
203- auto lowLimbMask = b.create <arith::ConstantOp>(b.getIntegerAttr (
204- tLow.getType (), APInt::getAllOnes (limbWidth).zext (modBitWidth)));
205- auto lowLimbShift = b.create <arith::ConstantOp>(
206- b.getIntegerAttr (tLow.getType (), (numLimbs - 1 ) * limbWidth));
199+ Type limbType = nPrimeAttr.getType ();
200+ TypedAttr limbWidthAttr =
201+ b.getIntegerAttr (getElementTypeOrSelf (tLow), limbWidth);
202+ TypedAttr limbMaskAttr =
203+ b.getIntegerAttr (getElementTypeOrSelf (tLow),
204+ APInt::getAllOnes (limbWidth).zext (modBitWidth));
205+ TypedAttr limbShiftAttr = b.getIntegerAttr (getElementTypeOrSelf (tLow),
206+ (numLimbs - 1 ) * limbWidth);
207+
208+ // Splat the attributes to match the shape of `tLow`.
209+ if (auto shapedType = dyn_cast<ShapedType>(tLow.getType ())) {
210+ limbType = shapedType.cloneWith (std::nullopt , limbType);
211+ nPrimeAttr =
212+ SplatElementsAttr::get (cast<ShapedType>(limbType), nPrimeAttr);
213+ limbWidthAttr = SplatElementsAttr::get (shapedType, limbWidthAttr);
214+ limbMaskAttr = SplatElementsAttr::get (shapedType, limbMaskAttr);
215+ limbShiftAttr = SplatElementsAttr::get (shapedType, limbShiftAttr);
216+ modAttr = SplatElementsAttr::get (shapedType, modAttr);
217+ }
218+
219+ // Create constants for the Montgomery reduction.
220+ auto nPrimeConst = b.create <arith::ConstantOp>(nPrimeAttr);
221+ auto limbWidthConst = b.create <arith::ConstantOp>(limbWidthAttr);
222+ auto limbMaskConst = b.create <arith::ConstantOp>(limbMaskAttr);
223+ auto limbShiftConst = b.create <arith::ConstantOp>(limbShiftAttr);
224+ auto modConst = b.create <arith::ConstantOp>(modAttr);
207225
208226 // Because the number of limbs (`numLimbs`) is known at compile time, we can
209227 // unroll the loop as a straight-line chain of operations.
210228 for (unsigned i = 0 ; i < numLimbs; ++i) {
211229 // Extract the current lowest limb: `tLow` (mod `base`)
212- auto lowerLimb = b.create <arith::TruncIOp>(nPrimeAttr. getType () , tLow);
230+ auto lowerLimb = b.create <arith::TruncIOp>(limbType , tLow);
213231 // Compute `m` = `lowerLimb` * `nPrime` (mod `base`)
214- auto m = b.create <arith::MulIOp>(lowerLimb, nPrime );
232+ auto m = b.create <arith::MulIOp>(lowerLimb, nPrimeConst );
215233 // Compute `m` * `N` , where `N` is modulus
216- auto mExt = b.create <arith::ExtUIOp>(mod .getType (), m);
217- auto mN = b.create <arith::MulUIExtendedOp>(mod , mExt );
234+ auto mExt = b.create <arith::ExtUIOp>(tLow .getType (), m);
235+ auto mN = b.create <arith::MulUIExtendedOp>(modConst , mExt );
218236 // Add the product to `T`.
219237 auto sum = b.create <arith::AddUIExtendedOp>(tLow, mN .getLow ());
220238 tLow = sum.getSum ();
@@ -226,17 +244,18 @@ struct ConvertMontReduce : public OpConversionPattern<MontReduceOp> {
226244 // Shift right by `limbWidth` to discard the zeroed limb.
227245 tLow = b.create <arith::ShRUIOp>(tLow, limbWidthConst);
228246 // copy the lowest limb of `tHigh` to the highest limb of `tLow`
229- Value tHighLimb = b.create <arith::AndIOp>(tHigh, lowLimbMask );
230- tHighLimb = b.create <arith::ShLIOp>(tHighLimb, lowLimbShift );
247+ Value tHighLimb = b.create <arith::AndIOp>(tHigh, limbMaskConst );
248+ tHighLimb = b.create <arith::ShLIOp>(tHighLimb, limbShiftConst );
231249 tLow = b.create <arith::OrIOp>(tLow, tHighLimb);
232250 // Shift right `tHigh` by `limbWidth`.
233251 tHigh = b.create <arith::ShRUIOp>(tHigh, limbWidthConst);
234252 }
235253
236254 // Final conditional subtraction: if (`tLow` >= `modulus`) then subtract
237255 // `modulus`.
238- auto cmp = b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, tLow, mod);
239- auto sub = b.create <arith::SubIOp>(tLow, mod);
256+ auto cmp =
257+ b.create <arith::CmpIOp>(arith::CmpIPredicate::uge, tLow, modConst);
258+ auto sub = b.create <arith::SubIOp>(tLow, modConst);
240259 auto result = b.create <arith::SelectOp>(cmp, sub, tLow);
241260
242261 rewriter.replaceOp (op, result);
0 commit comments