From fe580e4f4a32ba44e032709c2533b727e603fa95 Mon Sep 17 00:00:00 2001 From: Zenithal Date: Fri, 11 Oct 2024 04:10:12 +0000 Subject: [PATCH] Fix ntt/intt tests with upstream polynomial change --- .../PolynomialToStandard.cpp | 36 +++++++++++++++---- .../Polynomial/Transforms/NTTRewrites.td | 6 ++-- tests/polynomial/lower_intt.mlir | 4 +-- tests/polynomial/lower_intt_runner.mlir | 4 +-- tests/polynomial/lower_ntt.mlir | 4 +-- tests/polynomial/lower_ntt_runner.mlir | 4 +-- tests/polynomial/ntt_rewrites.mlir | 3 +- .../runner/lower_ntt_perf_runner.mlir | 6 ++-- tests/polynomial/runner/ntt_benchmark.mlir | 6 ++-- 9 files changed, 48 insertions(+), 25 deletions(-) diff --git a/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp b/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp index 68aadeeaa1..fbeff48ed0 100644 --- a/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp +++ b/lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp @@ -1236,15 +1236,27 @@ struct ConvertNTT : public OpConversionPattern { return failure(); } - if (!op.getRoot()) { + RingAttr ring = polyTy.getRing(); + PrimitiveRootAttr root = ring.getPrimitiveRoot(); + + if (!root) { op.emitError("missing root attribute"); return failure(); } - RingAttr ring = polyTy.getRing(); + APInt rootDegree = root.getDegree().getValue(); + auto polyModDegree = + ring.getPolynomialModulus().getPolynomial().getDegree(); + + // only negacyclic for now + if (rootDegree != 2 * polyModDegree) { + op.emitError("unsupported degree for primitive root"); + return failure(); + } + auto inputType = dyn_cast(adaptor.getInput().getType()); auto nttResult = fastNTT( - b, ring, op.getRoot().value(), inputType, + b, ring, root, inputType, computeReverseBitOrder(b, inputType, adaptor.getInput())); // Insert the ring encoding here to the input type @@ -1276,12 +1288,23 @@ struct ConvertINTT : public OpConversionPattern { return failure(); } - if (!op.getRoot()) { + RingAttr ring = polyTy.getRing(); + PrimitiveRootAttr root = ring.getPrimitiveRoot(); + + if (!root) { op.emitError("missing root attribute"); return failure(); } - RingAttr ring = polyTy.getRing(); + APInt rootDegree = root.getDegree().getValue(); + auto polyModDegree = + ring.getPolynomialModulus().getPolynomial().getDegree(); + + // only negacyclic for now + if (rootDegree != 2 * polyModDegree) { + op.emitError("unsupported degree for primitive root"); + return failure(); + } auto inputType = dyn_cast(adaptor.getInput().getType()); // Remove the encoded ring from the input tensor type @@ -1289,8 +1312,7 @@ struct ConvertINTT : public OpConversionPattern { RankedTensorType::get(inputType.getShape(), inputType.getElementType()); auto input = b.create(resultType, adaptor.getInput()); - auto nttResult = - fastNTT(b, ring, op.getRoot().value(), resultType, input); + auto nttResult = fastNTT(b, ring, root, resultType, input); rewriter.replaceOp(op, computeReverseBitOrder(b, resultType, nttResult)); diff --git a/lib/Dialect/Polynomial/Transforms/NTTRewrites.td b/lib/Dialect/Polynomial/Transforms/NTTRewrites.td index be915fe3fd..9a4de826ad 100644 --- a/lib/Dialect/Polynomial/Transforms/NTTRewrites.td +++ b/lib/Dialect/Polynomial/Transforms/NTTRewrites.td @@ -36,16 +36,16 @@ def NTTRewritePolyMul : Pattern< (Polynomial_MulOp:$mulOp $p1, $p2), [ // Transform to NTT point-value representation - (Polynomial_NTTOp:$p1NTT $p1, (Nullptr), + (Polynomial_NTTOp:$p1NTT $p1, (returnType (InputTensorType (GetRingAttr $p1)))), - (Polynomial_NTTOp:$p2NTT $p2, (Nullptr), + (Polynomial_NTTOp:$p2NTT $p2, (returnType (InputTensorType (GetRingAttr $p2)))), // Compute elementwise multiplication modulo cmod (ModArith_MulOp:$mulNTT $p1NTT, $p2NTT, (GetRingModAttr $p1)), // Compute inverse transform back to coefficient representation - (Polynomial_INTTOp:$res $mulNTT, (Nullptr)) + (Polynomial_INTTOp:$res $mulNTT) ], [ (HasDegreePowerOfTwo $p1) diff --git a/tests/polynomial/lower_intt.mlir b/tests/polynomial/lower_intt.mlir index 436fe8e642..f07da43db4 100644 --- a/tests/polynomial/lower_intt.mlir +++ b/tests/polynomial/lower_intt.mlir @@ -4,8 +4,8 @@ // https://doi.org/10.1109/ACCESS.2023.3294446 #cycl = #polynomial.int_polynomial<1 + x**4> -#ring = #polynomial.ring #root = #polynomial.primitive_root +#ring = #polynomial.ring !poly_ty = !polynomial.polynomial // CHECK-DAG: #[[ID_MAP:.*]] = affine_map<(d0) -> (d0)> @@ -81,6 +81,6 @@ func.func @lower_intt() -> !poly_ty { %ntt_coeffs = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32, #ring> - %ret = polynomial.intt %ntt_coeffs {root=#root} : tensor<4xi32, #ring> -> !poly_ty + %ret = polynomial.intt %ntt_coeffs : tensor<4xi32, #ring> -> !poly_ty return %ret : !poly_ty } diff --git a/tests/polynomial/lower_intt_runner.mlir b/tests/polynomial/lower_intt_runner.mlir index 235f037369..9e53cf0d89 100644 --- a/tests/polynomial/lower_intt_runner.mlir +++ b/tests/polynomial/lower_intt_runner.mlir @@ -9,14 +9,14 @@ func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface } #cycl = #polynomial.int_polynomial<1 + x**4> -#ring = #polynomial.ring #root = #polynomial.primitive_root +#ring = #polynomial.ring !poly_ty = !polynomial.polynomial func.func @test_poly_ntt() { %coeffs = arith.constant dense<[1467,2807,3471,7621]> : tensor<4xi32> %ntt_coeffs = tensor.cast %coeffs : tensor<4xi32> to tensor<4xi32, #ring> - %0 = polynomial.intt %ntt_coeffs {root=#root} : tensor<4xi32, #ring> -> !poly_ty + %0 = polynomial.intt %ntt_coeffs : tensor<4xi32, #ring> -> !poly_ty %1 = polynomial.to_tensor %0 : !poly_ty -> tensor<4xi32> %2 = bufferization.to_memref %1 : memref<4xi32> diff --git a/tests/polynomial/lower_ntt.mlir b/tests/polynomial/lower_ntt.mlir index 7ab11d338f..a33663b3a0 100644 --- a/tests/polynomial/lower_ntt.mlir +++ b/tests/polynomial/lower_ntt.mlir @@ -4,8 +4,8 @@ // https://doi.org/10.1109/ACCESS.2023.3294446 #cycl = #polynomial.int_polynomial<1 + x**4> -#ring = #polynomial.ring #root = #polynomial.primitive_root +#ring = #polynomial.ring !poly_ty = !polynomial.polynomial // CHECK-DAG: #[[ID_MAP:.*]] = affine_map<(d0) -> (d0)> @@ -76,6 +76,6 @@ func.func @lower_ntt() -> tensor<4xi32, #ring> { %coeffs = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32> %poly = polynomial.from_tensor %coeffs : tensor<4xi32> -> !poly_ty - %ret = polynomial.ntt %poly {root=#root} : !poly_ty -> tensor<4xi32, #ring> + %ret = polynomial.ntt %poly : !poly_ty -> tensor<4xi32, #ring> return %ret : tensor<4xi32, #ring> } diff --git a/tests/polynomial/lower_ntt_runner.mlir b/tests/polynomial/lower_ntt_runner.mlir index 0206d11c5c..755bc7b3c5 100644 --- a/tests/polynomial/lower_ntt_runner.mlir +++ b/tests/polynomial/lower_ntt_runner.mlir @@ -9,14 +9,14 @@ func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface } #cycl = #polynomial.int_polynomial<1 + x**4> -#ring = #polynomial.ring #root = #polynomial.primitive_root +#ring = #polynomial.ring !poly_ty = !polynomial.polynomial func.func @test_poly_ntt() { %coeffs = arith.constant dense<[1,2,3,4]> : tensor<4xi32> %poly = polynomial.from_tensor %coeffs : tensor<4xi32> -> !poly_ty - %0 = polynomial.ntt %poly {root=#root} : !poly_ty -> tensor<4xi32, #ring> + %0 = polynomial.ntt %poly : !poly_ty -> tensor<4xi32, #ring> %1 = tensor.cast %0 : tensor<4xi32, #ring> to tensor<4xi32> %2 = bufferization.to_memref %1 : memref<4xi32> diff --git a/tests/polynomial/ntt_rewrites.mlir b/tests/polynomial/ntt_rewrites.mlir index bfbe549639..f06f0f6115 100644 --- a/tests/polynomial/ntt_rewrites.mlir +++ b/tests/polynomial/ntt_rewrites.mlir @@ -2,7 +2,8 @@ // RUN: heir-opt --convert-polynomial-mul-to-ntt %s | FileCheck --check-prefix=EXT --check-prefix=CHECK %s #ideal = #polynomial.int_polynomial<1 + x**4> -#ring = #polynomial.ring +#root = #polynomial.primitive_root +#ring = #polynomial.ring !poly_ty = !polynomial.polynomial // CHECK: func.func @rewrite_poly_mul(%[[poly0:.*]]: [[POLY_TY:.*]], %[[poly1:.*]]: [[POLY_TY]]) -> [[POLY_TY]] { diff --git a/tests/polynomial/runner/lower_ntt_perf_runner.mlir b/tests/polynomial/runner/lower_ntt_perf_runner.mlir index d740fa8964..802a6ca01c 100644 --- a/tests/polynomial/runner/lower_ntt_perf_runner.mlir +++ b/tests/polynomial/runner/lower_ntt_perf_runner.mlir @@ -6,8 +6,8 @@ func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface } #cycl = #polynomial.int_polynomial<1 + x**65536> -#ring = #polynomial.ring #root = #polynomial.primitive_root +#ring = #polynomial.ring !poly_ty = !polynomial.polynomial func.func @test_poly_ntt() { @@ -17,14 +17,14 @@ func.func @test_poly_ntt() { %insert_rand0 = tensor.insert_slice %rand_coeffs into %full[0] [256] [1] : tensor<256xi32> into tensor<65536xi32> %insert_rand1 = tensor.insert_slice %rand_coeffs into %insert_rand0[65280] [256] [1] : tensor<256xi32> into tensor<65536xi32> %poly = polynomial.from_tensor %insert_rand1 : tensor<65536xi32> -> !poly_ty - %0 = polynomial.ntt %poly {root=#root} : !poly_ty -> tensor<65536xi32, #ring> + %0 = polynomial.ntt %poly : !poly_ty -> tensor<65536xi32, #ring> // Insert casts so that intt(ntt()) does not get folded away during polynomial // canonicalization %cast = tensor.cast %0 : tensor<65536xi32, #ring> to tensor<65536xi32> %cast_back = tensor.cast %cast : tensor<65536xi32> to tensor<65536xi32, #ring> - %1 = polynomial.intt %cast_back {root=#root} : tensor<65536xi32, #ring> -> !poly_ty + %1 = polynomial.intt %cast_back : tensor<65536xi32, #ring> -> !poly_ty %2 = polynomial.to_tensor %1 : !poly_ty -> tensor<65536xi32> %4 = bufferization.to_memref %2 : memref<65536xi32> diff --git a/tests/polynomial/runner/ntt_benchmark.mlir b/tests/polynomial/runner/ntt_benchmark.mlir index 7271aa4e73..cd67cf7dde 100644 --- a/tests/polynomial/runner/ntt_benchmark.mlir +++ b/tests/polynomial/runner/ntt_benchmark.mlir @@ -1,6 +1,6 @@ #cycl = #polynomial.int_polynomial<1 + x**65536> -#ring = #polynomial.ring #root = #polynomial.primitive_root +#ring = #polynomial.ring !poly_ty = !polynomial.polynomial func.func @input_generation() -> !poly_ty attributes { llvm.emit_c_interface } { @@ -14,11 +14,11 @@ func.func @input_generation() -> !poly_ty attributes { llvm.emit_c_interface } { } func.func @ntt(%arg0 : !poly_ty) -> tensor<65536xi32, #ring> attributes { llvm.emit_c_interface } { - %0 = polynomial.ntt %arg0 {root=#root} : !poly_ty -> tensor<65536xi32, #ring> + %0 = polynomial.ntt %arg0 : !poly_ty -> tensor<65536xi32, #ring> return %0 : tensor<65536xi32, #ring> } func.func @intt(%arg0 : tensor<65536xi32, #ring>) -> !poly_ty attributes { llvm.emit_c_interface } { - %0 = polynomial.intt %arg0 {root=#root} : tensor<65536xi32, #ring> -> !poly_ty + %0 = polynomial.intt %arg0 : tensor<65536xi32, #ring> -> !poly_ty return %0 :!poly_ty }