Skip to content

Commit d5db0dc

Browse files
authored
Merge pull request #27 from zk-rabbit/fix/elliptic-curve-add
fix(elliptic_curve): fix elliptic curve point add zero case
2 parents 3d947e3 + a6bff8a commit d5db0dc

File tree

6 files changed

+134
-23
lines changed

6 files changed

+134
-23
lines changed

tests/Dialect/EllipticCurve/elliptic_curve_to_field_runner.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ func.func @test_ops_in_order() {
114114
// CHECK_TEST_OPS_IN_ORDER: [1, 10, 1]
115115
// CHECK_TEST_OPS_IN_ORDER: [1, 10, 1, 1]
116116
// CHECK_TEST_OPS_IN_ORDER: [1, 10]
117-
// CHECK_TEST_OPS_IN_ORDER: [0, 0, 0]
118-
// CHECK_TEST_OPS_IN_ORDER: [1, 1]
119-
// CHECK_TEST_OPS_IN_ORDER: [4, 3, 0, 0]
117+
// CHECK_TEST_OPS_IN_ORDER: [10, 5, 4]
118+
// CHECK_TEST_OPS_IN_ORDER: [2, 3]
119+
// CHECK_TEST_OPS_IN_ORDER: [2, 8, 1, 1]
120120

121121

122122
// CHECK-LABEL: @test_msm
@@ -164,5 +164,5 @@ func.func @test_msm() {
164164
return
165165
}
166166

167-
// CHECK_TEST_MSM: [0, 0, 0]
168-
// CHECK_TEST_MSM: [0, 0, 0]
167+
// CHECK_TEST_MSM: [0, 3, 7]
168+
// CHECK_TEST_MSM: [0, 3, 7]

zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/EllipticCurveToField.cpp

Lines changed: 98 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,37 @@ struct ConvertPoint : public OpConversionPattern<PointOp> {
147147
}
148148
};
149149

