@@ -12,7 +12,8 @@ namespace mlir::zkir::poly {
1212// Compute the first degree powers of root modulo mod.
1313static void precomputeRoots (APInt root, const APInt &mod, unsigned degree,
1414 SmallVector<APInt> &roots,
15- SmallVector<APInt> &invRoots) {
15+ SmallVector<APInt> &invRoots,
16+ std::optional<IntegerAttr> montgomeryR) {
1617 unsigned kBitWidth = llvm::bit_width (degree);
1718
1819 // Precompute powers-of-two: `powerOfTwo[k]` = `root^(2^k)` mod `mod`.
@@ -22,11 +23,17 @@ static void precomputeRoots(APInt root, const APInt &mod, unsigned degree,
2223 powerOfTwo[k] = mulMod (powerOfTwo[k - 1 ], powerOfTwo[k - 1 ], mod);
2324 }
2425
26+ // Coset factor
27+ APInt coset (mod.getBitWidth (), 1 );
28+ if (montgomeryR) {
29+ coset = montgomeryR->getValue ();
30+ }
31+
2532 // Prepare the result vector.
2633 roots.resize (degree);
2734 invRoots.resize (degree);
28- roots[0 ] = APInt (root. getBitWidth (), 1 ); // Identity element.
29- invRoots[0 ] = APInt (root. getBitWidth (), 1 ); // Identity element.
35+ roots[0 ] = coset;
36+ invRoots[0 ] = coset;
3037
3138 llvm::StdThreadPool pool (llvm::hardware_concurrency ());
3239
@@ -42,8 +49,9 @@ static void precomputeRoots(APInt root, const APInt &mod, unsigned degree,
4249 exp >>= 1 ;
4350 bit++;
4451 }
45- roots[i] = result;
46- invRoots[degree - i] = result;
52+
53+ roots[i] = mulMod (coset, result, mod);
54+ invRoots[degree - i] = mulMod (coset, result, mod);
4755 });
4856 }
4957
@@ -73,13 +81,18 @@ DenseElementsAttr PrimitiveRootAttr::getInvRoots() const {
7381 return getImpl ()->invRoots ;
7482}
7583
84+ zkir::mod_arith::MontgomeryAttr PrimitiveRootAttr::getMontgomery () const {
85+ return getImpl ()->montgomery ;
86+ }
87+
7688namespace detail {
7789
7890PrimitiveRootAttrStorage *PrimitiveRootAttrStorage::construct (
7991 AttributeStorageAllocator &allocator, KeyTy &&key) {
8092 // Extract the root and degree from the key.
8193 zkir::field::PrimeFieldAttr root = std::get<0 >(key);
8294 IntegerAttr degree = std::get<1 >(key);
95+ zkir::mod_arith::MontgomeryAttr montgomery = std::get<2 >(key);
8396
8497 APInt mod = root.getType ().getModulus ().getValue ();
8598 APInt rootVal = root.getValue ().getValue ();
@@ -94,8 +107,11 @@ PrimitiveRootAttrStorage *PrimitiveRootAttrStorage::construct(
94107
95108 // Compute the exponent table.
96109 SmallVector<APInt> roots, invRoots;
97- precomputeRoots (rootVal, mod, degree.getInt (), roots, invRoots);
98-
110+ std::optional<IntegerAttr> montgomeryR;
111+ if (montgomery != zkir::mod_arith::MontgomeryAttr ()) {
112+ montgomeryR = montgomery.getR ();
113+ }
114+ precomputeRoots (rootVal, mod, degree.getInt (), roots, invRoots, montgomeryR);
99115 // Create a ranked tensor type for the exponents attribute.
100116 auto tensorType = RankedTensorType::get (
101117 {degree.getInt ()}, root.getType ().getModulus ().getType ());
@@ -106,7 +122,8 @@ PrimitiveRootAttrStorage *PrimitiveRootAttrStorage::construct(
106122 return new (allocator.allocate <PrimitiveRootAttrStorage>())
107123 PrimitiveRootAttrStorage (std::move (degree), std::move (invDegree),
108124 std::move (root), std::move (invRoot),
109- std::move (rootsAttr), std::move (invRootsAttr));
125+ std::move (rootsAttr), std::move (invRootsAttr),
126+ std::move (montgomery));
110127}
111128
112129} // namespace detail
0 commit comments