Skip to content

Commit 8b12336

Browse files
committed
feat(poly): impl MontMul based NTT
1 parent a3c600b commit 8b12336

File tree

7 files changed

+134
-27
lines changed

7 files changed

+134
-27
lines changed

benchmark/ntt/ntt_benchmark.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66
#root_elem = #field.pf_elem<17220337697351015657950521176323262483320249231368149235373741788599650842711:i256> : !coeff_ty
77
#root = #poly.primitive_root<root=#root_elem, degree=1048576:i256>
88

9+
!mod = !mod_arith.int<21888242871839275222246405745257275088548364400416034343698204186575808495617 : i256>
10+
#mont = #mod_arith.montgomery<!mod>
11+
#root_mont = #poly.primitive_root<root=#root_elem, degree=1048576:i256, montgomery=#mont>
12+
913
func.func @input_generation() -> !poly_ty attributes { llvm.emit_c_interface } {
1014
%c42 = arith.constant 6420 : i256
1115
%full = tensor.splat %c42 : !intt_ty
@@ -25,3 +29,15 @@ func.func @intt(%arg0 : !intt_ty) -> !poly_ty attributes { llvm.emit_c_interface
2529
%1 = poly.intt %0 {root=#root} : !coefft_ty -> !poly_ty
2630
return %1 :!poly_ty
2731
}
32+
33+
func.func @ntt_mont(%arg0 : !poly_ty) -> !intt_ty attributes { llvm.emit_c_interface } {
34+
%0 = poly.ntt %arg0 {root=#root_mont} : !poly_ty -> !coefft_ty
35+
%1 = field.pf.extract %0 : !coefft_ty -> !intt_ty
36+
return %1 : !intt_ty
37+
}
38+
39+
func.func @intt_mont(%arg0 : !intt_ty) -> !poly_ty attributes { llvm.emit_c_interface } {
40+
%0 = field.pf.encapsulate %arg0 : !intt_ty -> !coefft_ty
41+
%1 = poly.intt %0 {root=#root_mont} : !coefft_ty -> !poly_ty
42+
return %1 :!poly_ty
43+
}

benchmark/ntt/ntt_benchmark_test.cc

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@ extern "C" void _mlir_ciface_input_generation(Memref<i256> *output);
1717
extern "C" void _mlir_ciface_ntt(Memref<i256> *output, Memref<i256> *input);
1818
extern "C" void _mlir_ciface_intt(Memref<i256> *output, Memref<i256> *input);
1919

20+
extern "C" void _mlir_ciface_ntt_mont(Memref<i256> *output,
21+
Memref<i256> *input);
22+
extern "C" void _mlir_ciface_intt_mont(Memref<i256> *output,
23+
Memref<i256> *input);
24+
2025
void BM_ntt_benchmark(::benchmark::State &state) {
2126
Memref<i256> input(1, DEGREE);
2227
_mlir_ciface_input_generation(&input);
@@ -61,6 +66,50 @@ void BM_intt_benchmark(::benchmark::State &state) {
6166
// modifying the input. But I am not sure why ;(
6267
BENCHMARK(BM_intt_benchmark)->Iterations(1)->Unit(::benchmark::kSecond);
6368

69+
void BM_ntt_mont_benchmark(::benchmark::State &state) {
70+
Memref<i256> input(1, DEGREE);
71+
_mlir_ciface_input_generation(&input);
72+
73+
Memref<i256> ntt(1, DEGREE);
74+
for (auto _ : state) {
75+
_mlir_ciface_ntt_mont(&ntt, &input);
76+
}
77+
78+
Memref<i256> intt(1, DEGREE);
79+
_mlir_ciface_intt_mont(&intt, &ntt);
80+
81+
for (int i = 0; i < DEGREE; i++) {
82+
for (int j = 0; j < 4; j++) {
83+
EXPECT_EQ(intt.pget(0, i)->limbs[j], input.pget(0, i)->limbs[j]);
84+
}
85+
}
86+
}
87+
88+
BENCHMARK(BM_ntt_mont_benchmark)->Unit(::benchmark::kSecond);
89+
90+
void BM_intt_mont_benchmark(::benchmark::State &state) {
91+
Memref<i256> input(1, DEGREE);
92+
_mlir_ciface_input_generation(&input);
93+
94+
Memref<i256> ntt(1, DEGREE);
95+
_mlir_ciface_ntt_mont(&ntt, &input);
96+
97+
Memref<i256> intt(1, DEGREE);
98+
for (auto _ : state) {
99+
_mlir_ciface_intt_mont(&intt, &ntt);
100+
}
101+
102+
for (int i = 0; i < DEGREE; i++) {
103+
for (int j = 0; j < 4; j++) {
104+
EXPECT_EQ(intt.pget(0, i)->limbs[j], input.pget(0, i)->limbs[j]);
105+
}
106+
}
107+
}
108+
109+
// FIXME(batzor): It fails for more than 1 iteration so it seems like it is
110+
// modifying the input. But I am not sure why ;(
111+
BENCHMARK(BM_intt_mont_benchmark)->Iterations(1)->Unit(::benchmark::kSecond);
112+
64113
} // namespace
65114
} // namespace zkir
66115

@@ -69,9 +118,11 @@ BENCHMARK(BM_intt_benchmark)->Iterations(1)->Unit(::benchmark::kSecond);
69118
// L1 Data 64 KiB
70119
// L1 Instruction 128 KiB
71120
// L2 Unified 4096 KiB (x14)
72-
// Load Average: 22.54, 38.87, 26.62
73-
// -------------------------------------------------------------------------
74-
// Benchmark Time CPU Iterations
75-
// -------------------------------------------------------------------------
76-
// BM_ntt_benchmark 0.321 s 0.320 s 2
77-
// BM_intt_benchmark/iterations:1 0.475 s 0.473 s 1
121+
// Load Average: 9.50, 8.31, 8.95
122+
// ------------------------------------------------------------------------------
123+
// Benchmark Time CPU Iterations
124+
// ------------------------------------------------------------------------------
125+
// BM_ntt_benchmark 0.339 s 0.333 s 2
126+
// BM_intt_benchmark/iterations:1 0.501 s 0.493 s 1
127+
// BM_ntt_mont_benchmark 0.379 s 0.372 s 2
128+
// BM_intt_mont_benchmark/iterations:1 0.510 s 0.504 s 1

zkir/Dialect/Poly/Conversions/PolyToField/PolyToField.cpp

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -202,19 +202,31 @@ struct ConvertFromTensor : public OpConversionPattern<FromTensorOp> {
202202

203203
// Butterfly : Cooley-Tukey
204204
static std::pair<Value, Value> bflyCT(ImplicitLocOpBuilder &b, Value A, Value B,
205-
Value root) {
206-
auto rootB = b.create<field::MulOp>(B, root);
205+
Value root,
206+
mod_arith::MontgomeryAttr montAttr) {
207+
Value rootB;
208+
if (montAttr != mod_arith::MontgomeryAttr()) {
209+
rootB = b.create<field::MontMulOp>(B, root, montAttr);
210+
} else {
211+
rootB = b.create<field::MulOp>(B, root);
212+
}
207213
auto ctPlus = b.create<field::AddOp>(A, rootB);
208214
auto ctMinus = b.create<field::SubOp>(A, rootB);
209215
return {std::move(ctPlus), std::move(ctMinus)};
210216
}
211217

212218
// Butterfly : Gentleman-Sande
213219
static std::pair<Value, Value> bflyGS(ImplicitLocOpBuilder &b, Value A, Value B,
214-
Value root) {
220+
Value root,
221+
mod_arith::MontgomeryAttr montAttr) {
215222
auto gsPlus = b.create<field::AddOp>(A, B);
216223
auto gsMinus = b.create<field::SubOp>(A, B);
217-
auto gsMinusRoot = b.create<field::MulOp>(gsMinus, root);
224+
Value gsMinusRoot;
225+
if (montAttr != mod_arith::MontgomeryAttr()) {
226+
gsMinusRoot = b.create<field::MontMulOp>(gsMinus, root, montAttr);
227+
} else {
228+
gsMinusRoot = b.create<field::MulOp>(gsMinus, root);
229+
}
218230
return {std::move(gsPlus), std::move(gsMinusRoot)};
219231
}
220232

@@ -429,8 +441,10 @@ static Value fastNTT(ImplicitLocOpBuilder &b, PrimitiveRootAttr rootAttr,
429441
// (bflyGS) variant, depending on whether we are performing
430442
// an inverse transform.
431443
// ---------------------------------------------------------
432-
auto bflyResult = kInverse ? bflyGS(b, A, B, root)
433-
: bflyCT(b, A, B, root);
444+
auto bflyResult =
445+
kInverse
446+
? bflyGS(b, A, B, root, rootAttr.getMontgomery())
447+
: bflyCT(b, A, B, root, rootAttr.getMontgomery());
434448

435449
// Write the results back into the coefficient array.
436450
// Insert the "plus" result into `indexA` and the "minus"

zkir/Dialect/Poly/IR/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ cc_library(
2525
":ops_inc_gen",
2626
":types_inc_gen",
2727
"//zkir/Dialect/Field/IR:Field",
28+
"//zkir/Dialect/ModArith/IR:ModArith",
2829
"//zkir/Utils:APIntUtils",
2930
"@llvm-project//llvm:Support",
3031
"@llvm-project//mlir:ArithDialect",

zkir/Dialect/Poly/IR/PolyAttributes.cpp

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ namespace mlir::zkir::poly {
1212
// Compute the first degree powers of root modulo mod.
1313
static void precomputeRoots(APInt root, const APInt &mod, unsigned degree,
1414
SmallVector<APInt> &roots,
15-
SmallVector<APInt> &invRoots) {
15+
SmallVector<APInt> &invRoots,
16+
std::optional<IntegerAttr> montgomeryR) {
1617
unsigned kBitWidth = llvm::bit_width(degree);
1718

1819
// Precompute powers-of-two: `powerOfTwo[k]` = `root^(2^k)` mod `mod`.
@@ -22,11 +23,17 @@ static void precomputeRoots(APInt root, const APInt &mod, unsigned degree,
2223
powerOfTwo[k] = mulMod(powerOfTwo[k - 1], powerOfTwo[k - 1], mod);
2324
}
2425

26+
// Coset factor
27+
APInt coset(mod.getBitWidth(), 1);
28+
if (montgomeryR) {
29+
coset = montgomeryR->getValue();
30+
}
31+
2532
// Prepare the result vector.
2633
roots.resize(degree);
2734
invRoots.resize(degree);
28-
roots[0] = APInt(root.getBitWidth(), 1); // Identity element.
29-
invRoots[0] = APInt(root.getBitWidth(), 1); // Identity element.
35+
roots[0] = coset;
36+
invRoots[0] = coset;
3037

3138
llvm::StdThreadPool pool(llvm::hardware_concurrency());
3239

@@ -42,8 +49,9 @@ static void precomputeRoots(APInt root, const APInt &mod, unsigned degree,
4249
exp >>= 1;
4350
bit++;
4451
}
45-
roots[i] = result;
46-
invRoots[degree - i] = result;
52+
53+
roots[i] = mulMod(coset, result, mod);
54+
invRoots[degree - i] = mulMod(coset, result, mod);
4755
});
4856
}
4957

@@ -73,13 +81,18 @@ DenseElementsAttr PrimitiveRootAttr::getInvRoots() const {
7381
return getImpl()->invRoots;
7482
}
7583

84+
zkir::mod_arith::MontgomeryAttr PrimitiveRootAttr::getMontgomery() const {
85+
return getImpl()->montgomery;
86+
}
87+
7688
namespace detail {
7789

7890
PrimitiveRootAttrStorage *PrimitiveRootAttrStorage::construct(
7991
AttributeStorageAllocator &allocator, KeyTy &&key) {
8092
// Extract the root and degree from the key.
8193
zkir::field::PrimeFieldAttr root = std::get<0>(key);
8294
IntegerAttr degree = std::get<1>(key);
95+
zkir::mod_arith::MontgomeryAttr montgomery = std::get<2>(key);
8396

8497
APInt mod = root.getType().getModulus().getValue();
8598
APInt rootVal = root.getValue().getValue();
@@ -94,8 +107,11 @@ PrimitiveRootAttrStorage *PrimitiveRootAttrStorage::construct(
94107

95108
// Compute the exponent table.
96109
SmallVector<APInt> roots, invRoots;
97-
precomputeRoots(rootVal, mod, degree.getInt(), roots, invRoots);
98-
110+
std::optional<IntegerAttr> montgomeryR;
111+
if (montgomery != zkir::mod_arith::MontgomeryAttr()) {
112+
montgomeryR = montgomery.getR();
113+
}
114+
precomputeRoots(rootVal, mod, degree.getInt(), roots, invRoots, montgomeryR);
99115
// Create a ranked tensor type for the exponents attribute.
100116
auto tensorType = RankedTensorType::get(
101117
{degree.getInt()}, root.getType().getModulus().getType());
@@ -106,7 +122,8 @@ PrimitiveRootAttrStorage *PrimitiveRootAttrStorage::construct(
106122
return new (allocator.allocate<PrimitiveRootAttrStorage>())
107123
PrimitiveRootAttrStorage(std::move(degree), std::move(invDegree),
108124
std::move(root), std::move(invRoot),
109-
std::move(rootsAttr), std::move(invRootsAttr));
125+
std::move(rootsAttr), std::move(invRootsAttr),
126+
std::move(montgomery));
110127
}
111128

112129
} // namespace detail

zkir/Dialect/Poly/IR/PolyAttributes.h

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,31 +6,38 @@
66

77
#include "mlir/Dialect/Polynomial/IR/PolynomialAttributes.h"
88
#include "zkir/Dialect/Field/IR/FieldAttributes.h"
9+
#include "zkir/Dialect/ModArith/IR/ModArithAttributes.h"
910
#include "zkir/Dialect/Poly/IR/PolyDialect.h"
1011
#include "zkir/Dialect/Poly/IR/PolyTypes.h"
1112

1213
namespace mlir::zkir::poly::detail {
1314

1415
struct PrimitiveRootAttrStorage : public AttributeStorage {
15-
using KeyTy = std::tuple<zkir::field::PrimeFieldAttr, IntegerAttr>;
16+
using KeyTy = std::tuple<zkir::field::PrimeFieldAttr, IntegerAttr,
17+
mod_arith::MontgomeryAttr>;
1618
PrimitiveRootAttrStorage(IntegerAttr degree,
1719
zkir::field::PrimeFieldAttr invDegree,
1820
zkir::field::PrimeFieldAttr root,
1921
zkir::field::PrimeFieldAttr invRoot,
20-
DenseElementsAttr roots, DenseElementsAttr invRoots)
22+
DenseElementsAttr roots, DenseElementsAttr invRoots,
23+
zkir::mod_arith::MontgomeryAttr montgomery)
2124
: degree(std::move(degree)),
2225
invDegree(std::move(invDegree)),
2326
root(std::move(root)),
2427
invRoot(std::move(invRoot)),
2528
roots(std::move(roots)),
26-
invRoots(std::move(invRoots)) {}
29+
invRoots(std::move(invRoots)),
30+
montgomery(std::move(montgomery)) {}
2731

28-
KeyTy getAsKey() const { return KeyTy(root, degree); }
32+
KeyTy getAsKey() const { return KeyTy(root, degree, montgomery); }
2933

30-
bool operator==(const KeyTy &key) const { return key == KeyTy(root, degree); }
34+
bool operator==(const KeyTy &key) const {
35+
return key == KeyTy(root, degree, montgomery);
36+
}
3137

3238
static llvm::hash_code hashKey(const KeyTy &key) {
33-
return llvm::hash_combine(std::get<0>(key), std::get<1>(key));
39+
return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
40+
std::get<2>(key));
3441
}
3542

3643
static PrimitiveRootAttrStorage *construct(
@@ -42,6 +49,7 @@ struct PrimitiveRootAttrStorage : public AttributeStorage {
4249
zkir::field::PrimeFieldAttr invRoot;
4350
DenseElementsAttr roots;
4451
DenseElementsAttr invRoots;
52+
zkir::mod_arith::MontgomeryAttr montgomery;
4553
};
4654

4755
} // namespace mlir::zkir::poly::detail

zkir/Dialect/Poly/IR/PolyAttributes.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def Poly_PrimitiveRootAttr: Poly_Attr<"PrimitiveRoot", "primitive_root"> {
5959
DenseElementsAttr getRoots() const;
6060
DenseElementsAttr getInvRoots() const;
6161
}];
62-
let parameters = (ins Field_PrimeFieldAttr:$root, "IntegerAttr":$degree);
62+
let parameters = (ins Field_PrimeFieldAttr:$root, "IntegerAttr":$degree, OptionalParameter<"mod_arith::MontgomeryAttr">:$montgomery);
6363
let assemblyFormat = "`<` struct(params) `>`";
6464
let genStorageClass = 0;
6565
}

0 commit comments

Comments
 (0)