@@ -96,7 +96,7 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy,
9696 }
9797 }
9898
99- if (isa<AtenProdOp>(op)) {
99+ if (isa<AtenProdOp, AtenProdDimIntOp >(op)) {
100100 if (isa<mlir::FloatType>(elementTy)) {
101101 APFloat one (cast<mlir::FloatType>(elementTy).getFloatSemantics (), 1 );
102102 auto constAttr = DenseElementsAttr::get (constType, one);
@@ -172,7 +172,7 @@ static Value createReduceOpWithSingleRegionOp(Operation *op, Value input,
172172 } else if (isa<AtenAnyOp, AtenAnyDimOp>(op)) {
173173 result = rewriter.create <stablehlo::OrOp>(
174174 op->getLoc (), blockArgumentTy, *firstArgument, *secondArgument);
175- } else if (isa<AtenProdOp>(op)) {
175+ } else if (isa<AtenProdOp, AtenProdDimIntOp >(op)) {
176176 result = rewriter.create <stablehlo::MulOp>(
177177 op->getLoc (), blockArgumentTy, *firstArgument, *secondArgument);
178178 } else {
@@ -689,6 +689,69 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
689689}
690690} // namespace
691691
692+ // AtenProdDimIntOp
693+ namespace {
694+ template <>
695+ LogicalResult ConvertAtenReductionOp<AtenProdDimIntOp>::matchAndRewrite(
696+ AtenProdDimIntOp op, OpAdaptor adaptor,
697+ ConversionPatternRewriter &rewriter) const {
698+ Value input = adaptor.getSelf ();
699+ auto inputTy = dyn_cast<RankedTensorType>(input.getType ());
700+ auto outTy =
701+ dyn_cast<RankedTensorType>(getTypeConverter ()->convertType (op.getType ()));
702+ if (!inputTy) {
703+ return rewriter.notifyMatchFailure (
704+ op, " only Tensor types supported in StableHLO" );
705+ }
706+ if (inputTy.getElementType () != outTy.getElementType ()) {
707+ // Use output element type as computation type.
708+ auto dstElemTy = outTy.getElementType ();
709+ input =
710+ rewriter.create <stablehlo::ConvertOp>(op->getLoc (), input, dstElemTy);
711+ inputTy = dyn_cast<RankedTensorType>(input.getType ());
712+ }
713+ auto inputElemTy = inputTy.getElementType ();
714+ if (!inputElemTy.isIntOrFloat ()) {
715+ return op.emitError (
716+ " Only floating-point or integer datatype legalization supported" );
717+ }
718+
719+ int64_t dim;
720+ if (!matchPattern (op.getDim (), m_TorchConstantInt (&dim))) {
721+ return rewriter.notifyMatchFailure (
722+ op, " non-const integer `dim` is not supported" );
723+ }
724+ dim = toPositiveDim (dim, inputTy.getRank ());
725+ SmallVector<int64_t > reduceResultShape =
726+ getReduceOutputShape (inputTy.getShape (), {dim});
727+
728+ bool keepDim = false ;
729+ if (!matchPattern (op.getKeepdim (), m_TorchConstantBool (&keepDim))) {
730+ return rewriter.notifyMatchFailure (op, " non-bool keepdim unsupported" );
731+ }
732+
733+ Value reduceResult = createReduceOpWithSingleRegionOp (
734+ op, input,
735+ RankedTensorType::get (reduceResultShape, outTy.getElementType ()), dim,
736+ rewriter);
737+ if (!reduceResult) {
738+ return op->emitError (" createReduceOpWithSingleRegionOp return nullptr" );
739+ }
740+
741+ if (keepDim) {
742+ auto outShapeInfo = hlo::getDimIndexOfTensor (rewriter, op, input);
743+ if (failed (outShapeInfo)) {
744+ return rewriter.notifyMatchFailure (
745+ op, " failed to get dimension sizes of the input" );
746+ }
747+ reduceResult = reshapeReduceResultWhenKeepDim (
748+ rewriter, op->getLoc (), reduceResult, *outShapeInfo, outTy, dim);
749+ }
750+ rewriter.replaceOp (op, reduceResult);
751+ return success ();
752+ }
753+ } // namespace
754+
692755// AtenFrobeniusNormDimOp
693756// aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given
694757// dims) + stablehlo.sqrt
@@ -868,6 +931,7 @@ void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality(
868931 INSERT_ATEN_REDUCTION_OP_PATTERN (AtenSumDimIntListOp);
869932 INSERT_ATEN_REDUCTION_OP_PATTERN (AtenFrobeniusNormDimOp);
870933 INSERT_ATEN_REDUCTION_OP_PATTERN (AtenLinalgVectorNormOp);
934+ INSERT_ATEN_REDUCTION_OP_PATTERN (AtenProdDimIntOp);
871935#undef INSERT_ATEN_REDUCTION_OP_PATTERN
872936
873937#define INSERT_ATEN_REDUCTION_ALL_DIMS_OP_PATTERN (AtenOp ) \
0 commit comments