Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 29 additions & 7 deletions lib/Conversion/PolynomialToStandard/PolynomialToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1236,15 +1236,27 @@ struct ConvertNTT : public OpConversionPattern<NTTOp> {
return failure();
}

if (!op.getRoot()) {
RingAttr ring = polyTy.getRing();
PrimitiveRootAttr root = ring.getPrimitiveRoot();

if (!root) {
op.emitError("missing root attribute");
return failure();
}

RingAttr ring = polyTy.getRing();
APInt rootDegree = root.getDegree().getValue();
auto polyModDegree =
ring.getPolynomialModulus().getPolynomial().getDegree();

// only negacyclic for now
if (rootDegree != 2 * polyModDegree) {
op.emitError("unsupported degree for primitive root");
return failure();
}

auto inputType = dyn_cast<RankedTensorType>(adaptor.getInput().getType());
auto nttResult = fastNTT<false>(
b, ring, op.getRoot().value(), inputType,
b, ring, root, inputType,
computeReverseBitOrder(b, inputType, adaptor.getInput()));

// Insert the ring encoding here to the input type
Expand Down Expand Up @@ -1276,21 +1288,31 @@ struct ConvertINTT : public OpConversionPattern<INTTOp> {
return failure();
}

if (!op.getRoot()) {
RingAttr ring = polyTy.getRing();
PrimitiveRootAttr root = ring.getPrimitiveRoot();

if (!root) {
op.emitError("missing root attribute");
return failure();
}

RingAttr ring = polyTy.getRing();
APInt rootDegree = root.getDegree().getValue();
auto polyModDegree =
ring.getPolynomialModulus().getPolynomial().getDegree();

// only negacyclic for now
if (rootDegree != 2 * polyModDegree) {
op.emitError("unsupported degree for primitive root");
return failure();
}

auto inputType = dyn_cast<RankedTensorType>(adaptor.getInput().getType());
// Remove the encoded ring from the input tensor type
auto resultType =
RankedTensorType::get(inputType.getShape(), inputType.getElementType());
auto input = b.create<tensor::CastOp>(resultType, adaptor.getInput());

auto nttResult =
fastNTT<true>(b, ring, op.getRoot().value(), resultType, input);
auto nttResult = fastNTT<true>(b, ring, root, resultType, input);

rewriter.replaceOp(op, computeReverseBitOrder(b, resultType, nttResult));

Expand Down
6 changes: 3 additions & 3 deletions lib/Dialect/Polynomial/Transforms/NTTRewrites.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,16 @@ def NTTRewritePolyMul : Pattern<
(Polynomial_MulOp:$mulOp $p1, $p2),
[
// Transform to NTT point-value representation
(Polynomial_NTTOp:$p1NTT $p1, (Nullptr),
(Polynomial_NTTOp:$p1NTT $p1,
(returnType (InputTensorType (GetRingAttr $p1)))),
(Polynomial_NTTOp:$p2NTT $p2, (Nullptr),
(Polynomial_NTTOp:$p2NTT $p2,
(returnType (InputTensorType (GetRingAttr $p2)))),

// Compute elementwise multiplication modulo cmod
(ModArith_MulOp:$mulNTT $p1NTT, $p2NTT, (GetRingModAttr $p1)),

// Compute inverse transform back to coefficient representation
(Polynomial_INTTOp:$res $mulNTT, (Nullptr))
(Polynomial_INTTOp:$res $mulNTT)
],
[
(HasDegreePowerOfTwo $p1)
Expand Down
4 changes: 2 additions & 2 deletions tests/polynomial/lower_intt.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
// https://doi.org/10.1109/ACCESS.2023.3294446

#cycl = #polynomial.int_polynomial<1 + x**4>
#ring = #polynomial.ring<coefficientType = i32, coefficientModulus = 7681 : i32, polynomialModulus=#cycl>
#root = #polynomial.primitive_root<value=1925:i32, degree=8:i32>
#ring = #polynomial.ring<coefficientType = i32, coefficientModulus = 7681 : i32, polynomialModulus=#cycl, primitiveRoot=#root>
!poly_ty = !polynomial.polynomial<ring=#ring>

// CHECK-DAG: #[[ID_MAP:.*]] = affine_map<(d0) -> (d0)>
Expand Down Expand Up @@ -81,6 +81,6 @@

func.func @lower_intt() -> !poly_ty {
%ntt_coeffs = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32, #ring>
%ret = polynomial.intt %ntt_coeffs {root=#root} : tensor<4xi32, #ring> -> !poly_ty
%ret = polynomial.intt %ntt_coeffs : tensor<4xi32, #ring> -> !poly_ty
return %ret : !poly_ty
}
4 changes: 2 additions & 2 deletions tests/polynomial/lower_intt_runner.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }

#cycl = #polynomial.int_polynomial<1 + x**4>
#ring = #polynomial.ring<coefficientType = i32, coefficientModulus = 7681 : i32, polynomialModulus=#cycl>
#root = #polynomial.primitive_root<value=1925:i32, degree=8:i32>
#ring = #polynomial.ring<coefficientType = i32, coefficientModulus = 7681 : i32, polynomialModulus=#cycl, primitiveRoot=#root>
!poly_ty = !polynomial.polynomial<ring=#ring>

func.func @test_poly_ntt() {
%coeffs = arith.constant dense<[1467,2807,3471,7621]> : tensor<4xi32>
%ntt_coeffs = tensor.cast %coeffs : tensor<4xi32> to tensor<4xi32, #ring>
%0 = polynomial.intt %ntt_coeffs {root=#root} : tensor<4xi32, #ring> -> !poly_ty
%0 = polynomial.intt %ntt_coeffs : tensor<4xi32, #ring> -> !poly_ty

%1 = polynomial.to_tensor %0 : !poly_ty -> tensor<4xi32>
%2 = bufferization.to_memref %1 : memref<4xi32>
Expand Down
4 changes: 2 additions & 2 deletions tests/polynomial/lower_ntt.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
// https://doi.org/10.1109/ACCESS.2023.3294446

#cycl = #polynomial.int_polynomial<1 + x**4>
#ring = #polynomial.ring<coefficientType = i32, coefficientModulus = 7681 : i32, polynomialModulus=#cycl>
#root = #polynomial.primitive_root<value=1925:i32, degree=8:i32>
#ring = #polynomial.ring<coefficientType = i32, coefficientModulus = 7681 : i32, polynomialModulus=#cycl, primitiveRoot=#root>
!poly_ty = !polynomial.polynomial<ring=#ring>

// CHECK-DAG: #[[ID_MAP:.*]] = affine_map<(d0) -> (d0)>
Expand Down Expand Up @@ -76,6 +76,6 @@
func.func @lower_ntt() -> tensor<4xi32, #ring> {
%coeffs = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
%poly = polynomial.from_tensor %coeffs : tensor<4xi32> -> !poly_ty
%ret = polynomial.ntt %poly {root=#root} : !poly_ty -> tensor<4xi32, #ring>
%ret = polynomial.ntt %poly : !poly_ty -> tensor<4xi32, #ring>
return %ret : tensor<4xi32, #ring>
}
4 changes: 2 additions & 2 deletions tests/polynomial/lower_ntt_runner.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,14 @@
func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }

#cycl = #polynomial.int_polynomial<1 + x**4>
#ring = #polynomial.ring<coefficientType = i32, coefficientModulus = 7681 : i32, polynomialModulus=#cycl>
#root = #polynomial.primitive_root<value=1925:i32, degree=8:i32>
#ring = #polynomial.ring<coefficientType = i32, coefficientModulus = 7681 : i32, polynomialModulus=#cycl, primitiveRoot=#root>
!poly_ty = !polynomial.polynomial<ring=#ring>

func.func @test_poly_ntt() {
%coeffs = arith.constant dense<[1,2,3,4]> : tensor<4xi32>
%poly = polynomial.from_tensor %coeffs : tensor<4xi32> -> !poly_ty
%0 = polynomial.ntt %poly {root=#root} : !poly_ty -> tensor<4xi32, #ring>
%0 = polynomial.ntt %poly : !poly_ty -> tensor<4xi32, #ring>

%1 = tensor.cast %0 : tensor<4xi32, #ring> to tensor<4xi32>
%2 = bufferization.to_memref %1 : memref<4xi32>
Expand Down
3 changes: 2 additions & 1 deletion tests/polynomial/ntt_rewrites.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
// RUN: heir-opt --convert-polynomial-mul-to-ntt %s | FileCheck --check-prefix=EXT --check-prefix=CHECK %s

#ideal = #polynomial.int_polynomial<1 + x**4>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=17:i32, polynomialModulus=#ideal>
#root = #polynomial.primitive_root<value=2:i32, degree=8:i32>
#ring = #polynomial.ring<coefficientType=i32, coefficientModulus=17:i32, polynomialModulus=#ideal, primitiveRoot=#root>
!poly_ty = !polynomial.polynomial<ring=#ring>

// CHECK: func.func @rewrite_poly_mul(%[[poly0:.*]]: [[POLY_TY:.*]], %[[poly1:.*]]: [[POLY_TY]]) -> [[POLY_TY]] {
Expand Down
6 changes: 3 additions & 3 deletions tests/polynomial/runner/lower_ntt_perf_runner.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
func.func private @printMemrefI32(memref<*xi32>) attributes { llvm.emit_c_interface }

#cycl = #polynomial.int_polynomial<1 + x**65536>
#ring = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#cycl>
#root = #polynomial.primitive_root<value=283965:i32, degree=131072:i32>
#ring = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#cycl, primitiveRoot=#root>
!poly_ty = !polynomial.polynomial<ring=#ring>

func.func @test_poly_ntt() {
Expand All @@ -17,14 +17,14 @@ func.func @test_poly_ntt() {
%insert_rand0 = tensor.insert_slice %rand_coeffs into %full[0] [256] [1] : tensor<256xi32> into tensor<65536xi32>
%insert_rand1 = tensor.insert_slice %rand_coeffs into %insert_rand0[65280] [256] [1] : tensor<256xi32> into tensor<65536xi32>
%poly = polynomial.from_tensor %insert_rand1 : tensor<65536xi32> -> !poly_ty
%0 = polynomial.ntt %poly {root=#root} : !poly_ty -> tensor<65536xi32, #ring>
%0 = polynomial.ntt %poly : !poly_ty -> tensor<65536xi32, #ring>

// Insert casts so that intt(ntt()) does not get folded away during polynomial
// canonicalization
%cast = tensor.cast %0 : tensor<65536xi32, #ring> to tensor<65536xi32>
%cast_back = tensor.cast %cast : tensor<65536xi32> to tensor<65536xi32, #ring>

%1 = polynomial.intt %cast_back {root=#root} : tensor<65536xi32, #ring> -> !poly_ty
%1 = polynomial.intt %cast_back : tensor<65536xi32, #ring> -> !poly_ty

%2 = polynomial.to_tensor %1 : !poly_ty -> tensor<65536xi32>
%4 = bufferization.to_memref %2 : memref<65536xi32>
Expand Down
6 changes: 3 additions & 3 deletions tests/polynomial/runner/ntt_benchmark.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#cycl = #polynomial.int_polynomial<1 + x**65536>
#ring = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#cycl>
#root = #polynomial.primitive_root<value=283965:i32, degree=131072:i32>
#ring = #polynomial.ring<coefficientType = i32, coefficientModulus = 786433 : i32, polynomialModulus=#cycl, primitiveRoot=#root>
!poly_ty = !polynomial.polynomial<ring=#ring>

func.func @input_generation() -> !poly_ty attributes { llvm.emit_c_interface } {
Expand All @@ -14,11 +14,11 @@ func.func @input_generation() -> !poly_ty attributes { llvm.emit_c_interface } {
}

func.func @ntt(%arg0 : !poly_ty) -> tensor<65536xi32, #ring> attributes { llvm.emit_c_interface } {
%0 = polynomial.ntt %arg0 {root=#root} : !poly_ty -> tensor<65536xi32, #ring>
%0 = polynomial.ntt %arg0 : !poly_ty -> tensor<65536xi32, #ring>
return %0 : tensor<65536xi32, #ring>
}

func.func @intt(%arg0 : tensor<65536xi32, #ring>) -> !poly_ty attributes { llvm.emit_c_interface } {
%0 = polynomial.intt %arg0 {root=#root} : tensor<65536xi32, #ring> -> !poly_ty
%0 = polynomial.intt %arg0 : tensor<65536xi32, #ring> -> !poly_ty
return %0 :!poly_ty
}