Skip to content

Commit cfcffbf

Browse files
committed
perf(poly): use MontMul for degree inverse in INTT
1 parent d4fdc47 commit cfcffbf

File tree

3 files changed

+20
-10
lines changed

3 files changed

+20
-10
lines changed

benchmark/ntt/ntt_benchmark_test.cc

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,12 @@ BENCHMARK(BM_intt_mont_benchmark)->Iterations(1)->Unit(::benchmark::kSecond);
140140
// L1 Data 64 KiB
141141
// L1 Instruction 128 KiB
142142
// L2 Unified 4096 KiB (x14)
143-
// Load Average: 7.37, 7.53, 7.10
143+
// Load Average: 3.97, 3.11, 2.96
144144
// ------------------------------------------------------------------------------
145145
// Benchmark Time CPU Iterations
146146
// ------------------------------------------------------------------------------
147-
// BM_ntt_benchmark 10.4 s 10.3 s 1
148-
// BM_intt_benchmark/iterations:1 12.1 s 11.4 s 1
149-
// BM_ntt_mont_benchmark 0.201 s 0.197 s 3
150-
// BM_intt_mont_benchmark/iterations:1 1.38 s 1.35 s 1
147+
// BM_ntt_benchmark 10.2 s 10.1 s 1
148+
// BM_intt_benchmark/iterations:1 11.1 s 11.1 s 1
149+
// BM_ntt_mont_benchmark 0.190 s 0.190 s 3
150+
// BM_intt_mont_benchmark/iterations:1 0.316 s 0.304 s 1
151151
// NOLINTEND()

zkir/Dialect/Poly/Conversions/PolyToField/PolyToField.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -494,7 +494,12 @@ static Value fastNTT(ImplicitLocOpBuilder &b, PrimitiveRootAttr rootAttr,
494494
b.create<arith::ConstantOp>(rootAttr.getInvDegree().getValue());
495495
auto degreeInvTensor = b.create<tensor::SplatOp>(invDegreeConst, rootsType);
496496
auto fieldTensor = b.create<field::EncapsulateOp>(modType, degreeInvTensor);
497-
result = b.create<field::MulOp>(result, fieldTensor);
497+
if (rootAttr.getMontgomery() != mod_arith::MontgomeryAttr()) {
498+
result = b.create<field::MontMulOp>(result, fieldTensor,
499+
rootAttr.getMontgomery());
500+
} else {
501+
result = b.create<field::MulOp>(result, fieldTensor);
502+
}
498503
}
499504

500505
return result;

zkir/Dialect/Poly/IR/PolyAttributes.cpp

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,23 +94,28 @@ PrimitiveRootAttrStorage *PrimitiveRootAttrStorage::construct(
9494
IntegerAttr degree = std::get<1>(key);
9595
zkir::mod_arith::MontgomeryAttr montgomery = std::get<2>(key);
9696

97+
std::optional<IntegerAttr> montgomeryR;
98+
if (montgomery != zkir::mod_arith::MontgomeryAttr()) {
99+
montgomeryR = montgomery.getR();
100+
}
101+
97102
APInt mod = root.getType().getModulus().getValue();
98103
APInt rootVal = root.getValue().getValue();
99104
APInt invRootVal = multiplicativeInverse(rootVal, mod);
100105
APInt invDegreeVal = multiplicativeInverse(
101106
degree.getValue().zextOrTrunc(mod.getBitWidth()), mod);
102107

108+
if (montgomeryR) {
109+
invDegreeVal = mulMod(invDegreeVal, montgomeryR->getValue(), mod);
110+
}
111+
103112
field::PrimeFieldAttr invDegree =
104113
field::PrimeFieldAttr::get(root.getType(), invDegreeVal);
105114
zkir::field::PrimeFieldAttr invRoot =
106115
zkir::field::PrimeFieldAttr::get(root.getType(), invRootVal);
107116

108117
// Compute the exponent table.
109118
SmallVector<APInt> roots, invRoots;
110-
std::optional<IntegerAttr> montgomeryR;
111-
if (montgomery != zkir::mod_arith::MontgomeryAttr()) {
112-
montgomeryR = montgomery.getR();
113-
}
114119
precomputeRoots(rootVal, mod, degree.getInt(), roots, invRoots, montgomeryR);
115120
// Create a ranked tensor type for the exponents attribute.
116121
auto tensorType = RankedTensorType::get(

0 commit comments

Comments
 (0)