Skip to content

Commit 7783eb9

Browse files
committed
feat(mod_arith): impl NegateOp
1 parent c4e1f90 commit 7783eb9

File tree

4 files changed

+71
-9
lines changed

4 files changed

+71
-9
lines changed

tests/Dialect/ModArith/mod_arith_to_arith.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,28 @@ func.func @test_lower_constant() -> !mod_arith.int<3 : i5> {
1313
return %res: !mod_arith.int<3 : i5>
1414
}
1515

16+
// CHECK-LABEL: @test_lower_negate
17+
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
18+
func.func @test_lower_negate(%lhs : !Zp) -> !Zp {
19+
// CHECK-NOT: mod_arith.negate
20+
// CHECK: %[[CMOD:.*]] = arith.constant 65537 : [[T]]
21+
// CHECK: %[[SUB:.*]] = arith.subi %[[CMOD]], %[[LHS]] : [[T]]
22+
// CHECK: return %[[SUB]] : [[T]]
23+
%res = mod_arith.negate %lhs: !Zp
24+
return %res : !Zp
25+
}
26+
27+
// CHECK-LABEL: @test_lower_negate_vec
28+
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
29+
func.func @test_lower_negate_vec(%lhs : !Zpv) -> !Zpv {
30+
// CHECK-NOT: mod_arith.negate
31+
// CHECK: %[[CMOD:.*]] = arith.constant dense<65537> : [[T]]
32+
// CHECK: %[[SUB:.*]] = arith.subi %[[CMOD]], %[[LHS]] : [[T]]
33+
// CHECK: return %[[SUB]] : [[T]]
34+
%res = mod_arith.negate %lhs: !Zpv
35+
return %res : !Zpv
36+
}
37+
1638
// CHECK-LABEL: @test_lower_encapsulate
1739
// CHECK-SAME: (%[[LHS:.*]]: [[T:.*]]) -> [[T]] {
1840
func.func @test_lower_encapsulate(%lhs : i32) -> !Zp {

zkir/Dialect/ModArith/Conversions/ModArithToArith/ModArithToArith.cpp

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,24 @@ struct ConvertConstant : public OpConversionPattern<ConstantOp> {
128128
}
129129
};
130130

131+
struct ConvertNegate : public OpConversionPattern<NegateOp> {
132+
explicit ConvertNegate(mlir::MLIRContext *context)
133+
: OpConversionPattern<NegateOp>(context) {}
134+
135+
using OpConversionPattern::OpConversionPattern;
136+
137+
LogicalResult matchAndRewrite(
138+
NegateOp op, OpAdaptor adaptor,
139+
ConversionPatternRewriter &rewriter) const override {
140+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
141+
142+
auto cmod = b.create<arith::ConstantOp>(modulusAttr(op));
143+
auto sub = b.create<arith::SubIOp>(cmod, adaptor.getOperands()[0]);
144+
rewriter.replaceOp(op, sub);
145+
return success();
146+
}
147+
};
148+
131149
struct ConvertReduce : public OpConversionPattern<ReduceOp> {
132150
explicit ConvertReduce(mlir::MLIRContext *context)
133151
: OpConversionPattern<ReduceOp>(context) {}
@@ -530,15 +548,15 @@ void ModArithToArith::runOnOperation() {
530548

531549
RewritePatternSet patterns(context);
532550
rewrites::populateWithGenerated(patterns);
533-
patterns
534-
.add<ConvertEncapsulate, ConvertExtract, ConvertReduce, ConvertMontReduce,
535-
ConvertToMont, ConvertFromMont, ConvertAdd, ConvertSub, ConvertMul,
536-
ConvertMontMul, ConvertMac, ConvertConstant, ConvertInverse,
537-
ConvertAny<affine::AffineForOp>, ConvertAny<affine::AffineYieldOp>,
538-
ConvertAny<linalg::GenericOp>, ConvertAny<linalg::YieldOp>,
539-
ConvertAny<tensor::CastOp>, ConvertAny<tensor::ExtractOp>,
540-
ConvertAny<tensor::FromElementsOp>, ConvertAny<tensor::InsertOp>>(
541-
typeConverter, context);
551+
patterns.add<
552+
ConvertNegate, ConvertEncapsulate, ConvertExtract, ConvertReduce,
553+
ConvertMontReduce, ConvertToMont, ConvertFromMont, ConvertAdd, ConvertSub,
554+
ConvertMul, ConvertMontMul, ConvertMac, ConvertConstant, ConvertInverse,
555+
ConvertAny<affine::AffineForOp>, ConvertAny<affine::AffineYieldOp>,
556+
ConvertAny<linalg::GenericOp>, ConvertAny<linalg::YieldOp>,
557+
ConvertAny<tensor::CastOp>, ConvertAny<tensor::ExtractOp>,
558+
ConvertAny<tensor::FromElementsOp>, ConvertAny<tensor::InsertOp>>(
559+
typeConverter, context);
542560

543561
addStructuralConversionPatterns(typeConverter, patterns, target);
544562

zkir/Dialect/ModArith/IR/ModArithDialect.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ LogicalResult verifySameWidth(OpType op, ModArithType modArithType,
9595
return success();
9696
}
9797

98+
LogicalResult NegateOp::verify() {
99+
return verifyModArithType(*this, getResultModArithType(*this));
100+
}
101+
98102
LogicalResult ExtractOp::verify() {
99103
auto modArithType = getOperandModArithType(*this);
100104
auto integerType = getResultIntegerType(*this);

zkir/Dialect/ModArith/IR/ModArithOps.td

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,24 @@ def ModArith_ConstantOp : Op<ModArith_Dialect, "constant",
9797
let hasCustomAssemblyFormat = 1;
9898
}
9999

100+
def ModArith_NegateOp : ModArith_Op<"negate", [Pure, ElementwiseMappable, SameOperandsAndResultType, Involution]> {
101+
let summary = "negate a mod arith element";
102+
103+
let description = [{
104+
`mod_arith.negate x` computes $-x \mod q$
105+
Examples:
106+
```
107+
%1 = mod_arith.negate %0 : mod_arith.int<65537: i32>
108+
```
109+
}];
110+
111+
let arguments = (ins
112+
ModArithLike:$input
113+
);
114+
let results = (outs ModArithLike:$output);
115+
let hasVerifier = 1;
116+
let assemblyFormat = "operands attr-dict `:` type($output)";
117+
}
100118

101119
def ModArith_ReduceOp : ModArith_Op<"reduce", [Pure, ElementwiseMappable, SameOperandsAndResultType]> {
102120
let summary = "reduce the mod arith type to its canonical representative";

0 commit comments

Comments
 (0)