Skip to content

Commit ad6e390

Browse files
ashjeongchokobole
authored andcommitted
feat(elliptic_curve): lower MSMOp
1 parent c80232a commit ad6e390

File tree

2 files changed

+84
-4
lines changed

2 files changed

+84
-4
lines changed

tests/Dialect/EllipticCurve/elliptic_curve_to_field.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,3 +192,14 @@ func.func @test_point_set() {
192192
%doubled = elliptic_curve.double %point1 : !affine -> !jacobian
193193
return
194194
}
195+
196+
func.func @test_msm() {
197+
%var1 = field.pf.constant 1 : !PF
198+
%var5 = field.pf.constant 5 : !PF
199+
200+
%scalars = tensor.from_elements %var1, %var5, %var5 : tensor<3x!PF>
201+
%affine1 = elliptic_curve.point %var1, %var5 : !PF -> !affine
202+
%points = elliptic_curve.point_set.from_elements %affine1, %affine1, %affine1 : tensor<3x!affine>
203+
%msm_result = elliptic_curve.msm %scalars, %points : tensor<3x!PF>, tensor<3x!affine> -> !jacobian
204+
return
205+
}

zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/EllipticCurveToField.cpp

Lines changed: 73 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -613,6 +613,74 @@ struct ConvertScalarMul : public OpConversionPattern<ScalarMulOp> {
613613
}
614614
};
615615

616+
struct ConvertMSM : public OpConversionPattern<MSMOp> {
617+
explicit ConvertMSM(mlir::MLIRContext *context)
618+
: OpConversionPattern<MSMOp>(context) {}
619+
620+
using OpConversionPattern::OpConversionPattern;
621+
622+
LogicalResult matchAndRewrite(
623+
MSMOp op, OpAdaptor adaptor,
624+
ConversionPatternRewriter &rewriter) const override {
625+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
626+
627+
// 2D tensor of prime field values, e.g.:
628+
// |x1, y1, z1|
629+
// |x2, y2, z2|
630+
Value loweredPointSet = adaptor.getPoints();
631+
// 1d tensor of PF, e.g.:
632+
// | s1 , s2 |
633+
Value scalars = op.getScalars();
634+
RankedTensorType loweredPointSetType =
635+
cast<RankedTensorType>(loweredPointSet.getType());
636+
field::PrimeFieldType baseFieldType =
637+
cast<field::PrimeFieldType>(loweredPointSetType.getElementType());
638+
unsigned numScalarMuls = loweredPointSetType.getShape()[0];
639+
unsigned numCoords = loweredPointSetType.getShape()[1];
640+
641+
Type inputPointType =
642+
cast<RankedTensorType>(op.getPoints().getType()).getElementType();
643+
Type outputPointType = op.getOutput().getType();
644+
RankedTensorType loweredOutputPointType =
645+
RankedTensorType::get({numCoords}, baseFieldType);
646+
647+
Value accumulator;
648+
auto zero = b.create<arith::ConstantIndexOp>(0);
649+
auto one = b.create<arith::ConstantIndexOp>(1);
650+
auto sz = b.create<arith::ConstantIndexOp>(numCoords);
651+
SmallVector<Value> sizes{one, sz};
652+
SmallVector<Value> strides{one, one};
653+
for (size_t i = 0; i < numScalarMuls; ++i) {
654+
auto idx = b.create<arith::ConstantIndexOp>(i);
655+
656+
// scalar
657+
auto scalar = b.create<tensor::ExtractOp>(scalars, ValueRange{idx});
658+
659+
// point
660+
// - extract point = tensor<2x3x!PF> -> tensor<1x3x!PF>
661+
SmallVector<Value> offsets{idx, zero};
662+
auto higherRankedPoint = b.create<tensor::ExtractSliceOp>(
663+
loweredPointSet, offsets, sizes, strides);
664+
// - reshape point = tensor<1x3x!PF> -> tensor<3x!PF>
665+
SmallVector<Value> _outputShape{sz};
666+
auto outputShape = b.create<tensor::FromElementsOp>(_outputShape);
667+
auto point = b.create<tensor::ReshapeOp>(loweredOutputPointType,
668+
higherRankedPoint, outputShape);
669+
670+
Value adder = convertScalarMulImpl(point, scalar, inputPointType,
671+
outputPointType, b);
672+
if (i != 0) {
673+
accumulator = convertAddImpl(accumulator, adder, outputPointType,
674+
outputPointType, outputPointType, b);
675+
} else {
676+
accumulator = adder;
677+
}
678+
}
679+
rewriter.replaceOp(op, accumulator);
680+
return success();
681+
}
682+
};
683+
616684
namespace rewrites {
617685
// In an inner namespace to avoid conflicts with canonicalization patterns
618686
#include "zkir/Dialect/EllipticCurve/Conversions/EllipticCurveToField/EllipticCurveToField.cpp.inc"
@@ -636,10 +704,11 @@ void EllipticCurveToField::runOnOperation() {
636704

637705
RewritePatternSet patterns(context);
638706
rewrites::populateWithGenerated(patterns);
639-
patterns.add<ConvertPoint, ConvertPointSet, ConvertPointSetExtract,
640-
ConvertExtract, ConvertConvertPointType, ConvertAdd,
641-
ConvertDouble, ConvertNegate, ConvertSub, ConvertScalarMul>(
642-
typeConverter, context);
707+
patterns
708+
.add<ConvertPoint, ConvertPointSet, ConvertPointSetExtract,
709+
ConvertExtract, ConvertConvertPointType, ConvertAdd, ConvertDouble,
710+
ConvertNegate, ConvertSub, ConvertScalarMul, ConvertMSM>(
711+
typeConverter, context);
643712

644713
addStructuralConversionPatterns(typeConverter, patterns, target);
645714

0 commit comments

Comments
 (0)