Skip to content

Commit ebc929f

Browse files
authored
[Stablehlo] Add broadcasting support for Remainder op in torch mlir -> stable hlo (#4039)
1 parent a265d28 commit ebc929f

File tree

1 file changed

+1
-19
lines changed

1 file changed

+1
-19
lines changed

lib/Conversion/TorchToStablehlo/Basic.cpp

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1966,24 +1966,6 @@ LogicalResult ConvertAtenOp<AtenFlipOp>::matchAndRewrite(
19661966
return success();
19671967
}
19681968

1969-
// AtenRemainderTensorOp
1970-
template <>
1971-
LogicalResult ConvertAtenOp<AtenRemainderTensorOp>::matchAndRewrite(
1972-
AtenRemainderTensorOp op, OpAdaptor adaptor,
1973-
ConversionPatternRewriter &rewriter) const {
1974-
Value lhs = adaptor.getSelf();
1975-
Value rhs = adaptor.getOther();
1976-
1977-
auto resultType =
1978-
cast<RankedTensorType>(getTypeConverter()->convertType(op.getType()));
1979-
lhs = hlo::promoteType(rewriter, op->getLoc(), lhs,
1980-
resultType.getElementType());
1981-
rhs = hlo::promoteType(rewriter, op->getLoc(), rhs,
1982-
resultType.getElementType());
1983-
rewriter.replaceOpWithNewOp<stablehlo::RemOp>(op, lhs, rhs);
1984-
return success();
1985-
}
1986-
19871969
// AtenFmodTensorOp
19881970
// torch.fmod(a, b) == a - a.div(b, rounding_mode="trunc") * b
19891971
template <>
@@ -2231,6 +2213,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
22312213
INSERT_BINARY_MULDIV_PATTERN(AtenDivTensorModeOp, chlo::BroadcastDivOp);
22322214
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarOp, chlo::BroadcastDivOp);
22332215
INSERT_BINARY_MULDIV_PATTERN(AtenDivScalarModeOp, chlo::BroadcastDivOp);
2216+
INSERT_BINARY_MULDIV_PATTERN(AtenRemainderTensorOp, chlo::BroadcastRemOp);
22342217
INSERT_BINARY_MULDIV_PATTERN(AtenRemainderScalarOp, chlo::BroadcastRemOp);
22352218
#undef INSERT_BINARY_MULDIV_PATTERN
22362219

@@ -2310,7 +2293,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality(
23102293
INSERT_ATENOP_PATTERN(AtenEmptyMemoryFormatOp);
23112294
INSERT_ATENOP_PATTERN(AtenFillScalarOp);
23122295
INSERT_ATENOP_PATTERN(AtenFlipOp);
2313-
INSERT_ATENOP_PATTERN(AtenRemainderTensorOp);
23142296
INSERT_ATENOP_PATTERN(AtenFmodTensorOp);
23152297
INSERT_ATENOP_PATTERN(AtenBitwiseLeftShiftTensorOp);
23162298
INSERT_ATENOP_PATTERN(AtenBitwiseRightShiftTensorOp);

0 commit comments

Comments
 (0)