150+
struct ConvertIsZero : public OpConversionPattern<IsZeroOp> {
151+
explicit ConvertIsZero(MLIRContext *context)
152+
: OpConversionPattern<IsZeroOp>(context) {}
153+
154+
using OpConversionPattern::OpConversionPattern;
155+
156+
LogicalResult matchAndRewrite(
157+
IsZeroOp op, OneToNOpAdaptor adaptor,
158+
ConversionPatternRewriter &rewriter) const override {
159+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
160+
161+
ValueRange coords = adaptor.getInput();
162+
field::PrimeFieldType baseField =
163+
cast<field::PrimeFieldType>(coords[0].getType());
164+
Value zeroPF = b.create<field::ConstantOp>(baseField, 0);
165+
166+
Value cmp;
167+
if (isa<AffineType>(op.getInput().getType())) {
168+
Value xIsZero =
169+
b.create<field::CmpOp>(arith::CmpIPredicate::eq, coords[0], zeroPF);
170+
Value yIsZero =
171+
b.create<field::CmpOp>(arith::CmpIPredicate::eq, coords[1], zeroPF);
172+
cmp = b.create<arith::AndIOp>(xIsZero, yIsZero);
173+
} else {
174+
cmp = b.create<field::CmpOp>(arith::CmpIPredicate::eq, coords[2], zeroPF);
175+
}
176+
rewriter.replaceOp(op, cmp);
177+
return success();
178+
}
179+
};
180+
150181
struct ConvertExtract : public OpConversionPattern<ExtractOp> {
151182
explicit ConvertExtract(MLIRContext *context)
152183
: OpConversionPattern<ExtractOp>(context) {}
@@ -313,20 +344,67 @@ struct ConvertAdd : public OpConversionPattern<AddOp> {
313344
ConversionPatternRewriter &rewriter) const override {
314345
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
315346

347+
Value p1 = op.getLhs();
348+
Value p2 = op.getRhs();
349+
ValueRange p1Coords = adaptor.getLhs();
350+
ValueRange p2Coords = adaptor.getRhs();
351+
Type p1Type = p1.getType();
352+
Type p2Type = p2.getType();
316353
Type outputType = op.getOutput().getType();
317-
ValueRange p1 = adaptor.getLhs();
318-
ValueRange p2 = adaptor.getRhs();
319-
SmallVector<Value> sum;
320354

321-
if (auto xyzzType = dyn_cast<XYZZType>(outputType)) {
322-
sum = xyzzAdd(p1, p2, xyzzType.getCurve(), b);
323-
} else if (auto jacobianType = dyn_cast<JacobianType>(outputType)) {
324-
sum = jacobianAdd(p1, p2, jacobianType.getCurve(), b);
325-
} else {
326-
assert(false && "Unsupported point types for addition");
327-
}
328-
329-
rewriter.replaceOpWithMultiple(op, {sum});
355+
// check p1 == zero point
356+
Value p1isZeroCmp = b.create<elliptic_curve::IsZeroOp>(p1);
357+
auto p1IsZeroOp = b.create<scf::IfOp>(
358+
p1isZeroCmp,
359+
/*thenBuilder=*/
360+
[&](OpBuilder &builder, Location loc) {
361+
ImplicitLocOpBuilder b(loc, builder);
362+
ValueRange retP2 = p2Coords;
363+
if (isa<AffineType>(p2Type)) {
364+
retP2 = {
365+
b.create<elliptic_curve::ConvertPointTypeOp>(outputType, p2)};
366+
}
367+
368+
b.create<scf::YieldOp>(retP2);
369+
},
370+
/*elseBuilder=*/
371+
[&](OpBuilder &builder, Location loc) {
372+
ImplicitLocOpBuilder b(loc, builder);
373+
374+
// check p2 == zero point
375+
Value p2isZeroCmp = b.create<elliptic_curve::IsZeroOp>(p2);
376+
auto p2IsZeroOp = b.create<scf::IfOp>(
377+
p2isZeroCmp,
378+
/*thenBuilder=*/
379+
[&](OpBuilder &builder, Location loc) {
380+
ImplicitLocOpBuilder b(loc, builder);
381+
ValueRange retP1 = p1Coords;
382+
if (isa<AffineType>(p1Type)) {
383+
retP1 = {b.create<elliptic_curve::ConvertPointTypeOp>(
384+
outputType, p1)};
385+
}
386+
387+
b.create<scf::YieldOp>(retP1);
388+
},
389+
/*elseBuilder=*/
390+
[&](OpBuilder &builder, Location loc) {
391+
ImplicitLocOpBuilder b(loc, builder);
392+
// run default add
393+
SmallVector<Value> sum;
394+
if (auto xyzzType = dyn_cast<XYZZType>(outputType)) {
395+
sum = xyzzAdd(p1Coords, p2Coords, xyzzType.getCurve(), b);
396+
} else if (auto jacobianType =
397+
dyn_cast<JacobianType>(outputType)) {
398+
sum = jacobianAdd(p1Coords, p2Coords, jacobianType.getCurve(),
399+
b);
400+
} else {
401+
assert(false && "Unsupported point types for addition");
402+
}
403+
b.create<scf::YieldOp>(loc, sum);
404+
});
405+
b.create<scf::YieldOp>(p2IsZeroOp.getResults());
406+
});
407+
rewriter.replaceOpWithMultiple(op, {p1IsZeroOp.getResults()});
330408
return success();
331409
}
332410
};
@@ -428,12 +506,13 @@ struct ConvertScalarMul : public OpConversionPattern<ScalarMulOp> {
428506
auto scalar = b.create<field::ExtractOp>(signlessIntType, scalarPF);
429507

430508
auto zeroPF = b.create<field::ConstantOp>(baseFieldType, 0);
509+
auto onePF = b.create<field::ConstantOp>(baseFieldType, 1);
431510
Value zeroPoint =
432511
isa<XYZZType>(outputType)
433512
? b.create<elliptic_curve::PointOp>(
434-
outputType, ValueRange{zeroPF, zeroPF, zeroPF, zeroPF})
513+
outputType, ValueRange{onePF, onePF, zeroPF, zeroPF})
435514
: b.create<elliptic_curve::PointOp>(
436-
outputType, ValueRange{zeroPF, zeroPF, zeroPF});
515+
outputType, ValueRange{onePF, onePF, zeroPF});
437516

