@@ -11976,6 +11976,54 @@ class DecomposeAten_AssertScalarOp
11976
11976
};
11977
11977
} // namespace
11978
11978
11979
+ namespace {
11980
+ class DecomposeAtenRoundDecimalsOp
11981
+ : public OpRewritePattern<AtenRoundDecimalsOp> {
11982
+ public:
11983
+ using OpRewritePattern<AtenRoundDecimalsOp>::OpRewritePattern;
11984
+ LogicalResult matchAndRewrite (AtenRoundDecimalsOp op,
11985
+ PatternRewriter &rewriter) const override {
11986
+ // AtenRoundDecimalsOp is decomposed as follows if the decimals value is
11987
+ // non-zero: scale = 10 ** decimals return round(x * scale) / scale
11988
+ // otherwise:
11989
+ // return round(x)
11990
+
11991
+ auto loc = op.getLoc ();
11992
+ auto input = op.getSelf ();
11993
+ auto inputType = cast<BaseTensorType>(input.getType ());
11994
+
11995
+ if (!inputType.hasDtype () || !isa<mlir::FloatType>(inputType.getDtype ())) {
11996
+ return rewriter.notifyMatchFailure (
11997
+ op, " unimplemented: non-floating point dtype" );
11998
+ }
11999
+
12000
+ int64_t decimals;
12001
+ if (!matchPattern (op.getDecimals (), m_TorchConstantInt (&decimals))) {
12002
+ return rewriter.notifyMatchFailure (
12003
+ op, " non-constant decimal point is not supported." );
12004
+ }
12005
+
12006
+ Value newOp = op->getOperand (0 );
12007
+ Value scale;
12008
+ if (decimals) {
12009
+ auto scaleVal = pow (10 , decimals);
12010
+ scale = rewriter.create <ConstantFloatOp>(
12011
+ loc, rewriter.getF64FloatAttr (scaleVal));
12012
+ newOp = rewriter.create <AtenMulScalarOp>(loc, op.getType (), input, scale);
12013
+ }
12014
+
12015
+ newOp = rewriter.create <AtenRoundOp>(loc, op.getType (), newOp);
12016
+
12017
+ if (decimals) {
12018
+ newOp = rewriter.create <AtenDivScalarOp>(loc, op.getType (), newOp, scale);
12019
+ }
12020
+
12021
+ rewriter.replaceOp (op, newOp);
12022
+ return success ();
12023
+ }
12024
+ };
12025
+ } // namespace
12026
+
11979
12027
namespace {
11980
12028
class DecomposeComplexOpsPass
11981
12029
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
@@ -12291,6 +12339,7 @@ class DecomposeComplexOpsPass
12291
12339
addPatternIfTargetOpIsIllegal<DecomposeAtenConstrainRangeForSizeOp>(
12292
12340
patterns);
12293
12341
addPatternIfTargetOpIsIllegal<DecomposeAten_AssertScalarOp>(patterns);
12342
+ addPatternIfTargetOpIsIllegal<DecomposeAtenRoundDecimalsOp>(patterns);
12294
12343
12295
12344
GreedyRewriteConfig config;
12296
12345
config.setUseTopDownTraversal (true );
0 commit comments