Skip to content

Commit d4fdc47

Browse files
committed
fix(mod_arith): handle tensor type in MontReduce
1 parent 31f9688 commit d4fdc47

File tree

1 file changed

+36
-17
lines changed

1 file changed

+36
-17
lines changed

zkir/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -185,10 +185,8 @@ struct ConvertMontReduce : public OpConversionPattern<MontReduceOp> {
185185
Value tHigh = adaptor.getOperands()[1];
186186

187187
// Extract Montgomery constants: `nPrime` and `modulus`.
188-
IntegerAttr nPrimeAttr = op.getMontgomeryAttr().getNPrime();
189-
Value nPrime = b.create<arith::ConstantOp>(nPrimeAttr);
188+
TypedAttr nPrimeAttr = op.getMontgomeryAttr().getNPrime();
190189
TypedAttr modAttr = modulusAttr(op);
191-
Value mod = b.create<arith::ConstantOp>(modAttr);
192190

193191
// Retrieve the modulus bitwidth.
194192
unsigned modBitWidth = cast<IntegerType>(modAttr.getType()).getWidth();
@@ -198,23 +196,43 @@ struct ConvertMontReduce : public OpConversionPattern<MontReduceOp> {
198196
unsigned numLimbs = (modBitWidth + limbWidth - 1) / limbWidth;
199197

200198
// Prepare constants for limb operations.
201-
auto limbWidthConst = b.create<arith::ConstantOp>(
202-
b.getIntegerAttr(tLow.getType(), limbWidth));
203-
auto lowLimbMask = b.create<arith::ConstantOp>(b.getIntegerAttr(
204-
tLow.getType(), APInt::getAllOnes(limbWidth).zext(modBitWidth)));
205-
auto lowLimbShift = b.create<arith::ConstantOp>(
206-
b.getIntegerAttr(tLow.getType(), (numLimbs - 1) * limbWidth));
199+
Type limbType = nPrimeAttr.getType();
200+
TypedAttr limbWidthAttr =
201+
b.getIntegerAttr(getElementTypeOrSelf(tLow), limbWidth);
202+
TypedAttr limbMaskAttr =
203+
b.getIntegerAttr(getElementTypeOrSelf(tLow),
204+
APInt::getAllOnes(limbWidth).zext(modBitWidth));
205+
TypedAttr limbShiftAttr = b.getIntegerAttr(getElementTypeOrSelf(tLow),
206+
(numLimbs - 1) * limbWidth);
207+
208+
// Splat the attributes to match the shape of `tLow`.
209+
if (auto shapedType = dyn_cast<ShapedType>(tLow.getType())) {
210+
limbType = shapedType.cloneWith(std::nullopt, limbType);
211+
nPrimeAttr =
212+
SplatElementsAttr::get(cast<ShapedType>(limbType), nPrimeAttr);
213+
limbWidthAttr = SplatElementsAttr::get(shapedType, limbWidthAttr);
214+
limbMaskAttr = SplatElementsAttr::get(shapedType, limbMaskAttr);
215+
limbShiftAttr = SplatElementsAttr::get(shapedType, limbShiftAttr);
216+
modAttr = SplatElementsAttr::get(shapedType, modAttr);
217+
}
218+
219+
// Create constants for the Montgomery reduction.
220+
auto nPrimeConst = b.create<arith::ConstantOp>(nPrimeAttr);
221+
auto limbWidthConst = b.create<arith::ConstantOp>(limbWidthAttr);
222+
auto limbMaskConst = b.create<arith::ConstantOp>(limbMaskAttr);
223+
auto limbShiftConst = b.create<arith::ConstantOp>(limbShiftAttr);
224+
auto modConst = b.create<arith::ConstantOp>(modAttr);
207225

208226
// Because the number of limbs (`numLimbs`) is known at compile time, we can
209227
// unroll the loop as a straight-line chain of operations.
210228
for (unsigned i = 0; i < numLimbs; ++i) {
211229
// Extract the current lowest limb: `tLow` (mod `base`)
212-
auto lowerLimb = b.create<arith::TruncIOp>(nPrimeAttr.getType(), tLow);
230+
auto lowerLimb = b.create<arith::TruncIOp>(limbType, tLow);
213231
// Compute `m` = `lowerLimb` * `nPrime` (mod `base`)
214-
auto m = b.create<arith::MulIOp>(lowerLimb, nPrime);
232+
auto m = b.create<arith::MulIOp>(lowerLimb, nPrimeConst);
215233
// Compute `m` * `N` , where `N` is modulus
216-
auto mExt = b.create<arith::ExtUIOp>(mod.getType(), m);
217-
auto mN = b.create<arith::MulUIExtendedOp>(mod, mExt);
234+
auto mExt = b.create<arith::ExtUIOp>(tLow.getType(), m);
235+
auto mN = b.create<arith::MulUIExtendedOp>(modConst, mExt);
218236
// Add the product to `T`.
219237
auto sum = b.create<arith::AddUIExtendedOp>(tLow, mN.getLow());
220238
tLow = sum.getSum();
@@ -226,17 +244,18 @@ struct ConvertMontReduce : public OpConversionPattern<MontReduceOp> {
226244
// Shift right by `limbWidth` to discard the zeroed limb.
227245
tLow = b.create<arith::ShRUIOp>(tLow, limbWidthConst);
228246
// copy the lowest limb of `tHigh` to the highest limb of `tLow`
229-
Value tHighLimb = b.create<arith::AndIOp>(tHigh, lowLimbMask);
230-
tHighLimb = b.create<arith::ShLIOp>(tHighLimb, lowLimbShift);
247+
Value tHighLimb = b.create<arith::AndIOp>(tHigh, limbMaskConst);
248+
tHighLimb = b.create<arith::ShLIOp>(tHighLimb, limbShiftConst);
231249
tLow = b.create<arith::OrIOp>(tLow, tHighLimb);
232250
// Shift right `tHigh` by `limbWidth`.
233251
tHigh = b.create<arith::ShRUIOp>(tHigh, limbWidthConst);
234252
}
235253

236254
// Final conditional subtraction: if (`tLow` >= `modulus`) then subtract
237255
// `modulus`.
238-
auto cmp = b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, tLow, mod);
239-
auto sub = b.create<arith::SubIOp>(tLow, mod);
256+
auto cmp =
257+
b.create<arith::CmpIOp>(arith::CmpIPredicate::uge, tLow, modConst);
258+
auto sub = b.create<arith::SubIOp>(tLow, modConst);
240259
auto result = b.create<arith::SelectOp>(cmp, sub, tLow);
241260

242261
rewriter.replaceOp(op, result);

0 commit comments

Comments
 (0)