Skip to content

Commit e7913b1

Browse files
committed
1 parent 72e2cb6 commit e7913b1

File tree

2 files changed

+64
-17
lines changed

2 files changed

+64
-17
lines changed

tests/Dialect/EllipticCurve/elliptic_curve_to_field_runner.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ func.func @test_ops_in_order() {
114114
// CHECK_TEST_OPS_IN_ORDER: [1, 10, 1]
115115
// CHECK_TEST_OPS_IN_ORDER: [1, 10, 1, 1]
116116
// CHECK_TEST_OPS_IN_ORDER: [1, 10]
117-
// CHECK_TEST_OPS_IN_ORDER: [0, 0, 0]
118-
// CHECK_TEST_OPS_IN_ORDER: [1, 1]
119-
// CHECK_TEST_OPS_IN_ORDER: [4, 3, 0, 0]
117+
// CHECK_TEST_OPS_IN_ORDER: [10, 5, 4]
118+
// CHECK_TEST_OPS_IN_ORDER: [2, 3]
119+
// CHECK_TEST_OPS_IN_ORDER: [2, 8, 1, 1]
120120

121121

122122
// CHECK-LABEL: @test_msm
@@ -164,5 +164,5 @@ func.func @test_msm() {
164164
return
165165
}
166166

167-
// CHECK_TEST_MSM: [0, 0, 0]
168-
// CHECK_TEST_MSM: [0, 0, 0]
167+
// CHECK_TEST_MSM: [0, 3, 7]
168+
// CHECK_TEST_MSM: [0, 3, 7]

zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/EllipticCurveToField.cpp

Lines changed: 59 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -344,20 +344,67 @@ struct ConvertAdd : public OpConversionPattern<AddOp> {
344344
ConversionPatternRewriter &rewriter) const override {
345345
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
346346

347+
Value p1 = op.getLhs();
348+
Value p2 = op.getRhs();
349+
ValueRange p1Coords = adaptor.getLhs();
350+
ValueRange p2Coords = adaptor.getRhs();
351+
Type p1Type = p1.getType();
352+
Type p2Type = p2.getType();
347353
Type outputType = op.getOutput().getType();
348-
ValueRange p1 = adaptor.getLhs();
349-
ValueRange p2 = adaptor.getRhs();
350-
SmallVector<Value> sum;
351-
352-
if (auto xyzzType = dyn_cast<XYZZType>(outputType)) {
353-
sum = xyzzAdd(p1, p2, xyzzType.getCurve(), b);
354-
} else if (auto jacobianType = dyn_cast<JacobianType>(outputType)) {
355-
sum = jacobianAdd(p1, p2, jacobianType.getCurve(), b);
356-
} else {
357-
assert(false && "Unsupported point types for addition");
358-
}
359354

360-
rewriter.replaceOpWithMultiple(op, {sum});
355+
// check p1 == zero point
356+
Value p1isZeroCmp = b.create<elliptic_curve::IsZeroOp>(p1);
357+
auto p1IsZeroOp = b.create<scf::IfOp>(
358+
p1isZeroCmp,
359+
/*thenBuilder=*/
360+
[&](OpBuilder &builder, Location loc) {
361+
ImplicitLocOpBuilder b(loc, builder);
362+
ValueRange retP2 = p2Coords;
363+
if (isa<AffineType>(p2Type)) {
364+
retP2 = {
365+
b.create<elliptic_curve::ConvertPointTypeOp>(outputType, p2)};
366+
}
367+
368+
b.create<scf::YieldOp>(retP2);
369+
},
370+
/*elseBuilder=*/
371+
[&](OpBuilder &builder, Location loc) {
372+
ImplicitLocOpBuilder b(loc, builder);
373+
374+
// check p2 == zero point
375+
Value p2isZeroCmp = b.create<elliptic_curve::IsZeroOp>(p2);
376+
auto p2IsZeroOp = b.create<scf::IfOp>(
377+
p2isZeroCmp,
378+
/*thenBuilder=*/
379+
[&](OpBuilder &builder, Location loc) {
380+
ImplicitLocOpBuilder b(loc, builder);
381+
ValueRange retP1 = p1Coords;
382+
if (isa<AffineType>(p1Type)) {
383+
retP1 = {b.create<elliptic_curve::ConvertPointTypeOp>(
384+
outputType, p1)};
385+
}
386+
387+
b.create<scf::YieldOp>(retP1);
388+
},
389+
/*elseBuilder=*/
390+
[&](OpBuilder &builder, Location loc) {
391+
ImplicitLocOpBuilder b(loc, builder);
392+
// run default add
393+
SmallVector<Value> sum;
394+
if (auto xyzzType = dyn_cast<XYZZType>(outputType)) {
395+
sum = xyzzAdd(p1Coords, p2Coords, xyzzType.getCurve(), b);
396+
} else if (auto jacobianType =
397+
dyn_cast<JacobianType>(outputType)) {
398+
sum = jacobianAdd(p1Coords, p2Coords, jacobianType.getCurve(),
399+
b);
400+
} else {
401+
assert(false && "Unsupported point types for addition");
402+
}
403+
b.create<scf::YieldOp>(loc, sum);
404+
});
405+
b.create<scf::YieldOp>(p2IsZeroOp.getResults());
406+
});
407+
rewriter.replaceOpWithMultiple(op, {p1IsZeroOp.getResults()});
361408
return success();
362409
}
363410
};

0 commit comments

Comments
 (0)