Skip to content

Commit c606314

Browse files
authored
Merge pull request #13 from a41-official/feat/poly-ntt
feat: introduce poly NTT and INTT
2 parents 182fcc2 + abac84a commit c606314

26 files changed

+1014
-12
lines changed

tests/poly_canonicalization.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// RUN: zkir-opt %s -canonicalize | FileCheck %s
2+
3+
!coeff_ty = !field.pf<7681:i32>
4+
#elem = #field.pf_elem<3383:i32> : !coeff_ty
5+
#root = #poly.primitive_root<root=#elem, degree=4 :i32>
6+
!poly_ty = !poly.polynomial<!coeff_ty, 3>
7+
!tensor_ty = tensor<4x!coeff_ty>
8+
9+
// CHECK-LABEL: @test_canonicalize_intt_after_ntt
10+
// CHECK: (%[[P:.*]]: [[T:.*]]) -> [[T]]
11+
func.func @test_canonicalize_intt_after_ntt(%p0 : !poly_ty) -> !poly_ty {
12+
// CHECK-NOT: poly.ntt
13+
// CHECK-NOT: poly.intt
14+
// CHECK: %[[RESULT:.*]] = poly.add %[[P]], %[[P]] : [[T]]
15+
%t0 = poly.ntt %p0 {root=#root} : !poly_ty -> !tensor_ty
16+
%p1 = poly.intt %t0 {root=#root} : !tensor_ty -> !poly_ty
17+
%p2 = poly.add %p1, %p1 : !poly_ty
18+
// CHECK: return %[[RESULT]] : [[T]]
19+
return %p2 : !poly_ty
20+
}
21+
22+
// CHECK-LABEL: @test_canonicalize_ntt_after_intt
23+
// CHECK: (%[[X:.*]]: [[T:.*]]) -> [[T]]
24+
func.func @test_canonicalize_ntt_after_intt(%t0 : !tensor_ty) -> !tensor_ty {
25+
// CHECK-NOT: poly.intt
26+
// CHECK-NOT: poly.ntt
27+
// CHECK: %[[RESULT:.*]] = field.pf.add %[[X]], %[[X]] : [[T]]
28+
%p0 = poly.intt %t0 {root=#root} : !tensor_ty -> !poly_ty
29+
%t1 = poly.ntt %p0 {root=#root} : !poly_ty -> !tensor_ty
30+
%t2 = field.pf.add %t1, %t1 : !tensor_ty
31+
// CHECK: return %[[RESULT]] : [[T]]
32+
return %t2 : !tensor_ty
33+
}

tests/poly_ntt_runner.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// RUN: zkir-opt %s -poly-to-llvm \
2+
// RUN: | mlir-runner -e test_poly_ntt -entry-point-result=void \
3+
// RUN: --shared-libs="%mlir_lib_dir/libmlir_runner_utils%shlibext" > %t
4+
// RUN: FileCheck %s --check-prefix=CHECK_TEST_POLY_NTT < %t
5+
6+
!coeff_ty = !field.pf<7681:i32>
7+
#elem = #field.pf_elem<3383:i32> : !coeff_ty
8+
#inv_elem = #field.pf_elem<4298:i32> : !coeff_ty
9+
#root = #poly.primitive_root<root=#elem, degree=4 :i32>
10+
!poly_ty = !poly.polynomial<!coeff_ty, 3>
11+
12+
func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }
13+
14+
func.func @test_poly_ntt() {
15+
%coeffsRaw = arith.constant dense<[1,2,3,4]> : tensor<4xi32>
16+
%coeffs = field.pf.encapsulate %coeffsRaw : tensor<4xi32> -> tensor<4x!coeff_ty>
17+
%poly = poly.from_tensor %coeffs : tensor<4x!coeff_ty> -> !poly_ty
18+
%res = poly.ntt %poly {root=#root} : !poly_ty -> tensor<4x!coeff_ty>
19+
20+
%extract = field.pf.extract %res : tensor<4x!coeff_ty> -> tensor<4xi32>
21+
%1 = bufferization.to_memref %extract : tensor<4xi32> to memref<4xi32>
22+
%U = memref.cast %1 : memref<4xi32> to memref<*xi32>
23+
func.call @printMemrefI32(%U) : (memref<*xi32>) -> ()
24+
25+
%intt = poly.intt %res {root=#root} : tensor<4x!coeff_ty> -> !poly_ty
26+
%res2 = poly.to_tensor %intt : !poly_ty -> tensor<4x!coeff_ty>
27+
%extract2 = field.pf.extract %res2 : tensor<4x!coeff_ty> -> tensor<4xi32>
28+
%2= bufferization.to_memref %extract2 : tensor<4xi32> to memref<4xi32>
29+
%U2 = memref.cast %2 : memref<4xi32> to memref<*xi32>
30+
func.call @printMemrefI32(%U2) : (memref<*xi32>) -> ()
31+
return
32+
}
33+
// CHECK_TEST_POLY_NTT: [10, 913, 7679, 6764]
34+
// CHECK_TEST_POLY_NTT: [1, 2, 3, 4]

tests/poly_to_field.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
!PF1 = !field.pf<7:i255>
44
!poly_ty1 = !poly.polynomial<!PF1, 3>
55
!poly_ty2 = !poly.polynomial<!PF1, 4>
6+
#elem = #field.pf_elem<2:i255> : !PF1
7+
#root = #poly.primitive_root<root=#elem, degree=4>
68

79
// FIXME(batzor): without this line, the test will fail with the following error:
810
// LLVM ERROR: can't create Attribute 'mlir::polynomial::IntPolynomialAttr' because storage uniquer isn't initialized: the dialect was likely not loaded, or the attribute wasn't added with addAttributes<...>() in the Dialect::initialize() method.
@@ -63,3 +65,19 @@ func.func @test_lower_from_tensor(%t : tensor<4x!PF1>) -> !poly_ty1 {
6365
%res = poly.from_tensor %t : tensor<4x!PF1> -> !poly_ty1
6466
return %res : !poly_ty1
6567
}
68+
69+
// CHECK-LABEL: @test_lower_ntt
70+
// CHECK-SAME: (%[[INPUT:.*]]: [[P:.*]]) -> [[T:.*]] {
71+
func.func @test_lower_ntt(%input : !poly_ty1) -> tensor<4x!PF1> {
72+
// CHECK-NOT: poly.ntt
73+
%res = poly.ntt %input {root=#root} : !poly_ty1 -> tensor<4x!PF1>
74+
return %res: tensor<4x!PF1>
75+
}
76+
77+
// CHECK-LABEL: @test_lower_intt
78+
// CHECK-SAME: (%[[INPUT:.*]]: [[T:.*]]) -> [[P:.*]] {
79+
func.func @test_lower_intt(%input : tensor<4x!PF1>) -> !poly_ty1 {
80+
// CHECK-NOT: poly.intt
81+
%res = poly.intt %input {root=#root} : tensor<4x!PF1> -> !poly_ty1
82+
return %res: !poly_ty1
83+
}

tests/polynomial_syntax.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
!poly_ty1 = !poly.polynomial<!PF1, 32>
66
!poly_ty2 = !poly.polynomial<!PF2, 32>
77
#uni_poly = #poly.univariate_polynomial<x**6 + 1> : !poly_ty2
8-
#elem = #field.pf_elem<2> : !PF1
8+
#elem = #field.pf_elem<2:i32> : !PF1
99
#root = #poly.primitive_root<root=#elem, degree=3>
1010

1111
// CHECK-LABEL: @test_poly_syntax

tests/prime_field_to_mod_arith.mlir

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
// RUN: zkir-opt -prime-field-to-mod-arith --split-input-file %s | FileCheck %s --enable-var-scope
22
!PF1 = !field.pf<3:i32>
33
!PFv = tensor<4x!PF1>
4-
#elem = #field.pf_elem<31> : !PF1
4+
#elem = #field.pf_elem<31:i32> : !PF1
55

66
// CHECK-LABEL: @test_lower_constant
77
// CHECK-SAME: () -> [[T:.*]] {
@@ -32,6 +32,26 @@ func.func @test_lower_encapsulate_vec(%lhs : tensor<4xi32>) -> tensor<4x!PF1> {
3232
return %res : tensor<4x!PF1>
3333
}
3434

35+
// CHECK-LABEL: @test_lower_extract
36+
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[F:.*]] {
37+
func.func @test_lower_extract(%lhs : !PF1) -> i32 {
38+
// CHECK-NOT: field.pf.extract
39+
// CHECK: %[[RES:.*]] = mod_arith.extract %[[LHS]] : [[T]] -> [[F]]
40+
%res = field.pf.extract %lhs : !PF1 -> i32
41+
// CHECK: return %[[RES]] : [[F]]
42+
return %res : i32
43+
}
44+
45+
// CHECK-LABEL: @test_lower_extract_vec
46+
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[TF:.*]] {
47+
func.func @test_lower_extract_vec(%lhs : tensor<4x!PF1>) -> tensor<4xi32> {
48+
// CHECK-NOT: field.pf.extract
49+
// CHECK: %[[RES:.*]] = mod_arith.extract %[[LHS]] : [[T]] -> [[TF]]
50+
%res = field.pf.extract %lhs : tensor<4x!PF1> -> tensor<4xi32>
51+
// CHECK: return %[[RES]] : [[TF]]
52+
return %res : tensor<4xi32>
53+
}
54+
3555
// CHECK-LABEL: @test_lower_inverse
3656
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
3757
func.func @test_lower_inverse(%lhs : !PF1) -> !PF1 {

tools/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ cc_binary(
2828
"//zkir/Dialect/ModArith/IR:ModArith",
2929
"//zkir/Dialect/Poly/Conversions/PolyToField",
3030
"//zkir/Dialect/Poly/IR:Poly",
31+
"//zkir/Pipelines:PipelineRegistration",
3132
"@llvm-project//llvm:Support",
3233
"@llvm-project//mlir:AllExtensions",
3334
"@llvm-project//mlir:AllPassesAndDialects",

tools/zkir-opt.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#include "zkir/Dialect/ModArith/IR/ModArithDialect.h"
1111
#include "zkir/Dialect/Poly/Conversions/PolyToField/PolyToField.h"
1212
#include "zkir/Dialect/Poly/IR/PolyDialect.h"
13+
#include "zkir/Pipelines/PipelineRegistration.h"
1314

1415
int main(int argc, char **argv) {
1516
mlir::DialectRegistry registry;
@@ -27,5 +28,9 @@ int main(int argc, char **argv) {
2728
mlir::zkir::field::registerFieldToModArithPasses();
2829
mlir::zkir::poly::registerPolyToFieldPasses();
2930

31+
mlir::PassPipelineRegistration<>(
32+
"poly-to-llvm", "Run passes to lower the polynomial dialect to LLVM",
33+
mlir::zkir::pipelines::polyToLLVMPipelineBuilder);
34+
3035
return failed(mlir::MlirOptMain(argc, argv, "ZKIR optimizer\n", registry));
3136
}

zkir/Dialect/Field/Conversions/FieldToModArith/FieldToModArith.cpp

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,25 @@ struct ConvertEncapsulate : public OpConversionPattern<EncapsulateOp> {
9696
}
9797
};
9898

99+
struct ConvertExtract : public OpConversionPattern<ExtractOp> {
100+
explicit ConvertExtract(mlir::MLIRContext *context)
101+
: OpConversionPattern<ExtractOp>(context) {}
102+
103+
using OpConversionPattern::OpConversionPattern;
104+
105+
LogicalResult matchAndRewrite(
106+
ExtractOp op, OpAdaptor adaptor,
107+
ConversionPatternRewriter &rewriter) const override {
108+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
109+
110+
auto resultType = typeConverter->convertType(op.getResult().getType());
111+
auto extracted =
112+
b.create<mod_arith::ExtractOp>(resultType, adaptor.getOperands()[0]);
113+
rewriter.replaceOp(op, extracted);
114+
return success();
115+
}
116+
};
117+
99118
struct ConvertInverse : public OpConversionPattern<InverseOp> {
100119
explicit ConvertInverse(mlir::MLIRContext *context)
101120
: OpConversionPattern<InverseOp>(context) {}
@@ -187,13 +206,20 @@ void PrimeFieldToModArith::runOnOperation() {
187206

188207
RewritePatternSet patterns(context);
189208
rewrites::populateWithGenerated(patterns);
190-
patterns.add<ConvertConstant, ConvertEncapsulate, ConvertInverse, ConvertAdd,
191-
ConvertSub, ConvertMul, ConvertAny<tensor::FromElementsOp>>(
192-
typeConverter, context);
209+
patterns
210+
.add<ConvertConstant, ConvertEncapsulate, ConvertExtract, ConvertInverse,
211+
ConvertAdd, ConvertSub, ConvertMul, ConvertAny<affine::AffineForOp>,
212+
ConvertAny<affine::AffineYieldOp>, ConvertAny<linalg::GenericOp>,
213+
ConvertAny<linalg::YieldOp>, ConvertAny<tensor::CastOp>,
214+
ConvertAny<tensor::ExtractOp>, ConvertAny<tensor::FromElementsOp>,
215+
ConvertAny<tensor::InsertOp>>(typeConverter, context);
193216

194217
addStructuralConversionPatterns(typeConverter, patterns, target);
195218

196-
target.addDynamicallyLegalOp<tensor::FromElementsOp>(
219+
target.addDynamicallyLegalOp<affine::AffineForOp, affine::AffineYieldOp,
220+
linalg::GenericOp, linalg::YieldOp,
221+
tensor::CastOp, tensor::ExtractOp,
222+
tensor::FromElementsOp, tensor::InsertOp>(
197223
[&](auto op) { return typeConverter.isLegal(op); });
198224

199225
if (failed(applyPartialConversion(module, target, std::move(patterns)))) {

zkir/Dialect/Field/IR/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package(
88
cc_library(
99
name = "Field",
1010
srcs = [
11+
"FieldAttributes.cpp",
1112
"FieldDialect.cpp",
1213
],
1314
hdrs = [
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#include "zkir/Dialect/Field/IR/FieldAttributes.h"
2+
3+
#include "mlir/IR/BuiltinAttributes.h"
4+
#include "mlir/IR/BuiltinTypes.h"
5+
6+
namespace mlir::zkir::field {
7+
8+
LogicalResult PrimeFieldAttr::verify(
9+
::llvm::function_ref<::mlir::InFlightDiagnostic()> emitError,
10+
PrimeFieldType type, IntegerAttr value) {
11+
if (type.getModulus().getValue().getBitWidth() !=
12+
value.getValue().getBitWidth()) {
13+
emitError()
14+
<< "prime field modulus bitwidth does not match the value bitwidth";
15+
return failure();
16+
}
17+
return success();
18+
}
19+
20+
} // namespace mlir::zkir::field

0 commit comments

Comments
 (0)