@@ -344,20 +344,67 @@ struct ConvertAdd : public OpConversionPattern<AddOp> {
344344 ConversionPatternRewriter &rewriter) const override {
345345 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
346346
347+ Value p1 = op.getLhs ();
348+ Value p2 = op.getRhs ();
349+ ValueRange p1Coords = adaptor.getLhs ();
350+ ValueRange p2Coords = adaptor.getRhs ();
351+ Type p1Type = p1.getType ();
352+ Type p2Type = p2.getType ();
347353 Type outputType = op.getOutput ().getType ();
348- ValueRange p1 = adaptor.getLhs ();
349- ValueRange p2 = adaptor.getRhs ();
350- SmallVector<Value> sum;
351-
352- if (auto xyzzType = dyn_cast<XYZZType>(outputType)) {
353- sum = xyzzAdd (p1, p2, xyzzType.getCurve (), b);
354- } else if (auto jacobianType = dyn_cast<JacobianType>(outputType)) {
355- sum = jacobianAdd (p1, p2, jacobianType.getCurve (), b);
356- } else {
357- assert (false && " Unsupported point types for addition" );
358- }
359354
360- rewriter.replaceOpWithMultiple (op, {sum});
355+ // check p1 == zero point
356+ Value p1isZeroCmp = b.create <elliptic_curve::IsZeroOp>(p1);
357+ auto p1IsZeroOp = b.create <scf::IfOp>(
358+ p1isZeroCmp,
359+ /* thenBuilder=*/
360+ [&](OpBuilder &builder, Location loc) {
361+ ImplicitLocOpBuilder b (loc, builder);
362+ ValueRange retP2 = p2Coords;
363+ if (isa<AffineType>(p2Type)) {
364+ retP2 = {
365+ b.create <elliptic_curve::ConvertPointTypeOp>(outputType, p2)};
366+ }
367+
368+ b.create <scf::YieldOp>(retP2);
369+ },
370+ /* elseBuilder=*/
371+ [&](OpBuilder &builder, Location loc) {
372+ ImplicitLocOpBuilder b (loc, builder);
373+
374+ // check p2 == zero point
375+ Value p2isZeroCmp = b.create <elliptic_curve::IsZeroOp>(p2);
376+ auto p2IsZeroOp = b.create <scf::IfOp>(
377+ p2isZeroCmp,
378+ /* thenBuilder=*/
379+ [&](OpBuilder &builder, Location loc) {
380+ ImplicitLocOpBuilder b (loc, builder);
381+ ValueRange retP1 = p1Coords;
382+ if (isa<AffineType>(p1Type)) {
383+ retP1 = {b.create <elliptic_curve::ConvertPointTypeOp>(
384+ outputType, p1)};
385+ }
386+
387+ b.create <scf::YieldOp>(retP1);
388+ },
389+ /* elseBuilder=*/
390+ [&](OpBuilder &builder, Location loc) {
391+ ImplicitLocOpBuilder b (loc, builder);
392+ // run default add
393+ SmallVector<Value> sum;
394+ if (auto xyzzType = dyn_cast<XYZZType>(outputType)) {
395+ sum = xyzzAdd (p1Coords, p2Coords, xyzzType.getCurve (), b);
396+ } else if (auto jacobianType =
397+ dyn_cast<JacobianType>(outputType)) {
398+ sum = jacobianAdd (p1Coords, p2Coords, jacobianType.getCurve (),
399+ b);
400+ } else {
401+ assert (false && " Unsupported point types for addition" );
402+ }
403+ b.create <scf::YieldOp>(loc, sum);
404+ });
405+ b.create <scf::YieldOp>(p2IsZeroOp.getResults ());
406+ });
407+ rewriter.replaceOpWithMultiple (op, {p1IsZeroOp.getResults ()});
361408 return success ();
362409 }
363410};
0 commit comments