Skip to content

Commit 238a5db

Browse files
authored
Merge pull request #19 from a41-official/feat/add-root-of-unity
feat: add root of unity attribute
2 parents 4af58a4 + 6a20807 commit 238a5db

File tree

21 files changed

+131
-85
lines changed

21 files changed

+131
-85
lines changed

benchmark/field/mul_benchmark_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
namespace zkir {
66
namespace {
77

8-
using ::zkir::benchmark::Memref;
8+
using benchmark::Memref;
99

1010
struct i256 {
1111
uint64_t limbs[4]; // 4 x 64 = 256 bits

benchmark/ntt/ntt_benchmark.mlir

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,12 @@
33
!memref_ty = memref<1048576x!coeff_ty>
44

55
#root_elem = #field.pf_elem<17220337697351015657950521176323262483320249231368149235373741788599650842711:i256> : !coeff_ty
6-
#root = #poly.primitive_root<root=#root_elem, degree=1048576:i256>
6+
#root_of_unity = #field.root_of_unity<#root_elem, 1048576:i256>
7+
#root = #poly.primitive_root<root_of_unity=#root_of_unity>
78

89
!mod = !mod_arith.int<21888242871839275222246405745257275088548364400416034343698204186575808495617 : i256>
910
#mont = #mod_arith.montgomery<!mod>
10-
#root_mont = #poly.primitive_root<root=#root_elem, degree=1048576:i256, montgomery=#mont>
11+
#root_mont = #poly.primitive_root<root_of_unity=#root_of_unity, montgomery=#mont>
1112

1213
func.func @ntt(%arg0 : !memref_ty) attributes { llvm.emit_c_interface } {
1314
%t = bufferization.to_tensor %arg0 restrict writable : !memref_ty to !coefft_ty

benchmark/ntt/ntt_benchmark_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
namespace zkir {
1111
namespace {
1212

13-
using ::zkir::benchmark::Memref;
13+
using benchmark::Memref;
1414

1515
using i256 = benchmark::BigInt<4>;
1616

tests/Dialect/Field/prime_field_to_mod_arith.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
// RUN: zkir-opt -prime-field-to-mod-arith --split-input-file %s | FileCheck %s --enable-var-scope
22
!PF1 = !field.pf<3:i32>
33
!PFv = tensor<4x!PF1>
4-
#elem = #field.pf_elem<31:i32> : !PF1
4+
#root_elem = #field.pf_elem<2:i32> : !PF1
5+
#root = #field.root_of_unity<#root_elem, 2>
56

67
!mod = !mod_arith.int<3 : i32>
78
#mont = #mod_arith.montgomery<!mod>

tests/Dialect/Poly/poly_canonicalization.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
!coeff_ty = !field.pf<7681:i32>
44
#elem = #field.pf_elem<3383:i32> : !coeff_ty
5-
#root = #poly.primitive_root<root=#elem, degree=4 :i32>
5+
#root_of_unity = #field.root_of_unity<#elem, 4:i32>
6+
#root = #poly.primitive_root<root_of_unity=#root_of_unity>
67
!poly_ty = !poly.polynomial<!coeff_ty, 3>
78
!tensor_ty = tensor<4x!coeff_ty>
89

tests/Dialect/Poly/poly_ntt_runner.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
!coeff_ty = !field.pf<7681:i32>
77
#elem = #field.pf_elem<3383:i32> : !coeff_ty
88
#inv_elem = #field.pf_elem<4298:i32> : !coeff_ty
9-
#root = #poly.primitive_root<root=#elem, degree=4 :i32>
9+
#root_of_unity = #field.root_of_unity<#elem, 4:i32>
10+
#root = #poly.primitive_root<root_of_unity=#root_of_unity>
1011
!poly_ty = !poly.polynomial<!coeff_ty, 3>
1112

1213
func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }

tests/Dialect/Poly/poly_syntax.mlir

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
!poly_ty2 = !poly.polynomial<!PF2, 32>
77
#uni_poly = #poly.univariate_polynomial<x**6 + 1> : !poly_ty2
88
#elem = #field.pf_elem<2:i32> : !PF1
9-
#root = #poly.primitive_root<root=#elem, degree=3>
9+
#root_of_unity = #field.root_of_unity<#elem, 3:i32>
10+
#root = #poly.primitive_root<root_of_unity=#root_of_unity>
1011

1112
// CHECK-LABEL: @test_poly_syntax
1213
func.func @test_poly_syntax() {

tests/Dialect/Poly/poly_to_field.mlir

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
!PF1 = !field.pf<7:i255>
44
!poly_ty1 = !poly.polynomial<!PF1, 3>
55
!poly_ty2 = !poly.polynomial<!PF1, 4>
6-
#elem = #field.pf_elem<2:i255> : !PF1
7-
#root = #poly.primitive_root<root=#elem, degree=4>
6+
#elem = #field.pf_elem<6:i255> : !PF1
7+
#root_of_unity = #field.root_of_unity<#elem, 2:i255>
8+
#root = #poly.primitive_root<root_of_unity=#root_of_unity>
89

910
// FIXME(batzor): without this line, the test will fail with the following error:
1011
// LLVM ERROR: can't create Attribute 'mlir::polynomial::IntPolynomialAttr' because storage uniquer isn't initialized: the dialect was likely not loaded, or the attribute wasn't added with addAttributes<...>() in the Dialect::initialize() method.
@@ -68,16 +69,16 @@ func.func @test_lower_from_tensor(%t : tensor<4x!PF1>) -> !poly_ty1 {
6869

6970
// CHECK-LABEL: @test_lower_ntt
7071
// CHECK-SAME: (%[[INPUT:.*]]: [[T:.*]]) -> [[T]] {
71-
func.func @test_lower_ntt(%input : tensor<4x!PF1>) -> tensor<4x!PF1> {
72+
func.func @test_lower_ntt(%input : tensor<2x!PF1>) -> tensor<2x!PF1> {
7273
// CHECK-NOT: poly.ntt
73-
%res = poly.ntt %input {root=#root} : tensor<4x!PF1>
74-
return %res: tensor<4x!PF1>
74+
%res = poly.ntt %input {root=#root} : tensor<2x!PF1>
75+
return %res: tensor<2x!PF1>
7576
}
7677

7778
// CHECK-LABEL: @test_lower_intt
7879
// CHECK-SAME: (%[[INPUT:.*]]: [[T:.*]]) -> [[P:.*]] {
79-
func.func @test_lower_intt(%input : tensor<4x!PF1>) -> tensor<4x!PF1> {
80+
func.func @test_lower_intt(%input : tensor<2x!PF1>) -> tensor<2x!PF1> {
8081
// CHECK-NOT: poly.intt
81-
%res = poly.intt %input {root=#root} : tensor<4x!PF1>
82-
return %res: tensor<4x!PF1>
82+
%res = poly.intt %input {root=#root} : tensor<2x!PF1>
83+
return %res: tensor<2x!PF1>
8384
}

zkir/Dialect/Arith/Conversions/ArithToModArith/ArithToModArith.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,10 @@ namespace mlir::zkir::arith {
2929
static mod_arith::ModArithType convertArithType(Type type) {
3030
auto modulusBitSize = static_cast<uint64_t>(type.getIntOrFloatBitWidth());
3131
auto modulus = (1L << (modulusBitSize - 1L));
32-
auto newType = mlir::IntegerType::get(type.getContext(), modulusBitSize + 1);
32+
auto newType = IntegerType::get(type.getContext(), modulusBitSize + 1);
3333

3434
return mod_arith::ModArithType::get(type.getContext(),
35-
mlir::IntegerAttr::get(newType, modulus));
35+
IntegerAttr::get(newType, modulus));
3636
}
3737

3838
static Type convertArithLikeType(ShapedType type) {
@@ -129,14 +129,14 @@ struct ConvertExtUI : public OpConversionPattern<mlir::arith::ExtUIOp> {
129129
}
130130
};
131131

132-
struct ConvertLoadOp : public OpConversionPattern<mlir::memref::LoadOp> {
132+
struct ConvertLoadOp : public OpConversionPattern<memref::LoadOp> {
133133
explicit ConvertLoadOp(MLIRContext *context)
134-
: OpConversionPattern<mlir::memref::LoadOp>(context) {}
134+
: OpConversionPattern<memref::LoadOp>(context) {}
135135

136136
using OpConversionPattern::OpConversionPattern;
137137

138138
LogicalResult matchAndRewrite(
139-
::mlir::memref::LoadOp op, OpAdaptor adaptor,
139+
memref::LoadOp op, OpAdaptor adaptor,
140140
ConversionPatternRewriter &rewriter) const override {
141141
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
142142

zkir/Dialect/Field/IR/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ cc_library(
2323
":ops_inc_gen",
2424
":types_inc_gen",
2525
"//zkir/Dialect/ModArith/IR:ModArith",
26+
"//zkir/Utils:APIntUtils",
2627
"//zkir/Utils:OpUtils",
2728
"@llvm-project//llvm:Support",
2829
"@llvm-project//mlir:ArithDialect",

0 commit comments

Comments
 (0)