Skip to content

Commit 6a20807

Browse files
committed
refac(poly): use RootOfUnityAttr
1 parent 43bd78c commit 6a20807

File tree

8 files changed

+48
-43
lines changed

8 files changed

+48
-43
lines changed

benchmark/ntt/ntt_benchmark.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
!memref_ty = memref<1048576x!coeff_ty>
44

55
#root_elem = #field.pf_elem<17220337697351015657950521176323262483320249231368149235373741788599650842711:i256> : !coeff_ty
6-
#root = #poly.primitive_root<root=#root_elem, degree=1048576:i256>
6+
#root_of_unity = #field.root_of_unity<#root_elem, 1048576:i256>
7+
#root = #poly.primitive_root<root_of_unity=#root_of_unity>
78

89
!mod = !mod_arith.int<21888242871839275222246405745257275088548364400416034343698204186575808495617 : i256>
910
#mont = #mod_arith.montgomery<!mod>
10-
#root_mont = #poly.primitive_root<root=#root_elem, degree=1048576:i256, montgomery=#mont>
11+
#root_mont = #poly.primitive_root<root_of_unity=#root_of_unity, montgomery=#mont>
1112

1213
func.func @ntt(%arg0 : !memref_ty) attributes { llvm.emit_c_interface } {
1314
%t = bufferization.to_tensor %arg0 restrict writable : !memref_ty to !coefft_ty

tests/Dialect/Poly/poly_canonicalization.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
!coeff_ty = !field.pf<7681:i32>
44
#elem = #field.pf_elem<3383:i32> : !coeff_ty
5-
#root = #poly.primitive_root<root=#elem, degree=4 :i32>
5+
#root_of_unity = #field.root_of_unity<#elem, 4:i32>
6+
#root = #poly.primitive_root<root_of_unity=#root_of_unity>
67
!poly_ty = !poly.polynomial<!coeff_ty, 3>
78
!tensor_ty = tensor<4x!coeff_ty>
89

tests/Dialect/Poly/poly_ntt_runner.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
!coeff_ty = !field.pf<7681:i32>
77
#elem = #field.pf_elem<3383:i32> : !coeff_ty
88
#inv_elem = #field.pf_elem<4298:i32> : !coeff_ty
9-
#root = #poly.primitive_root<root=#elem, degree=4 :i32>
9+
#root_of_unity = #field.root_of_unity<#elem, 4:i32>
10+
#root = #poly.primitive_root<root_of_unity=#root_of_unity>
1011
!poly_ty = !poly.polynomial<!coeff_ty, 3>
1112

1213
func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }

tests/Dialect/Poly/poly_syntax.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
!poly_ty2 = !poly.polynomial<!PF2, 32>
77
#uni_poly = #poly.univariate_polynomial<x**6 + 1> : !poly_ty2
88
#elem = #field.pf_elem<2:i32> : !PF1
9-
#root = #poly.primitive_root<root=#elem, degree=3>
9+
#root_of_unity = #field.root_of_unity<#elem, 3:i32>
10+
#root = #poly.primitive_root<root_of_unity=#root_of_unity>
1011

