Skip to content

Commit b6c4e87

Browse files
authored
support decomposition of aten.adaptive_max_pool2d (#3954)
* also move `backend_legal_ops` into `fx_importer_backend` and `jit_importer_backend`, so that `onnx/fx/jit` importer test maintain their own `backend_legal_ops` config. * TODO: unify different `backend_legal_ops` into one.
1 parent e4a2f86 commit b6c4e87

File tree

8 files changed

+185
-168
lines changed

8 files changed

+185
-168
lines changed

lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp

Lines changed: 82 additions & 128 deletions
Original file line numberDiff line numberDiff line change
@@ -8056,105 +8056,17 @@ class DecomposeAtenToDeviceOp : public OpRewritePattern<AtenToDeviceOp> {
80568056
} // namespace
80578057

80588058
namespace {
8059+
// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d`.
80598060
// Decompose `aten.adaptive_max_pool1d` op into `aten.max_pool1d_with_indices`
8060-
// op.
8061-
class DecomposeAtenAdaptiveMaxPool1dOp
8062-
: public OpRewritePattern<AtenAdaptiveMaxPool1dOp> {
8063-
using OpRewritePattern<AtenAdaptiveMaxPool1dOp>::OpRewritePattern;
8064-
LogicalResult matchAndRewrite(AtenAdaptiveMaxPool1dOp op,
8065-
PatternRewriter &rewriter) const override {
8066-
Location loc = op->getLoc();
8067-
MLIRContext *context = op.getContext();
8068-
8069-
Value input = op.getSelf();
8070-
std::optional<unsigned> maybeRank = getTensorRank(input);
8071-
if (!maybeRank) {
8072-
return rewriter.notifyMatchFailure(op, "expected input to have a rank");
8073-
}
8074-
unsigned rank = *maybeRank;
8075-
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
8076-
loc, rewriter.getI64IntegerAttr(rank - 1));
8077-
Value inputSize = rewriter.create<AtenSizeIntOp>(loc, input, sizeDim);
8078-
8079-
Value outputShape = op.getOutputSize();
8080-
SmallVector<Value> outputShapeSizesTorchInt;
8081-
getListConstructElements(outputShape, outputShapeSizesTorchInt);
8082-
Value outputSize = outputShapeSizesTorchInt[0];
8083-
8084-
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
8085-
loc, rewriter.getI64IntegerAttr(1));
8086-
Value constantZero = rewriter.create<Torch::ConstantIntOp>(
8087-
loc, rewriter.getI64IntegerAttr(0));
8088-
Value constantFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
8089-
8090-
int64_t outputSizeInt;
8091-
if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) {
8092-
return rewriter.notifyMatchFailure(
8093-
op, "the output size of adaptive_max_pool1d must be a constant int");
8094-
}
8095-
8096-
SmallVector<Value, 1> kernelSize;
8097-
if (outputSizeInt == 1) {
8098-
BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType());
8099-
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
8100-
kernelSize.push_back(
8101-
inputShape[rank - 1] == kUnknownSize
8102-
? inputSize
8103-
: rewriter.create<Torch::ConstantIntOp>(
8104-
loc, rewriter.getI64IntegerAttr(inputShape[rank - 1])));
8105-
} else {
8106-
if (!isAssumingStrictSymbolicShapes(rewriter)) {
8107-
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
8108-
rewriter.create<RuntimeAssertOp>(
8109-
loc, cond,
8110-
"unimplemented: only support cases where input and output size are "
8111-
"equal for non-unit output size");
8112-
}
8113-
kernelSize.push_back(constantOne);
8114-
}
8115-
8116-
Value kernelSizeList = rewriter.create<PrimListConstructOp>(
8117-
loc, Torch::ListType::get(Torch::IntType::get(context)), kernelSize);
8118-
Value strideList = rewriter.create<PrimListConstructOp>(
8119-
loc, Torch::ListType::get(Torch::IntType::get(context)),
8120-
ValueRange{constantOne});
8121-
Value paddingSizeList = rewriter.create<PrimListConstructOp>(
8122-
loc, Torch::ListType::get(Torch::IntType::get(context)),
8123-
ValueRange{constantZero});
8124-
Value dialationList = rewriter.create<PrimListConstructOp>(
8125-
loc, Torch::ListType::get(Torch::IntType::get(context)),
8126-
ValueRange{constantOne});
8127-
8128-
if (op.getResult(1).use_empty()) {
8129-
auto maxPool = rewriter.create<AtenMaxPool1dOp>(
8130-
loc, op.getType(0), input, kernelSizeList, strideList,
8131-
paddingSizeList, dialationList,
8132-
/*ceil_mode=*/constantFalse);
8133-
rewriter.replaceOp(op, {maxPool.getResult(), Value()});
8134-
} else {
8135-
auto maxPool = rewriter.create<AtenMaxPool1dWithIndicesOp>(
8136-
loc, op.getType(0), op.getType(1), input, kernelSizeList, strideList,
8137-
paddingSizeList, dialationList,
8138-
/*ceil_mode=*/constantFalse);
8139-
rewriter.replaceOp(op, maxPool.getResults());
8140-
}
8141-
return success();
8142-
}
8143-
};
8144-
} // namespace
8145-
8146-
namespace {
8147-
// Decompose `aten.adaptive_avg_pool1d` op into `aten.avg_pool1d` op.
8148-
8149-
// The logic of this decomposition is totally same with
8150-
// the DecomposeAtenAdaptiveAvgPool2dOp, that means currently only following two
8151-
// cases are supported:
8061+
// or `aten.max_pool1d`.
8062+
//
8063+
// Only following two cases are supported:
81528064
// 1. inputSize = outputSize
81538065
// 2. outputSize = 1
8154-
class DecomposeAtenAdaptiveAvgPool1dOp
8155-
: public OpRewritePattern<AtenAdaptiveAvgPool1dOp> {
8156-
using OpRewritePattern<AtenAdaptiveAvgPool1dOp>::OpRewritePattern;
8157-
LogicalResult matchAndRewrite(AtenAdaptiveAvgPool1dOp op,
8066+
template <typename AtenOpT>
8067+
class DecomposeAtenAdaptivePool1dOp : public OpRewritePattern<AtenOpT> {
8068+
using OpRewritePattern<AtenOpT>::OpRewritePattern;
8069+
LogicalResult matchAndRewrite(AtenOpT op,
81588070
PatternRewriter &rewriter) const override {
81598071
Location loc = op->getLoc();
81608072
MLIRContext *context = op.getContext();
@@ -8167,11 +8079,10 @@ class DecomposeAtenAdaptiveAvgPool1dOp
81678079
unsigned rank = *maybeRank;
81688080
Value sizeDim = rewriter.create<Torch::ConstantIntOp>(
81698081
loc, rewriter.getI64IntegerAttr(rank - 1));
8170-
Value inputSize = rewriter.create<AtenSizeIntOp>(loc, input, sizeDim);
8082+
Value inputSize = rewriter.createOrFold<AtenSizeIntOp>(loc, input, sizeDim);
81718083

8172-
Value outputShape = op.getOutputSize();
81738084
SmallVector<Value> outputShapeSizesTorchInt;
8174-
getListConstructElements(outputShape, outputShapeSizesTorchInt);
8085+
getListConstructElements(op.getOutputSize(), outputShapeSizesTorchInt);
81758086
Value outputSize = outputShapeSizesTorchInt[0];
81768087

81778088
Value constantOne = rewriter.create<Torch::ConstantIntOp>(
@@ -8184,18 +8095,12 @@ class DecomposeAtenAdaptiveAvgPool1dOp
81848095
int64_t outputSizeInt;
81858096
if (!matchPattern(outputSize, m_TorchConstantInt(&outputSizeInt))) {
81868097
return rewriter.notifyMatchFailure(
8187-
op, "the output size of adaptive_pool_1d must be a constant int");
8098+
op, "the output size of adaptive pool1d must be a constant int");
81888099
}
81898100

81908101
SmallVector<Value, 1> kernelSize;
81918102
if (outputSizeInt == 1) {
8192-
BaseTensorType inputTensorType = cast<BaseTensorType>(input.getType());
8193-
ArrayRef<int64_t> inputShape = inputTensorType.getSizes();
8194-
kernelSize.push_back(
8195-
inputShape[rank - 1] == kUnknownSize
8196-
? inputSize
8197-
: rewriter.create<Torch::ConstantIntOp>(
8198-
loc, rewriter.getI64IntegerAttr(inputShape[rank - 1])));
8103+
kernelSize.push_back(inputSize);
81998104
} else {
82008105
if (!isAssumingStrictSymbolicShapes(rewriter)) {
82018106
Value cond = rewriter.create<AtenEqIntOp>(loc, inputSize, outputSize);
@@ -8216,16 +8121,40 @@ class DecomposeAtenAdaptiveAvgPool1dOp
82168121
loc, Torch::ListType::get(Torch::IntType::get(context)),
82178122
ValueRange{constantZero});
82188123

8219-
rewriter.replaceOpWithNewOp<AtenAvgPool1dOp>(
8220-
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
8221-
/*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue);
8222-
return success();
8124+
if constexpr (std::is_same_v<AtenAdaptiveAvgPool1dOp, AtenOpT>) {
8125+
rewriter.replaceOpWithNewOp<AtenAvgPool1dOp>(
8126+
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
8127+
/*ceil_mode=*/constantFalse, /*count_include_pad=*/constantTrue);
8128+
return success();
8129+
} else if constexpr (std::is_same_v<AtenAdaptiveMaxPool1dOp, AtenOpT>) {
8130+
Value dilationList = rewriter.create<PrimListConstructOp>(
8131+
loc, Torch::ListType::get(Torch::IntType::get(context)),
8132+
ValueRange{constantOne});
8133+
if (op.getResult(1).use_empty()) {
8134+
auto maxPool = rewriter.create<AtenMaxPool1dOp>(
8135+
loc, op.getType(0), input, kernelSizeList, strideList,
8136+
paddingSizeList, dilationList,
8137+
/*ceil_mode=*/constantFalse);
8138+
rewriter.replaceOp(op, {maxPool.getResult(), Value()});
8139+
} else {
8140+
auto maxPool = rewriter.create<AtenMaxPool1dWithIndicesOp>(
8141+
loc, op.getType(0), op.getType(1), input, kernelSizeList,
8142+
strideList, paddingSizeList, dilationList,
8143+
/*ceil_mode=*/constantFalse);
8144+
rewriter.replaceOp(op, maxPool.getResults());
8145+
}
8146+
return success();
8147+
}
8148+
return rewriter.notifyMatchFailure(
8149+
op, "unimplemented: unsupported template op");
82238150
}
82248151
};
82258152
} // namespace
82268153

82278154
namespace {
8228-
// Decompose `aten.adaptiveAvgPool2d` op into `aten.avgPool2d` op.
8155+
// Decompose `aten.adaptive_avg_pool2d` op into `aten.avg_pool2d` op.
8156+
// Decompose `aten.adaptive_max_pool2d` op into `aten.max_pool2d` or
8157+
// `aten.max_pool2d_with_indices` op.
82298158
//
82308159
// For AdaptiveAvgPool2d op, when the input size is an integer multiple of
82318160
// output size the kernelSize, stride and padding is calculated as follows:
@@ -8235,10 +8164,10 @@ namespace {
82358164
// kernelW = inW - [(outW - 1) * strideW] = strideW
82368165
// paddingH = 0, paddingW = 0
82378166
//
8238-
class DecomposeAtenAdaptiveAvgPool2dOp
8239-
: public OpRewritePattern<AtenAdaptiveAvgPool2dOp> {
8240-
using OpRewritePattern::OpRewritePattern;
8241-
LogicalResult matchAndRewrite(AtenAdaptiveAvgPool2dOp op,
8167+
template <typename AtenOpT>
8168+
class DecomposeAtenAdaptivePool2dOp : public OpRewritePattern<AtenOpT> {
8169+
using OpRewritePattern<AtenOpT>::OpRewritePattern;
8170+
LogicalResult matchAndRewrite(AtenOpT op,
82428171
PatternRewriter &rewriter) const override {
82438172

82448173
Location loc = op.getLoc();
@@ -8254,15 +8183,14 @@ class DecomposeAtenAdaptiveAvgPool2dOp
82548183
Value dimH = rewriter.create<Torch::ConstantIntOp>(
82558184
loc, rewriter.getI64IntegerAttr(rank - 2));
82568185
inputHW.push_back(
8257-
/*inH=*/rewriter.create<AtenSizeIntOp>(loc, input, dimH));
8186+
/*inH=*/rewriter.createOrFold<AtenSizeIntOp>(loc, input, dimH));
82588187
Value dimW = rewriter.create<Torch::ConstantIntOp>(
82598188
loc, rewriter.getI64IntegerAttr(rank - 1));
82608189
inputHW.push_back(
8261-
/*inW=*/rewriter.create<AtenSizeIntOp>(loc, input, dimW));
8190+
/*inW=*/rewriter.createOrFold<AtenSizeIntOp>(loc, input, dimW));
82628191

8263-
Value outputShape = op.getOutputSize();
82648192
SmallVector<Value> outputShapeSizesTorchInt;
8265-
getListConstructElements(outputShape, outputShapeSizesTorchInt);
8193+
getListConstructElements(op.getOutputSize(), outputShapeSizesTorchInt);
82668194

82678195
// TODO: Add support for cases other than:
82688196
// inH % outH != 0 or inW % outW != 0 where
@@ -8343,11 +8271,32 @@ class DecomposeAtenAdaptiveAvgPool2dOp
83438271
loc, Torch::ListType::get(Torch::IntType::get(context)),
83448272
ValueRange{constantZero, constantZero});
83458273

8346-
rewriter.replaceOpWithNewOp<AtenAvgPool2dOp>(
8347-
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
8348-
/*ceilMode=*/constantFalse, /*countIncludePad=*/constantTrue,
8349-
/*divisorOverride=*/constantNone);
8350-
return success();
8274+
if constexpr (std::is_same_v<AtenOpT, AtenAdaptiveAvgPool2dOp>) {
8275+
rewriter.replaceOpWithNewOp<AtenAvgPool2dOp>(
8276+
op, op.getType(), input, kernelSizeList, strideList, paddingSizeList,
8277+
/*ceilMode=*/constantFalse, /*countIncludePad=*/constantTrue,
8278+
/*divisorOverride=*/constantNone);
8279+
return success();
8280+
} else if constexpr (std::is_same_v<AtenOpT, AtenAdaptiveMaxPool2dOp>) {
8281+
Value dilationList = rewriter.create<PrimListConstructOp>(
8282+
loc, Torch::ListType::get(Torch::IntType::get(context)),
8283+
ValueRange{constantOne, constantOne});
8284+
if (op.getResult(1).use_empty()) {
8285+
auto maxPool = rewriter.create<AtenMaxPool2dOp>(
8286+
loc, op.getType(0), input, kernelSizeList, strideList,
8287+
paddingSizeList, dilationList, /*ceil_mode=*/constantFalse);
8288+
rewriter.replaceOp(op, {maxPool.getResult(), Value()});
8289+
} else {
8290+
auto maxPool = rewriter.create<AtenMaxPool2dWithIndicesOp>(
8291+
loc, op.getType(0), op.getType(1), input, kernelSizeList,
8292+
strideList, paddingSizeList, dilationList,
8293+
/*ceil_mode=*/constantFalse);
8294+
rewriter.replaceOp(op, maxPool.getResults());
8295+
}
8296+
return success();
8297+
}
8298+
return rewriter.notifyMatchFailure(
8299+
op, "unimplemented: unsupported template op");
83518300
}
83528301
};
83538302
} // namespace
@@ -11778,9 +11727,14 @@ class DecomposeComplexOpsPass
1177811727
addPatternIfTargetOpIsIllegal<DecomposeAtenToDtypeLayoutOp>(patterns);
1177911728
addPatternIfTargetOpIsIllegal<DecomposeAtenToDeviceOp>(patterns);
1178011729
addPatternIfTargetOpIsIllegal<DecomposeAtenToPrimDeviceOp>(patterns);
11781-
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveMaxPool1dOp>(patterns);
11782-
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool1dOp>(patterns);
11783-
addPatternIfTargetOpIsIllegal<DecomposeAtenAdaptiveAvgPool2dOp>(patterns);
11730+
addPatternIfTargetOpIsIllegal<
11731+
DecomposeAtenAdaptivePool1dOp<AtenAdaptiveMaxPool1dOp>>(patterns);
11732+
addPatternIfTargetOpIsIllegal<
11733+
DecomposeAtenAdaptivePool1dOp<AtenAdaptiveAvgPool1dOp>>(patterns);
11734+
addPatternIfTargetOpIsIllegal<
11735+
DecomposeAtenAdaptivePool2dOp<AtenAdaptiveMaxPool2dOp>>(patterns);
11736+
addPatternIfTargetOpIsIllegal<
11737+
DecomposeAtenAdaptivePool2dOp<AtenAdaptiveAvgPool2dOp>>(patterns);
1178411738
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinOp>(patterns);
1178511739
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMinTensorOp>(patterns);
1178611740
addPatternIfTargetOpIsIllegal<DecomposeAtenClampMaxOp>(patterns);

lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -509,6 +509,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
509509
target.addIllegalOp<AtenToPrimDeviceOp>();
510510
target.addIllegalOp<AtenAdaptiveAvgPool1dOp>();
511511
target.addIllegalOp<AtenAdaptiveAvgPool2dOp>();
512+
target.addIllegalOp<AtenAdaptiveMaxPool1dOp>();
513+
target.addIllegalOp<AtenAdaptiveMaxPool2dOp>();
512514
target.addIllegalOp<AtenClampMinOp>();
513515
target.addIllegalOp<AtenClampMinTensorOp>();
514516
target.addIllegalOp<AtenClampMaxOp>();

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2702,6 +2702,7 @@
27022702
"AdaptiveMaxPool2dDynamicNoBatch_basic",
27032703
"AdaptiveMaxPool2dDynamicWithIndices_basic",
27042704
"AdaptiveMaxPool2dDynamic_basic",
2705+
"AdaptiveMaxPool2dFixedKernelStrideSizeStaticModule_basic",
27052706
"AdaptiveMaxPool2dStaticWithIndices_basic",
27062707
"AdaptiveMaxPool2dStatic_basic",
27072708
"AdaptiveMaxPool3dDynamicNoBatch_basic",

projects/pt1/python/torch_mlir/torchscript.py

Lines changed: 1 addition & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -145,34 +145,6 @@ def _get_for_tracing(
145145
return result
146146

147147

148-
# The set of ops that are considered legal for each backend.
149-
# These are currently quite load-bearing, since different backends might be
150-
# missing patterns for decomposed forms of certain ops.
151-
# TODO: Tighten up the definition of these "conditionally legal for backends"
152-
# ops in the backend contract, and move these lists somewhere deeper in the
153-
# compiler where each backend can "own" its set of legal ops.
154-
BACKEND_LEGAL_OPS = {
155-
OutputType.TOSA: [
156-
"aten.flatten.using_ints",
157-
"aten.native_layer_norm",
158-
"aten.linear",
159-
],
160-
OutputType.LINALG_ON_TENSORS: [
161-
"aten.flatten.using_ints",
162-
"aten.adaptive_avg_pool1d",
163-
"aten.adaptive_avg_pool2d",
164-
"aten.unflatten.int",
165-
],
166-
OutputType.STABLEHLO: [
167-
"aten.amax",
168-
"aten.amin",
169-
"aten.randn.generator",
170-
"aten.normal_functional",
171-
"aten.fmod.Tensor",
172-
],
173-
}
174-
175-
176148
def _canon_extra_library(
177149
extra_library, extra_library_file_name="custom_op_extra_library.mlir"
178150
):
@@ -249,19 +221,10 @@ def compile(
249221
if ignore_traced_shapes and not use_tracing:
250222
raise Exception("`ignore_traced_shapes` requires `use_tracing`")
251223

252-
# We only allow `backend_legal_ops` to be specified for the `"torch"`
253-
# output type because the other output types actually invoke their
254-
# respective backends (Linalg, TOSA, or STABLEHLO), and those backends have
255-
# very specific requirements about the ops which are legal.
256-
# See `BACKEND_LEGAL_OPS` for more details.
257224
if backend_legal_ops is not None:
258-
if output_type != OutputType.TORCH:
259-
raise Exception(
260-
"`backend_legal_ops` is only valid with the " "`torch` output type"
261-
)
262225
backend_legal_ops = list(sorted(set(backend_legal_ops)))
263226
else:
264-
backend_legal_ops = BACKEND_LEGAL_OPS.get(output_type, [])
227+
backend_legal_ops = []
265228

266229
# For FX-based models, automatically strip overloads.
267230
if isinstance(model, torch.fx.GraphModule):

0 commit comments

Comments
 (0)