438517
Value intialPoint =
439518
isa<AffineType>(pointType)
@@ -554,10 +633,11 @@ void EllipticCurveToField::runOnOperation() {
554633

555634
RewritePatternSet patterns(context);
556635
rewrites::populateWithGenerated(patterns);
557-
patterns.add<ConvertPoint, ConvertExtract, ConvertConvertPointType,
558-
ConvertAdd, ConvertDouble, ConvertNegate, ConvertSub,
559-
ConvertScalarMul, ConvertMSM, ConvertAny<tensor::FromElementsOp>,
560-
ConvertAny<tensor::ExtractOp>>(typeConverter, context);
636+
patterns
637+
.add<ConvertPoint, ConvertIsZero, ConvertExtract, ConvertConvertPointType,
638+
ConvertAdd, ConvertDouble, ConvertNegate, ConvertSub,
639+
ConvertScalarMul, ConvertMSM, ConvertAny<tensor::FromElementsOp>,
640+
ConvertAny<tensor::ExtractOp>>(typeConverter, context);
561641
target.addDynamicallyLegalOp<tensor::FromElementsOp, tensor::ExtractOp>(
562642
[&](auto op) { return typeConverter.isLegal(op); });
563643

zkir/Dialect/EllipticCurve/IR/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ cc_library(
2323
":types_inc_gen",
2424
"//zkir/Dialect/Field/IR:Field",
2525
"//zkir/Dialect/Poly/IR:Poly",
26+
"//zkir/Utils:OpUtils",
2627
"@llvm-project//llvm:Support",
2728
"@llvm-project//mlir:IR",
2829
"@llvm-project//mlir:InferTypeOpInterface",

zkir/Dialect/EllipticCurve/IR/EllipticCurveDialect.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,15 @@ LogicalResult PointOp::verify() {
179179

180180
/////////////// VERIFY OPS /////////////////
181181

182+
LogicalResult IsZeroOp::verify() {
183+
Type inputType = getInput().getType();
184+
if (isa<AffineType>(getElementTypeOrSelf(inputType)) ||
185+
isa<JacobianType>(getElementTypeOrSelf(inputType)) ||
186+
isa<XYZZType>(getElementTypeOrSelf(inputType)))
187+
return success();
188+
return emitError() << "invalid input type";
189+
}
190+
182191
template <typename OpType>
183192
LogicalResult verifyBinaryOp(OpType op) {
184193
Type lhsType = op.getLhs().getType();

zkir/Dialect/EllipticCurve/IR/EllipticCurveOps.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "zkir/Dialect/EllipticCurve/IR/EllipticCurveTypes.h"
88
#include "zkir/Dialect/Field/IR/FieldAttributes.h"
99
#include "zkir/Dialect/Field/IR/FieldTypes.h"
10+
#include "zkir/Utils/OpUtils.h"
1011

1112
#define GET_OP_CLASSES
1213
#include "zkir/Dialect/EllipticCurve/IR/EllipticCurveOps.h.inc"

zkir/Dialect/EllipticCurve/IR/EllipticCurveOps.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,26 @@ def EllipticCurve_PointOp : EllipticCurve_Op<"point"> {
5151
let hasVerifier = 1;
5252
}
5353

54+
////////////// POINT CHECKS //////////////
55+
56+
def EllipticCurve_IsZeroOp : EllipticCurve_Op<"is_zero", [TypesMatchWith<
57+
"result type has i1 element type and same shape as operands",
58+
"input", "output", "getI1SameShape($_self)">]> {
59+
let summary = "Checks whether an elliptic curve point is zero";
60+
let description = [{
61+
Outputs a bool true (1) if the elliptic curve point input is zero, and a bool false (0) otherwise.
62+
63+
Example:
64+
```
65+
%0 = elliptic_curve.is_zero %point : !jacobian
66+
```
67+
}];
68+
let arguments = (ins PointLike:$input);
69+
let results = (outs BoolLike:$output);
70+
let hasVerifier = 1;
71+
let assemblyFormat = "operands attr-dict `:` type($input)";
72+
}
73+
5474
/////////// POINT EXTRACTION //////////////
5575

5676
def EllipticCurve_ExtractOp : EllipticCurve_Op<"extract"> {

0 commit comments

Comments
 (0)