1112
// CHECK-LABEL: @test_poly_syntax
1213
func.func @test_poly_syntax() {

tests/Dialect/Poly/poly_to_field.mlir

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
!PF1 = !field.pf<7:i255>
44
!poly_ty1 = !poly.polynomial<!PF1, 3>
55
!poly_ty2 = !poly.polynomial<!PF1, 4>
6-
#elem = #field.pf_elem<2:i255> : !PF1
7-
#root = #poly.primitive_root<root=#elem, degree=4>
6+
#elem = #field.pf_elem<6:i255> : !PF1
7+
#root_of_unity = #field.root_of_unity<#elem, 2:i255>
8+
#root = #poly.primitive_root<root_of_unity=#root_of_unity>
89

910
// FIXME(batzor): without this line, the test will fail with the following error:
1011
// LLVM ERROR: can't create Attribute 'mlir::polynomial::IntPolynomialAttr' because storage uniquer isn't initialized: the dialect was likely not loaded, or the attribute wasn't added with addAttributes<...>() in the Dialect::initialize() method.
@@ -68,16 +69,16 @@ func.func @test_lower_from_tensor(%t : tensor<4x!PF1>) -> !poly_ty1 {
6869

6970
// CHECK-LABEL: @test_lower_ntt
7071
// CHECK-SAME: (%[[INPUT:.*]]: [[T:.*]]) -> [[T]] {
71-
func.func @test_lower_ntt(%input : tensor<4x!PF1>) -> tensor<4x!PF1> {
72+
func.func @test_lower_ntt(%input : tensor<2x!PF1>) -> tensor<2x!PF1> {
7273
// CHECK-NOT: poly.ntt
73-
%res = poly.ntt %input {root=#root} : tensor<4x!PF1>
74-
return %res: tensor<4x!PF1>
74+
%res = poly.ntt %input {root=#root} : tensor<2x!PF1>
75+
return %res: tensor<2x!PF1>
7576
}
7677

7778
// CHECK-LABEL: @test_lower_intt
7879
// CHECK-SAME: (%[[INPUT:.*]]: [[T:.*]]) -> [[P:.*]] {
79-
func.func @test_lower_intt(%input : tensor<4x!PF1>) -> tensor<4x!PF1> {
80+
func.func @test_lower_intt(%input : tensor<2x!PF1>) -> tensor<2x!PF1> {
8081
// CHECK-NOT: poly.intt
81-
%res = poly.intt %input {root=#root} : tensor<4x!PF1>
82-
return %res: tensor<4x!PF1>
82+
%res = poly.intt %input {root=#root} : tensor<2x!PF1>
83+
return %res: tensor<2x!PF1>
8384
}

zkir/Dialect/Poly/IR/PolyAttributes.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,21 @@ static void precomputeRoots(APInt root, const APInt &mod, unsigned degree,
5959
pool.wait();
6060
}
6161

62+
field::RootOfUnityAttr PrimitiveRootAttr::getRootOfUnity() const {
63+
return getImpl()->rootOfUnity;
64+
}
65+
6266
field::PrimeFieldAttr PrimitiveRootAttr::getRoot() const {
63-
return getImpl()->root;
67+
return getImpl()->rootOfUnity.getRoot();
6468
}
6569

6670
field::PrimeFieldAttr PrimitiveRootAttr::getInvRoot() const {
6771
return getImpl()->invRoot;
6872
}
6973

70-
IntegerAttr PrimitiveRootAttr::getDegree() const { return getImpl()->degree; }
74+
IntegerAttr PrimitiveRootAttr::getDegree() const {
75+
return getImpl()->rootOfUnity.getDegree();
76+
}
7177

7278
field::PrimeFieldAttr PrimitiveRootAttr::getInvDegree() const {
7379
return getImpl()->invDegree;
@@ -90,9 +96,10 @@ namespace detail {
9096
PrimitiveRootAttrStorage *PrimitiveRootAttrStorage::construct(
9197
AttributeStorageAllocator &allocator, KeyTy &&key) {
9298
// Extract the root and degree from the key.
93-
field::PrimeFieldAttr root = std::get<0>(key);
94-
IntegerAttr degree = std::get<1>(key);
95-
mod_arith::MontgomeryAttr montgomery = std::get<2>(key);
99+
field::RootOfUnityAttr rootOfUnity = std::get<0>(key);
100+
field::PrimeFieldAttr root = rootOfUnity.getRoot();
101+
IntegerAttr degree = rootOfUnity.getDegree();
102+
mod_arith::MontgomeryAttr montgomery = std::get<1>(key);
96103

97104
std::optional<IntegerAttr> montgomeryR;
98105
if (montgomery != mod_arith::MontgomeryAttr()) {
@@ -125,10 +132,9 @@ PrimitiveRootAttrStorage *PrimitiveRootAttrStorage::construct(
125132
DenseElementsAttr rootsAttr = DenseElementsAttr::get(tensorType, roots);
126133
DenseElementsAttr invRootsAttr = DenseElementsAttr::get(tensorType, invRoots);
127134
return new (allocator.allocate<PrimitiveRootAttrStorage>())
128-
PrimitiveRootAttrStorage(std::move(degree), std::move(invDegree),
129-
std::move(root), std::move(invRoot),
130-
std::move(rootsAttr), std::move(invRootsAttr),
131-
std::move(montgomery));
135+
PrimitiveRootAttrStorage(std::move(rootOfUnity), std::move(invDegree),
136+
std::move(invRoot), std::move(rootsAttr),
137+
std::move(invRootsAttr), std::move(montgomery));
132138
}
133139

134140
} // namespace detail

zkir/Dialect/Poly/IR/PolyAttributes.h

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -13,39 +13,34 @@
1313
namespace mlir::zkir::poly::detail {
1414

1515
struct PrimitiveRootAttrStorage : public AttributeStorage {
16-
using KeyTy = std::tuple<field::PrimeFieldAttr, IntegerAttr,
17-
mod_arith::MontgomeryAttr>;
18-
PrimitiveRootAttrStorage(IntegerAttr degree,
16+
using KeyTy = std::tuple<field::RootOfUnityAttr, mod_arith::MontgomeryAttr>;
17+
PrimitiveRootAttrStorage(field::RootOfUnityAttr rootOfUnity,
1918
field::PrimeFieldAttr invDegree,
20-
field::PrimeFieldAttr root,
2119
field::PrimeFieldAttr invRoot,
2220
DenseElementsAttr roots, DenseElementsAttr invRoots,
2321
mod_arith::MontgomeryAttr montgomery)
24-
: degree(std::move(degree)),
22+
: rootOfUnity(std::move(rootOfUnity)),
2523
invDegree(std::move(invDegree)),
26-
root(std::move(root)),
2724
invRoot(std::move(invRoot)),
2825
roots(std::move(roots)),
2926
invRoots(std::move(invRoots)),
3027
montgomery(std::move(montgomery)) {}
3128

32-
KeyTy getAsKey() const { return KeyTy(root, degree, montgomery); }
29+
KeyTy getAsKey() const { return KeyTy(rootOfUnity, montgomery); }
3330

3431
bool operator==(const KeyTy &key) const {
35-
return key == KeyTy(root, degree, montgomery);
32+
return key == KeyTy(rootOfUnity, montgomery);
3633
}
3734

3835
static llvm::hash_code hashKey(const KeyTy &key) {
39-
return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
40-
std::get<2>(key));
36+
return llvm::hash_combine(std::get<0>(key), std::get<1>(key));
4137
}
4238

4339
static PrimitiveRootAttrStorage *construct(
4440
AttributeStorageAllocator &allocator, KeyTy &&key);
4541

46-
IntegerAttr degree;
42+
field::RootOfUnityAttr rootOfUnity;
4743
field::PrimeFieldAttr invDegree;
48-
field::PrimeFieldAttr root;
4944
field::PrimeFieldAttr invRoot;
5045
DenseElementsAttr roots;
5146
DenseElementsAttr invRoots;

zkir/Dialect/Poly/IR/PolyAttributes.td

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,29 +35,28 @@ def Poly_UnivariatePolyAttr : Poly_Attr<"UnivariatePoly", "univariate_polynomial
3535
}
3636

3737
def Poly_PrimitiveRootAttr: Poly_Attr<"PrimitiveRoot", "primitive_root"> {
38-
let summary = "an attribute containing an integer and its degree as a root of unity";
38+
let summary = "an attribute containing a primitive root of unity and its twiddles";
3939
let description = [{
40-
A primitive root attribute stores an integer root `value` and an integer
41-
`degree`, corresponding to a primitive root of unity of the given degree in
42-
an unspecified ring.
43-
44-
This is used as an attribute on `poly.ntt` and `poly.intt` ops
45-
to specify the root of unity used in lowering the transform.
40+
A primitive root attribute stores a primitive root of unity specified by the `RootOfUnityAttr`
41+
attribute and also some other necessary precomputed values for (I)NTT operations.
4642

4743
Example:
4844
```
4945
!PF1 = !field.pf<7:i32>
5046
#elem = #field.elem<2> : !PF1
51-
#root = #poly.primitive_root<root=%generator, degree=3>
47+
#root_of_unity = #field.root_of_unity<root=#elem, degree=3>
48+
#root = #poly.primitive_root<root_of_unity=#root_of_unity>
5249
```
5350
}];
5451
let extraClassDeclaration = [{
52+
IntegerAttr getDegree() const;
53+
field::PrimeFieldAttr getRoot() const;
5554
field::PrimeFieldAttr getInvRoot() const;
5655
field::PrimeFieldAttr getInvDegree() const;
5756
DenseElementsAttr getRoots() const;
5857
DenseElementsAttr getInvRoots() const;
5958
}];
60-
let parameters = (ins Field_PrimeFieldAttr:$root, "IntegerAttr":$degree, OptionalParameter<"mod_arith::MontgomeryAttr">:$montgomery);
59+
let parameters = (ins Field_RootOfUnityAttr:$root_of_unity, OptionalParameter<"mod_arith::MontgomeryAttr">:$montgomery);
6160
let assemblyFormat = "`<` struct(params) `>`";
6261
let genStorageClass = 0;
6362
}

0 commit comments

Comments
 (0)