@@ -613,6 +613,74 @@ struct ConvertScalarMul : public OpConversionPattern<ScalarMulOp> {
613613 }
614614};
615615
616+ struct ConvertMSM : public OpConversionPattern <MSMOp> {
617+ explicit ConvertMSM (mlir::MLIRContext *context)
618+ : OpConversionPattern<MSMOp>(context) {}
619+
620+ using OpConversionPattern::OpConversionPattern;
621+
622+ LogicalResult matchAndRewrite (
623+ MSMOp op, OpAdaptor adaptor,
624+ ConversionPatternRewriter &rewriter) const override {
625+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
626+
627+ // 2D tensor of prime field values, e.g.:
628+ // |x1, y1, z1|
629+ // |x2, y2, z2|
630+ Value loweredPointSet = adaptor.getPoints ();
631+ // 1d tensor of PF, e.g.:
632+ // | s1 , s2 |
633+ Value scalars = op.getScalars ();
634+ RankedTensorType loweredPointSetType =
635+ cast<RankedTensorType>(loweredPointSet.getType ());
636+ field::PrimeFieldType baseFieldType =
637+ cast<field::PrimeFieldType>(loweredPointSetType.getElementType ());
638+ unsigned numScalarMuls = loweredPointSetType.getShape ()[0 ];
639+ unsigned numCoords = loweredPointSetType.getShape ()[1 ];
640+
641+ Type inputPointType =
642+ cast<RankedTensorType>(op.getPoints ().getType ()).getElementType ();
643+ Type outputPointType = op.getOutput ().getType ();
644+ RankedTensorType loweredOutputPointType =
645+ RankedTensorType::get ({numCoords}, baseFieldType);
646+
647+ Value accumulator;
648+ auto zero = b.create <arith::ConstantIndexOp>(0 );
649+ auto one = b.create <arith::ConstantIndexOp>(1 );
650+ auto sz = b.create <arith::ConstantIndexOp>(numCoords);
651+ SmallVector<Value> sizes{one, sz};
652+ SmallVector<Value> strides{one, one};
653+ for (size_t i = 0 ; i < numScalarMuls; ++i) {
654+ auto idx = b.create <arith::ConstantIndexOp>(i);
655+
656+ // scalar
657+ auto scalar = b.create <tensor::ExtractOp>(scalars, ValueRange{idx});
658+
659+ // point
660+ // - extract point = tensor<2x3x!PF> -> tensor<1x3x!PF>
661+ SmallVector<Value> offsets{idx, zero};
662+ auto higherRankedPoint = b.create <tensor::ExtractSliceOp>(
663+ loweredPointSet, offsets, sizes, strides);
664+ // - reshape point = tensor<1x3x!PF> -> tensor<3x!PF>
665+ SmallVector<Value> _outputShape{sz};
666+ auto outputShape = b.create <tensor::FromElementsOp>(_outputShape);
667+ auto point = b.create <tensor::ReshapeOp>(loweredOutputPointType,
668+ higherRankedPoint, outputShape);
669+
670+ Value adder = convertScalarMulImpl (point, scalar, inputPointType,
671+ outputPointType, b);
672+ if (i != 0 ) {
673+ accumulator = convertAddImpl (accumulator, adder, outputPointType,
674+ outputPointType, outputPointType, b);
675+ } else {
676+ accumulator = adder;
677+ }
678+ }
679+ rewriter.replaceOp (op, accumulator);
680+ return success ();
681+ }
682+ };
683+
616684namespace rewrites {
617685// In an inner namespace to avoid conflicts with canonicalization patterns
618686#include " zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/EllipticCurveToField.cpp.inc"
@@ -636,10 +704,11 @@ void EllipticCurveToField::runOnOperation() {
636704
637705 RewritePatternSet patterns (context);
638706 rewrites::populateWithGenerated (patterns);
639- patterns.add <ConvertPoint, ConvertPointSet, ConvertPointSetExtract,
640- ConvertExtract, ConvertConvertPointType, ConvertAdd,
641- ConvertDouble, ConvertNegate, ConvertSub, ConvertScalarMul>(
642- typeConverter, context);
707+ patterns
708+ .add <ConvertPoint, ConvertPointSet, ConvertPointSetExtract,
709+ ConvertExtract, ConvertConvertPointType, ConvertAdd, ConvertDouble,
710+ ConvertNegate, ConvertSub, ConvertScalarMul, ConvertMSM>(
711+ typeConverter, context);
643712
644713 addStructuralConversionPatterns (typeConverter, patterns, target);
645714
0 commit comments