Skip to content

Commit 72e2cb6

Browse files
committed
feat: add IsZeroOp
1 parent 3d947e3 commit 72e2cb6

File tree

5 files changed

+67
-4
lines changed

5 files changed

+67
-4
lines changed

zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/EllipticCurveToField.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,37 @@ struct ConvertPoint : public OpConversionPattern<PointOp> {
147147
}
148148
};
149149

150+
struct ConvertIsZero : public OpConversionPattern<IsZeroOp> {
151+
explicit ConvertIsZero(MLIRContext *context)
152+
: OpConversionPattern<IsZeroOp>(context) {}
153+
154+
using OpConversionPattern::OpConversionPattern;
155+
156+
LogicalResult matchAndRewrite(
157+
IsZeroOp op, OneToNOpAdaptor adaptor,
158+
ConversionPatternRewriter &rewriter) const override {
159+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
160+
161+
ValueRange coords = adaptor.getInput();
162+
field::PrimeFieldType baseField =
163+
cast<field::PrimeFieldType>(coords[0].getType());
164+
Value zeroPF = b.create<field::ConstantOp>(baseField, 0);
165+
166+
Value cmp;
167+
if (isa<AffineType>(op.getInput().getType())) {
168+
Value xIsZero =
169+
b.create<field::CmpOp>(arith::CmpIPredicate::eq, coords[0], zeroPF);
170+
Value yIsZero =
171+
b.create<field::CmpOp>(arith::CmpIPredicate::eq, coords[1], zeroPF);
172+
cmp = b.create<arith::AndIOp>(xIsZero, yIsZero);
173+
} else {
174+
cmp = b.create<field::CmpOp>(arith::CmpIPredicate::eq, coords[2], zeroPF);
175+
}
176+
rewriter.replaceOp(op, cmp);
177+
return success();
178+
}
179+
};
180+
150181
struct ConvertExtract : public OpConversionPattern<ExtractOp> {
151182
explicit ConvertExtract(MLIRContext *context)
152183
: OpConversionPattern<ExtractOp>(context) {}
@@ -554,10 +585,11 @@ void EllipticCurveToField::runOnOperation() {
554585

555586
RewritePatternSet patterns(context);
556587
rewrites::populateWithGenerated(patterns);
557-
patterns.add<ConvertPoint, ConvertExtract, ConvertConvertPointType,
558-
ConvertAdd, ConvertDouble, ConvertNegate, ConvertSub,
559-
ConvertScalarMul, ConvertMSM, ConvertAny<tensor::FromElementsOp>,
560-
ConvertAny<tensor::ExtractOp>>(typeConverter, context);
588+
patterns
589+
.add<ConvertPoint, ConvertIsZero, ConvertExtract, ConvertConvertPointType,
590+
ConvertAdd, ConvertDouble, ConvertNegate, ConvertSub,
591+
ConvertScalarMul, ConvertMSM, ConvertAny<tensor::FromElementsOp>,
592+
ConvertAny<tensor::ExtractOp>>(typeConverter, context);
561593
target.addDynamicallyLegalOp<tensor::FromElementsOp, tensor::ExtractOp>(
562594
[&](auto op) { return typeConverter.isLegal(op); });
563595

zkir/Dialect/EllipticCurve/IR/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ cc_library(
2323
":types_inc_gen",
2424
"//zkir/Dialect/Field/IR:Field",
2525
"//zkir/Dialect/Poly/IR:Poly",
26+
"//zkir/Utils:OpUtils",
2627
"@llvm-project//llvm:Support",
2728
"@llvm-project//mlir:IR",
2829
"@llvm-project//mlir:InferTypeOpInterface",

zkir/Dialect/EllipticCurve/IR/EllipticCurveDialect.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,15 @@ LogicalResult PointOp::verify() {
179179

180180
/////////////// VERIFY OPS /////////////////
181181

182+
LogicalResult IsZeroOp::verify() {
183+
Type inputType = getInput().getType();
184+
if (isa<AffineType>(getElementTypeOrSelf(inputType)) ||
185+
isa<JacobianType>(getElementTypeOrSelf(inputType)) ||
186+
isa<XYZZType>(getElementTypeOrSelf(inputType)))
187+
return success();
188+
return emitError() << "invalid input type";
189+
}
190+
182191
template <typename OpType>
183192
LogicalResult verifyBinaryOp(OpType op) {
184193
Type lhsType = op.getLhs().getType();

zkir/Dialect/EllipticCurve/IR/EllipticCurveOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "zkir/Dialect/EllipticCurve/IR/EllipticCurveTypes.h"
88
#include "zkir/Dialect/Field/IR/FieldAttributes.h"
99
#include "zkir/Dialect/Field/IR/FieldTypes.h"
10+
#include "zkir/Utils/OpUtils.h"
1011

1112
#define GET_OP_CLASSES
1213
#include "zkir/Dialect/EllipticCurve/IR/EllipticCurveOps.h.inc"

zkir/Dialect/EllipticCurve/IR/EllipticCurveOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,26 @@ def EllipticCurve_PointOp : EllipticCurve_Op<"point"> {
5151
let hasVerifier = 1;
5252
}
5353

54+
////////////// POINT CHECKS //////////////
55+
56+
def EllipticCurve_IsZeroOp : EllipticCurve_Op<"is_zero", [TypesMatchWith<
57+
"result type has i1 element type and same shape as operands",
58+
"input", "output", "getI1SameShape($_self)">]> {
59+
let summary = "Checks whether an elliptic curve point is zero";
60+
let description = [{
61+
Outputs a bool true (1) if the elliptic curve point input is zero, and a bool false (0) otherwise.
62+
63+
Example:
64+
```
65+
%0 = elliptic_curve.is_zero %point : !jacobian
66+
```
67+
}];
68+
let arguments = (ins PointLike:$input);
69+
let results = (outs BoolLike:$output);
70+
let hasVerifier = 1;
71+
let assemblyFormat = "operands attr-dict `:` type($input)";
72+
}
73+
5474
/////////// POINT EXTRACTION //////////////
5575

5676
def EllipticCurve_ExtractOp : EllipticCurve_Op<"extract"> {

0 commit comments

Comments
 (0)