@@ -147,6 +147,37 @@ struct ConvertPoint : public OpConversionPattern<PointOp> {
147147 }
148148};
149149
150+ struct ConvertIsZero : public OpConversionPattern <IsZeroOp> {
151+ explicit ConvertIsZero (MLIRContext *context)
152+ : OpConversionPattern<IsZeroOp>(context) {}
153+
154+ using OpConversionPattern::OpConversionPattern;
155+
156+ LogicalResult matchAndRewrite (
157+ IsZeroOp op, OneToNOpAdaptor adaptor,
158+ ConversionPatternRewriter &rewriter) const override {
159+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
160+
161+ ValueRange coords = adaptor.getInput ();
162+ field::PrimeFieldType baseField =
163+ cast<field::PrimeFieldType>(coords[0 ].getType ());
164+ Value zeroPF = b.create <field::ConstantOp>(baseField, 0 );
165+
166+ Value cmp;
167+ if (isa<AffineType>(op.getInput ().getType ())) {
168+ Value xIsZero =
169+ b.create <field::CmpOp>(arith::CmpIPredicate::eq, coords[0 ], zeroPF);
170+ Value yIsZero =
171+ b.create <field::CmpOp>(arith::CmpIPredicate::eq, coords[1 ], zeroPF);
172+ cmp = b.create <arith::AndIOp>(xIsZero, yIsZero);
173+ } else {
174+ cmp = b.create <field::CmpOp>(arith::CmpIPredicate::eq, coords[2 ], zeroPF);
175+ }
176+ rewriter.replaceOp (op, cmp);
177+ return success ();
178+ }
179+ };
180+
150181struct ConvertExtract : public OpConversionPattern <ExtractOp> {
151182 explicit ConvertExtract (MLIRContext *context)
152183 : OpConversionPattern<ExtractOp>(context) {}
@@ -313,20 +344,67 @@ struct ConvertAdd : public OpConversionPattern<AddOp> {
313344 ConversionPatternRewriter &rewriter) const override {
314345 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
315346
347+ Value p1 = op.getLhs ();
348+ Value p2 = op.getRhs ();
349+ ValueRange p1Coords = adaptor.getLhs ();
350+ ValueRange p2Coords = adaptor.getRhs ();
351+ Type p1Type = p1.getType ();
352+ Type p2Type = p2.getType ();
316353 Type outputType = op.getOutput ().getType ();
317- ValueRange p1 = adaptor.getLhs ();
318- ValueRange p2 = adaptor.getRhs ();
319- SmallVector<Value> sum;
320354
321- if (auto xyzzType = dyn_cast<XYZZType>(outputType)) {
322- sum = xyzzAdd (p1, p2, xyzzType.getCurve (), b);
323- } else if (auto jacobianType = dyn_cast<JacobianType>(outputType)) {
324- sum = jacobianAdd (p1, p2, jacobianType.getCurve (), b);
325- } else {
326- assert (false && " Unsupported point types for addition" );
327- }
328-
329- rewriter.replaceOpWithMultiple (op, {sum});
355+ // check p1 == zero point
356+ Value p1isZeroCmp = b.create <elliptic_curve::IsZeroOp>(p1);
357+ auto p1IsZeroOp = b.create <scf::IfOp>(
358+ p1isZeroCmp,
359+ /* thenBuilder=*/
360+ [&](OpBuilder &builder, Location loc) {
361+ ImplicitLocOpBuilder b (loc, builder);
362+ ValueRange retP2 = p2Coords;
363+ if (isa<AffineType>(p2Type)) {
364+ retP2 = {
365+ b.create <elliptic_curve::ConvertPointTypeOp>(outputType, p2)};
366+ }
367+
368+ b.create <scf::YieldOp>(retP2);
369+ },
370+ /* elseBuilder=*/
371+ [&](OpBuilder &builder, Location loc) {
372+ ImplicitLocOpBuilder b (loc, builder);
373+
374+ // check p2 == zero point
375+ Value p2isZeroCmp = b.create <elliptic_curve::IsZeroOp>(p2);
376+ auto p2IsZeroOp = b.create <scf::IfOp>(
377+ p2isZeroCmp,
378+ /* thenBuilder=*/
379+ [&](OpBuilder &builder, Location loc) {
380+ ImplicitLocOpBuilder b (loc, builder);
381+ ValueRange retP1 = p1Coords;
382+ if (isa<AffineType>(p1Type)) {
383+ retP1 = {b.create <elliptic_curve::ConvertPointTypeOp>(
384+ outputType, p1)};
385+ }
386+
387+ b.create <scf::YieldOp>(retP1);
388+ },
389+ /* elseBuilder=*/
390+ [&](OpBuilder &builder, Location loc) {
391+ ImplicitLocOpBuilder b (loc, builder);
392+ // run default add
393+ SmallVector<Value> sum;
394+ if (auto xyzzType = dyn_cast<XYZZType>(outputType)) {
395+ sum = xyzzAdd (p1Coords, p2Coords, xyzzType.getCurve (), b);
396+ } else if (auto jacobianType =
397+ dyn_cast<JacobianType>(outputType)) {
398+ sum = jacobianAdd (p1Coords, p2Coords, jacobianType.getCurve (),
399+ b);
400+ } else {
401+ assert (false && " Unsupported point types for addition" );
402+ }
403+ b.create <scf::YieldOp>(loc, sum);
404+ });
405+ b.create <scf::YieldOp>(p2IsZeroOp.getResults ());
406+ });
407+ rewriter.replaceOpWithMultiple (op, {p1IsZeroOp.getResults ()});
330408 return success ();
331409 }
332410};
@@ -428,12 +506,13 @@ struct ConvertScalarMul : public OpConversionPattern<ScalarMulOp> {
428506 auto scalar = b.create <field::ExtractOp>(signlessIntType, scalarPF);
429507
430508 auto zeroPF = b.create <field::ConstantOp>(baseFieldType, 0 );
509+ auto onePF = b.create <field::ConstantOp>(baseFieldType, 1 );
431510 Value zeroPoint =
432511 isa<XYZZType>(outputType)
433512 ? b.create <elliptic_curve::PointOp>(
434- outputType, ValueRange{zeroPF, zeroPF , zeroPF, zeroPF})
513+ outputType, ValueRange{onePF, onePF , zeroPF, zeroPF})
435514 : b.create <elliptic_curve::PointOp>(
436- outputType, ValueRange{zeroPF, zeroPF , zeroPF});
515+ outputType, ValueRange{onePF, onePF , zeroPF});
437516
438517 Value intialPoint =
439518 isa<AffineType>(pointType)
@@ -554,10 +633,11 @@ void EllipticCurveToField::runOnOperation() {
554633
555634 RewritePatternSet patterns (context);
556635 rewrites::populateWithGenerated (patterns);
557- patterns.add <ConvertPoint, ConvertExtract, ConvertConvertPointType,
558- ConvertAdd, ConvertDouble, ConvertNegate, ConvertSub,
559- ConvertScalarMul, ConvertMSM, ConvertAny<tensor::FromElementsOp>,
560- ConvertAny<tensor::ExtractOp>>(typeConverter, context);
636+ patterns
637+ .add <ConvertPoint, ConvertIsZero, ConvertExtract, ConvertConvertPointType,
638+ ConvertAdd, ConvertDouble, ConvertNegate, ConvertSub,
639+ ConvertScalarMul, ConvertMSM, ConvertAny<tensor::FromElementsOp>,
640+ ConvertAny<tensor::ExtractOp>>(typeConverter, context);
561641 target.addDynamicallyLegalOp <tensor::FromElementsOp, tensor::ExtractOp>(
562642 [&](auto op) { return typeConverter.isLegal (op); });
563643
0 commit comments