@@ -738,16 +738,20 @@ struct ReduceMeanConverter final
738738 ConversionPatternRewriter &rewriter) const final ;
739739};
740740
741- struct ReduceSumConverter final
742- : public OpConversionPattern<migraphx::ReduceSumOp> {
743- using OpConversionPattern<migraphx::ReduceSumOp>::OpConversionPattern;
741+ namespace {
742+ template <typename MIGraphXOp, typename TosaOp>
743+ struct ReduceConverter final : public OpConversionPattern<MIGraphXOp> {
744+ using OpConversionPattern<MIGraphXOp>::OpConversionPattern;
745+ using OpAdaptor = typename OpConversionPattern<MIGraphXOp>::OpAdaptor;
744746
745747 LogicalResult
746- matchAndRewrite (migraphx::ReduceSumOp op, OpAdaptor adaptor,
748+ matchAndRewrite (MIGraphXOp op, OpAdaptor adaptor,
747749 ConversionPatternRewriter &rewriter) const final ;
748750};
749751} // namespace
750752
753+ } // namespace
754+
751755tosa::ConstOp ReduceMeanConverter::createNumElementsTosaConst (
752756 Location loc, TypedValue<RankedTensorType> inputTensor,
753757 IntegerAttr axisAttr, ConversionPatternRewriter &rewriter) const {
@@ -805,9 +809,10 @@ LogicalResult ReduceMeanConverter::matchAndRewrite(
805809 return success ();
806810}
807811
808- LogicalResult
809- ReduceSumConverter::matchAndRewrite (migraphx::ReduceSumOp op, OpAdaptor adaptor,
810- ConversionPatternRewriter &rewriter) const {
812+ template <typename MIGraphXOp, typename TosaOp>
813+ LogicalResult ReduceConverter<MIGraphXOp, TosaOp>::matchAndRewrite(
814+ MIGraphXOp op, OpAdaptor adaptor,
815+ ConversionPatternRewriter &rewriter) const {
811816 Location loc = op.getLoc ();
812817 ArrayRef<Attribute> axes = op.getAxes ().getValue ();
813818 if (axes.size () != 1 ) {
@@ -817,9 +822,9 @@ ReduceSumConverter::matchAndRewrite(migraphx::ReduceSumOp op, OpAdaptor adaptor,
817822 rewriter.getI32IntegerAttr (cast<IntegerAttr>(axes[0 ]).getInt ());
818823 auto input = cast<TypedValue<RankedTensorType>>(adaptor.getInput ());
819824 Type elementType = input.getType ().getElementType ();
820- auto tosaReduceSum = createOpAndInfer<tosa::ReduceSumOp>(
821- rewriter, loc, elementType, input, axis);
822- rewriter.replaceOp (op, tosaReduceSum );
825+ auto tosaReduce =
826+ createOpAndInfer<TosaOp>( rewriter, loc, elementType, input, axis);
827+ rewriter.replaceOp (op, tosaReduce );
823828 return success ();
824829}
825830
@@ -1493,7 +1498,9 @@ void migraphx::populateMIGraphXToTosaConversionPatterns(
14931498 DotConverter<DotOp>, DotConverter<QuantDotOp>,
14941499 BroadcastConverter, MultiBroadcastConverter, TransposeConverter,
14951500 ReshapeConverter, SliceConverter, ReduceMeanConverter,
1496- ReduceSumConverter, TrivialConverter<AddOp, tosa::AddOp>,
1501+ ReduceConverter<ReduceSumOp, tosa::ReduceSumOp>,
1502+ ReduceConverter<ReduceMaxOp, tosa::ReduceMaxOp>,
1503+ TrivialConverter<AddOp, tosa::AddOp>,
14971504 TrivialConverter<SubOp, tosa::SubOp>,
14981505 TrivialConverter<PowOp, tosa::PowOp>, DivConverter, MulConverter,
14991506 TrivialConverter<AbsOp, tosa::AbsOp>,
0 commit comments