@@ -8056,105 +8056,17 @@ class DecomposeAtenToDeviceOp : public OpRewritePattern<AtenToDeviceOp> {
8056
8056
} // namespace
8057
8057
8058
8058
namespace {
8059
+ // Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d`.
8059
8060
// Decompose `aten.adaptive_max_pool1d` op into `aten.max_pool1d_with_indices`
8060
- // op.
8061
- class DecomposeAtenAdaptiveMaxPool1dOp
8062
- : public OpRewritePattern<AtenAdaptiveMaxPool1dOp> {
8063
- using OpRewritePattern<AtenAdaptiveMaxPool1dOp>::OpRewritePattern;
8064
- LogicalResult matchAndRewrite (AtenAdaptiveMaxPool1dOp op,
8065
- PatternRewriter &rewriter) const override {
8066
- Location loc = op->getLoc ();
8067
- MLIRContext *context = op.getContext ();
8068
-
8069
- Value input = op.getSelf ();
8070
- std::optional<unsigned > maybeRank = getTensorRank (input);
8071
- if (!maybeRank) {
8072
- return rewriter.notifyMatchFailure (op, " expected input to have a rank" );
8073
- }
8074
- unsigned rank = *maybeRank;
8075
- Value sizeDim = rewriter.create <Torch::ConstantIntOp>(
8076
- loc, rewriter.getI64IntegerAttr (rank - 1 ));
8077
- Value inputSize = rewriter.create <AtenSizeIntOp>(loc, input, sizeDim);
8078
-
8079
- Value outputShape = op.getOutputSize ();
8080
- SmallVector<Value> outputShapeSizesTorchInt;
8081
- getListConstructElements (outputShape, outputShapeSizesTorchInt);
8082
- Value outputSize = outputShapeSizesTorchInt[0 ];
8083
-
8084
- Value constantOne = rewriter.create <Torch::ConstantIntOp>(
8085
- loc, rewriter.getI64IntegerAttr (1 ));
8086
- Value constantZero = rewriter.create <Torch::ConstantIntOp>(
8087
- loc, rewriter.getI64IntegerAttr (0 ));
8088
- Value constantFalse = rewriter.create <Torch::ConstantBoolOp>(loc, false );
8089
-
8090
- int64_t outputSizeInt;
8091
- if (!matchPattern (outputSize, m_TorchConstantInt (&outputSizeInt))) {
8092
- return rewriter.notifyMatchFailure (
8093
- op, " the output size of adaptive_max_pool1d must be a constant int" );
8094
- }
8095
-
8096
- SmallVector<Value, 1 > kernelSize;
8097
- if (outputSizeInt == 1 ) {
8098
- BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType ());
8099
- ArrayRef<int64_t > inputShape = inputTensorType.getSizes ();
8100
- kernelSize.push_back (
8101
- inputShape[rank - 1 ] == kUnknownSize
8102
- ? inputSize
8103
- : rewriter.create <Torch::ConstantIntOp>(
8104
- loc, rewriter.getI64IntegerAttr (inputShape[rank - 1 ])));
8105
- } else {
8106
- if (!isAssumingStrictSymbolicShapes (rewriter)) {
8107
- Value cond = rewriter.create <AtenEqIntOp>(loc, inputSize, outputSize);
8108
- rewriter.create <RuntimeAssertOp>(
8109
- loc, cond,
8110
- " unimplemented: only support cases where input and output size are "
8111
- " equal for non-unit output size" );
8112
- }
8113
- kernelSize.push_back (constantOne);
8114
- }
8115
-
8116
- Value kernelSizeList = rewriter.create <PrimListConstructOp>(
8117
- loc, Torch::ListType::get (Torch::IntType::get (context)), kernelSize);
8118
- Value strideList = rewriter.create <PrimListConstructOp>(
8119
- loc, Torch::ListType::get (Torch::IntType::get (context)),
8120
- ValueRange{constantOne});
8121
- Value paddingSizeList = rewriter.create <PrimListConstructOp>(
8122
- loc, Torch::ListType::get (Torch::IntType::get (context)),
8123
- ValueRange{constantZero});
8124
- Value dialationList = rewriter.create <PrimListConstructOp>(
8125
- loc, Torch::ListType::get (Torch::IntType::get (context)),
8126
- ValueRange{constantOne});
8127
-
8128
- if (op.getResult (1 ).use_empty ()) {
8129
- auto maxPool = rewriter.create <AtenMaxPool1dOp>(
8130
- loc, op.getType (0 ), input, kernelSizeList, strideList,
8131
- paddingSizeList, dialationList,
8132
- /* ceil_mode=*/ constantFalse);
8133
- rewriter.replaceOp (op, {maxPool.getResult (), Value ()});
8134
- } else {
8135
- auto maxPool = rewriter.create <AtenMaxPool1dWithIndicesOp>(
8136
- loc, op.getType (0 ), op.getType (1 ), input, kernelSizeList, strideList,
8137
- paddingSizeList, dialationList,
8138
- /* ceil_mode=*/ constantFalse);
8139
- rewriter.replaceOp (op, maxPool.getResults ());
8140
- }
8141
- return success ();
8142
- }
8143
- };
8144
- } // namespace
8145
-
8146
- namespace {
8147
- // Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op.
8148
-
8149
- // The logic of this decomposition is totally same with
8150
- // the DecomposeAtenAdaptiveAvgPool2dOp, that means currently only following two
8151
- // cases are supported:
8061
+ // or `aten.max_pool1d`.
8062
+ //
8063
+ // Only following two cases are supported:
8152
8064
// 1. inputSize = outputSize
8153
8065
// 2. outputSize = 1
8154
- class DecomposeAtenAdaptiveAvgPool1dOp
8155
- : public OpRewritePattern<AtenAdaptiveAvgPool1dOp > {
8156
- using OpRewritePattern<AtenAdaptiveAvgPool1dOp >::OpRewritePattern;
8157
- LogicalResult matchAndRewrite (AtenAdaptiveAvgPool1dOp op,
8066
+ template < typename AtenOpT>
8067
+ class DecomposeAtenAdaptivePool1dOp : public OpRewritePattern <AtenOpT > {
8068
+ using OpRewritePattern<AtenOpT >::OpRewritePattern;
8069
+ LogicalResult matchAndRewrite (AtenOpT op,
8158
8070
PatternRewriter &rewriter) const override {
8159
8071
Location loc = op->getLoc ();
8160
8072
MLIRContext *context = op.getContext ();
@@ -8167,11 +8079,10 @@ class DecomposeAtenAdaptiveAvgPool1dOp
8167
8079
unsigned rank = *maybeRank;
8168
8080
Value sizeDim = rewriter.create <Torch::ConstantIntOp>(
8169
8081
loc, rewriter.getI64IntegerAttr (rank - 1 ));
8170
- Value inputSize = rewriter.create <AtenSizeIntOp>(loc, input, sizeDim);
8082
+ Value inputSize = rewriter.createOrFold <AtenSizeIntOp>(loc, input, sizeDim);
8171
8083
8172
- Value outputShape = op.getOutputSize ();
8173
8084
SmallVector<Value> outputShapeSizesTorchInt;
8174
- getListConstructElements (outputShape , outputShapeSizesTorchInt);
8085
+ getListConstructElements (op. getOutputSize () , outputShapeSizesTorchInt);
8175
8086
Value outputSize = outputShapeSizesTorchInt[0 ];
8176
8087
8177
8088
Value constantOne = rewriter.create <Torch::ConstantIntOp>(
@@ -8184,18 +8095,12 @@ class DecomposeAtenAdaptiveAvgPool1dOp
8184
8095
int64_t outputSizeInt;
8185
8096
if (!matchPattern (outputSize, m_TorchConstantInt (&outputSizeInt))) {
8186
8097
return rewriter.notifyMatchFailure (
8187
- op, " the output size of adaptive_pool_1d must be a constant int" );
8098
+ op, " the output size of adaptive pool1d must be a constant int" );
8188
8099
}
8189
8100
8190
8101
SmallVector<Value, 1 > kernelSize;
8191
8102
if (outputSizeInt == 1 ) {
8192
- BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType ());
8193
- ArrayRef<int64_t > inputShape = inputTensorType.getSizes ();
8194
- kernelSize.push_back (
8195
- inputShape[rank - 1 ] == kUnknownSize
8196
- ? inputSize
8197
- : rewriter.create <Torch::ConstantIntOp>(
8198
- loc, rewriter.getI64IntegerAttr (inputShape[rank - 1 ])));
8103
+ kernelSize.push_back (inputSize);
8199
8104
} else {
8200
8105
if (!isAssumingStrictSymbolicShapes (rewriter)) {
8201
8106
Value cond = rewriter.create <AtenEqIntOp>(loc, inputSize, outputSize);
@@ -8216,16 +8121,40 @@ class DecomposeAtenAdaptiveAvgPool1dOp
8216
8121
loc, Torch::ListType::get (Torch::IntType::get (context)),
8217
8122
ValueRange{constantZero});
8218
8123
8219
- rewriter.replaceOpWithNewOp <AtenAvgPool1dOp>(
8220
- op, op.getType (), input, kernelSizeList, strideList, paddingSizeList,
8221
- /* ceil_mode=*/ constantFalse, /* count_include_pad=*/ constantTrue);
8222
- return success ();
8124
+ if constexpr (std::is_same_v<AtenAdaptiveAvgPool1dOp, AtenOpT>) {
8125
+ rewriter.replaceOpWithNewOp <AtenAvgPool1dOp>(
8126
+ op, op.getType (), input, kernelSizeList, strideList, paddingSizeList,
8127
+ /* ceil_mode=*/ constantFalse, /* count_include_pad=*/ constantTrue);
8128
+ return success ();
8129
+ } else if constexpr (std::is_same_v<AtenAdaptiveMaxPool1dOp, AtenOpT>) {
8130
+ Value dilationList = rewriter.create <PrimListConstructOp>(
8131
+ loc, Torch::ListType::get (Torch::IntType::get (context)),
8132
+ ValueRange{constantOne});
8133
+ if (op.getResult (1 ).use_empty ()) {
8134
+ auto maxPool = rewriter.create <AtenMaxPool1dOp>(
8135
+ loc, op.getType (0 ), input, kernelSizeList, strideList,
8136
+ paddingSizeList, dilationList,
8137
+ /* ceil_mode=*/ constantFalse);
8138
+ rewriter.replaceOp (op, {maxPool.getResult (), Value ()});
8139
+ } else {
8140
+ auto maxPool = rewriter.create <AtenMaxPool1dWithIndicesOp>(
8141
+ loc, op.getType (0 ), op.getType (1 ), input, kernelSizeList,
8142
+ strideList, paddingSizeList, dilationList,
8143
+ /* ceil_mode=*/ constantFalse);
8144
+ rewriter.replaceOp (op, maxPool.getResults ());
8145
+ }
8146
+ return success ();
8147
+ }
8148
+ return rewriter.notifyMatchFailure (
8149
+ op, " unimplemented: unsupported template op" );
8223
8150
}
8224
8151
};
8225
8152
} // namespace
8226
8153
8227
8154
namespace {
8228
- // Decompose `aten.adaptiveAvgPool2d` op into `aten.avgPool2d` op.
8155
+ // Decompose `aten.adaptive_avg_pool2d` op into `aten.avg_pool2d` op.
8156
+ // Decompose `aten.adaptive_max_pool2d` op into `aten.max_pool2d` or
8157
+ // `aten.max_pool2d_with_indices` op.
8229
8158
//
8230
8159
// For AdaptiveAvgPool2d op, when the input size is an integer multiple of
8231
8160
// output size the kernelSize, stride and padding is calculated as follows:
@@ -8235,10 +8164,10 @@ namespace {
8235
8164
// kernelW = inW - [(outW - 1) * strideW] = strideW
8236
8165
// paddingH = 0, paddingW = 0
8237
8166
//
8238
- class DecomposeAtenAdaptiveAvgPool2dOp
8239
- : public OpRewritePattern<AtenAdaptiveAvgPool2dOp > {
8240
- using OpRewritePattern::OpRewritePattern;
8241
- LogicalResult matchAndRewrite (AtenAdaptiveAvgPool2dOp op,
8167
+ template < typename AtenOpT>
8168
+ class DecomposeAtenAdaptivePool2dOp : public OpRewritePattern <AtenOpT > {
8169
+ using OpRewritePattern<AtenOpT> ::OpRewritePattern;
8170
+ LogicalResult matchAndRewrite (AtenOpT op,
8242
8171
PatternRewriter &rewriter) const override {
8243
8172
8244
8173
Location loc = op.getLoc ();
@@ -8254,15 +8183,14 @@ class DecomposeAtenAdaptiveAvgPool2dOp
8254
8183
Value dimH = rewriter.create <Torch::ConstantIntOp>(
8255
8184
loc, rewriter.getI64IntegerAttr (rank - 2 ));
8256
8185
inputHW.push_back (
8257
- /* inH=*/ rewriter.create <AtenSizeIntOp>(loc, input, dimH));
8186
+ /* inH=*/ rewriter.createOrFold <AtenSizeIntOp>(loc, input, dimH));
8258
8187
Value dimW = rewriter.create <Torch::ConstantIntOp>(
8259
8188
loc, rewriter.getI64IntegerAttr (rank - 1 ));
8260
8189
inputHW.push_back (
8261
- /* inW=*/ rewriter.create <AtenSizeIntOp>(loc, input, dimW));
8190
+ /* inW=*/ rewriter.createOrFold <AtenSizeIntOp>(loc, input, dimW));
8262
8191
8263
- Value outputShape = op.getOutputSize ();
8264
8192
SmallVector<Value> outputShapeSizesTorchInt;
8265
- getListConstructElements (outputShape , outputShapeSizesTorchInt);
8193
+ getListConstructElements (op. getOutputSize () , outputShapeSizesTorchInt);
8266
8194
8267
8195
// TODO: Add support for cases other than:
8268
8196
// inH % outH != 0 or inW % outW != 0 where
@@ -8343,11 +8271,32 @@ class DecomposeAtenAdaptiveAvgPool2dOp
8343
8271
loc, Torch::ListType::get (Torch::IntType::get (context)),
8344
8272
ValueRange{constantZero, constantZero});
8345
8273
8346
- rewriter.replaceOpWithNewOp <AtenAvgPool2dOp>(
8347
- op, op.getType (), input, kernelSizeList, strideList, paddingSizeList,
8348
- /* ceilMode=*/ constantFalse, /* countIncludePad=*/ constantTrue,
8349
- /* divisorOverride=*/ constantNone);
8350
- return success ();
8274
+ if constexpr (std::is_same_v<AtenOpT, AtenAdaptiveAvgPool2dOp>) {
8275
+ rewriter.replaceOpWithNewOp <AtenAvgPool2dOp>(
8276
+ op, op.getType (), input, kernelSizeList, strideList, paddingSizeList,
8277
+ /* ceilMode=*/ constantFalse, /* countIncludePad=*/ constantTrue,
8278
+ /* divisorOverride=*/ constantNone);
8279
+ return success ();
8280
+ } else if constexpr (std::is_same_v<AtenOpT, AtenAdaptiveMaxPool2dOp>) {
8281
+ Value dilationList = rewriter.create <PrimListConstructOp>(
8282
+ loc, Torch::ListType::get (Torch::IntType::get (context)),
8283
+ ValueRange{constantOne, constantOne});
8284
+ if (op.getResult (1 ).use_empty ()) {
8285
+ auto maxPool = rewriter.create <AtenMaxPool2dOp>(
8286
+ loc, op.getType (0 ), input, kernelSizeList, strideList,
8287
+ paddingSizeList, dilationList, /* ceil_mode=*/ constantFalse);
8288
+ rewriter.replaceOp (op, {maxPool.getResult (), Value ()});
8289
+ } else {
8290
+ auto maxPool = rewriter.create <AtenMaxPool2dWithIndicesOp>(
8291
+ loc, op.getType (0 ), op.getType (1 ), input, kernelSizeList,
8292
+ strideList, paddingSizeList, dilationList,
8293
+ /* ceil_mode=*/ constantFalse);
8294
+ rewriter.replaceOp (op, maxPool.getResults ());
8295
+ }
8296
+ return success ();
8297
+ }
8298
+ return rewriter.notifyMatchFailure (
8299
+ op, " unimplemented: unsupported template op" );
8351
8300
}
8352
8301
};
8353
8302
} // namespace
@@ -11778,9 +11727,14 @@ class DecomposeComplexOpsPass
11778
11727
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
11779
11728
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
11780
11729
addPatternIfTargetOpIsIllegal<DecomposeAtenToPrimDeviceOp>(patterns);
11781
- addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveMaxPool1dOp>(patterns);
11782
- addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns);
11783
- addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
11730
+ addPatternIfTargetOpIsIllegal<
11731
+ DecomposeAtenAdaptivePool1dOp<AtenAdaptiveMaxPool1dOp>>(patterns);
11732
+ addPatternIfTargetOpIsIllegal<
11733
+ DecomposeAtenAdaptivePool1dOp<AtenAdaptiveAvgPool1dOp>>(patterns);
11734
+ addPatternIfTargetOpIsIllegal<
11735
+ DecomposeAtenAdaptivePool2dOp<AtenAdaptiveMaxPool2dOp>>(patterns);
11736
+ addPatternIfTargetOpIsIllegal<
11737
+ DecomposeAtenAdaptivePool2dOp<AtenAdaptiveAvgPool2dOp>>(patterns);
11784
11738
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
11785
11739
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinTensorOp>(patterns);
11786
11740
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMaxOp>(patterns);
0 commit comments