Skip to content

Commit 79612ae

Browse files
committed
[TOSA] Add legalization for avg_pool2d
Before this patch, the `avg_pool2d` and `avg_pool1d` legalizations lacked support for pooling with count_include_pad=True. This patch introduces that support. Signed-off-by: Vitalii Shutov <[email protected]> Change-Id: I73fa26a58379e2c021929ade81c983ff91c59667
1 parent 2c989a2 commit 79612ae

File tree

4 files changed

+137
-42
lines changed

4 files changed

+137
-42
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,12 @@ FailureOr<Value> getConvBiasForNoneType(Operation *op,
106106
Type inputElemTy, Type outputElemTy,
107107
ArrayRef<int64_t> weightShape);
108108

109+
// Emit a TOSA explicit zero padding op for NCHW layout.
110+
std::pair<Value, RankedTensorType>
111+
emitExplicitZeroPadNCHW(Location loc, PatternRewriter &rewriter, Operation *op,
112+
Value input, ArrayRef<int64_t> paddingInts,
113+
Type elemTy);
114+
109115
} // namespace tosa
110116
} // namespace mlir
111117

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6115,21 +6115,27 @@ static LogicalResult getOutputTypeAndPoolingParameters(
61156115

61166116
if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
61176117
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
6118-
// Currently, we can not represent `count_include_pad` with the existing
6119-
// TOSA AvgPool2d specification. Without the below check, we produce silent
6120-
// wrong answer (SWA) when the `count_include_pad` value is `true.`
6121-
//
6122-
// Note: We need to check for `count_include_pad` only when the `padding`
6123-
// value is non-zero.
6118+
// When count_include_pad=true with non-zero padding, insert an explicit
6119+
// zero-filled tosa.pad and then call avg_pool2d with pad=[0,0,0,0] so that
6120+
// the divisor equals the full kernel size.
61246121
bool countIncludePad;
61256122
if ((paddingInts[0] != 0 || paddingInts[1] != 0) &&
61266123
(!matchPattern(op.getCountIncludePad(),
61276124
m_TorchConstantBool(&countIncludePad)) ||
61286125

61296126
countIncludePad)) {
6130-
return rewriter.notifyMatchFailure(
6131-
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
6132-
"`count_include_pad` value should be `False`.");
6127+
6128+
auto elemTy = inputTy.getElementType();
6129+
auto padResult = tosa::emitExplicitZeroPadNCHW(
6130+
op.getLoc(), rewriter, op, inputXchw,
6131+
/*{top,left}*/ {paddingInts[0], paddingInts[1]}, elemTy);
6132+
if (!padResult.first)
6133+
return failure();
6134+
6135+
inputXchw = padResult.first;
6136+
inputTy = padResult.second;
6137+
6138+
paddingInts.assign(/*Count=*/2, /*Value=*/0);
61336139
}
61346140
}
61356141

lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -595,5 +595,42 @@ FailureOr<Value> getConvBiasForNoneType(Operation *op,
595595
}
596596
}
597597

598+
// Emit a TOSA explicit zero padding op for NCHW layout.
599+
// Emit a `tosa.pad` around `input` (NCHW order) so that a later
600+
// tosa.avg_pool2d can run with pad = 0 and still reproduce
601+
// `count_include_pad==true` semantics. `paddingInts` comes in as {pad_top,
602+
// pad_left}.
603+
std::pair<Value, RankedTensorType>
604+
emitExplicitZeroPadNCHW(Location loc, PatternRewriter &rewriter, Operation *op,
605+
Value input, ArrayRef<int64_t> paddingInts,
606+
Type elemTy) {
607+
const int64_t padTop = paddingInts[0];
608+
const int64_t padLeft = paddingInts[1];
609+
610+
SmallVector<int64_t> padPairs = {0, 0, 0, 0,
611+
padTop, padTop, padLeft, padLeft};
612+
Value padShape = tosa::getTosaConstShape(rewriter, loc, padPairs);
613+
614+
Value padConst;
615+
if (isa<FloatType>(elemTy)) {
616+
padConst = *getConstTensor<float>(rewriter, op, {0.0f}, {1}, elemTy);
617+
} else {
618+
padConst = *getConstTensor<int32_t>(rewriter, op, {0}, {1}, elemTy);
619+
}
620+
621+
// Create the actual Pad op
622+
auto inTy = cast<RankedTensorType>(input.getType());
623+
auto outTy = RankedTensorType::get({inTy.getDimSize(0), // N
624+
inTy.getDimSize(1), // C
625+
inTy.getDimSize(2) + 2 * padTop, // H
626+
inTy.getDimSize(3) + 2 * padLeft}, // W
627+
elemTy);
628+
629+
Value padded =
630+
rewriter.create<tosa::PadOp>(loc, outTy, input, padShape, padConst);
631+
632+
return {padded, outTy};
633+
}
634+
598635
} // namespace tosa
599636
} // namespace mlir

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 79 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2265,24 +2265,6 @@ func.func @torch.aten.round(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso
22652265

22662266
// -----
22672267

2268-
func.func @torch.aten.avg_pool2d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
2269-
%int0 = torch.constant.int 0
2270-
%int1 = torch.constant.int 1
2271-
%int3 = torch.constant.int 3
2272-
%false= torch.constant.bool false
2273-
%count_include_pad = torch.constant.bool true
2274-
%divisor_override = torch.constant.none
2275-
2276-
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
2277-
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
2278-
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
2279-
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
2280-
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32>
2281-
return %3 : !torch.vtensor<[1,192,35,35],f32>
2282-
}
2283-
2284-
// -----
2285-
22862268
func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
22872269
%int0 = torch.constant.int 0
22882270
%int1 = torch.constant.int 1
@@ -2802,21 +2784,6 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to
28022784

28032785
// -----
28042786

2805-
func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
2806-
%int1 = torch.constant.int 1
2807-
%int3 = torch.constant.int 3
2808-
%false = torch.constant.bool false
2809-
%count_include_pad = torch.constant.bool true
2810-
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
2811-
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
2812-
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
2813-
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
2814-
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
2815-
return %3 : !torch.vtensor<[1,512,10],f32>
2816-
}
2817-
2818-
// -----
2819-
28202787
// CHECK-LABEL: func.func @torch.aten.reflection_pad1d$basic(
28212788
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,2,4],f32>) -> !torch.vtensor<[1,2,8],f32> {
28222789
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,2,4],f32> -> tensor<1x2x4xf32>
@@ -4328,3 +4295,82 @@ func.func @torch.aten.linear$f16(%arg0: !torch.vtensor<[2,4],f16>, %arg1: !torch
43284295
%0 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[2,4],f16>, !torch.vtensor<[3,4],f16>, !torch.vtensor<[3],f16> -> !torch.vtensor<[2,3],f16>
43294296
return %0 : !torch.vtensor<[2,3],f16>
43304297
}
4298+
4299+
// -----
4300+
// CHECK-LABEL: func.func @torch.aten.avg_pool2d.count_include_pad(
4301+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4302+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,192,35,35],f32> -> tensor<1x192x35x35xf32>
4303+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 0
4304+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 1
4305+
// CHECK: %[[VAL_4:.*]] = torch.constant.int 3
4306+
// CHECK: %[[VAL_5:.*]] = torch.constant.bool false
4307+
// CHECK: %[[VAL_6:.*]] = torch.constant.bool true
4308+
// CHECK: %[[VAL_7:.*]] = torch.constant.none
4309+
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_4]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<int>
4310+
// CHECK: %[[VAL_9:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4311+
// CHECK: %[[VAL_10:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_3]] : (!torch.int, !torch.int) -> !torch.list<int>
4312+
// CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 1, 1]> : tensor<8xindex>} : () -> !tosa.shape<8>
4313+
// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4314+
// CHECK: %[[VAL_13:.*]] = tosa.pad %[[VAL_1]], %[[VAL_11]], %[[VAL_12]] : (tensor<1x192x35x35xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x192x37x37xf32>
4315+
// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_13]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x192x37x37xf32>) -> tensor<1x37x37x192xf32>
4316+
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4317+
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4318+
// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 3>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x37x37x192xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x35x35x192xf32>
4319+
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x35x35x192xf32>) -> tensor<1x192x35x35xf32>
4320+
// CHECK: %[[VAL_19:.*]] = tensor.cast %[[VAL_18]] : tensor<1x192x35x35xf32> to tensor<1x192x35x35xf32>
4321+
// CHECK: %[[VAL_20:.*]] = torch_c.from_builtin_tensor %[[VAL_19]] : tensor<1x192x35x35xf32> -> !torch.vtensor<[1,192,35,35],f32>
4322+
// CHECK: return %[[VAL_20]] : !torch.vtensor<[1,192,35,35],f32>
4323+
// CHECK: }
4324+
func.func @torch.aten.avg_pool2d.count_include_pad(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
4325+
%int0 = torch.constant.int 0
4326+
%int1 = torch.constant.int 1
4327+
%int3 = torch.constant.int 3
4328+
%false= torch.constant.bool false
4329+
%count_include_pad = torch.constant.bool true
4330+
%divisor_override = torch.constant.none
4331+
4332+
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
4333+
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
4334+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
4335+
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,192,35,35],f32>
4336+
return %3 : !torch.vtensor<[1,192,35,35],f32>
4337+
}
4338+
4339+
// -----
4340+
// CHECK-LABEL: func.func @torch.aten.avg_pool1d.count_include_pad(
4341+
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4342+
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,512,10],f32> -> tensor<1x512x10xf32>
4343+
// CHECK: %[[VAL_2:.*]] = torch.constant.int 1
4344+
// CHECK: %[[VAL_3:.*]] = torch.constant.int 3
4345+
// CHECK: %[[VAL_4:.*]] = torch.constant.bool false
4346+
// CHECK: %[[VAL_5:.*]] = torch.constant.bool true
4347+
// CHECK: %[[VAL_6:.*]] = torch.prim.ListConstruct %[[VAL_3]] : (!torch.int) -> !torch.list<int>
4348+
// CHECK: %[[VAL_7:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4349+
// CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[VAL_2]] : (!torch.int) -> !torch.list<int>
4350+
// CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[1, 512, 10, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
4351+
// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_9]] : (tensor<1x512x10xf32>, !tosa.shape<4>) -> tensor<1x512x10x1xf32>
4352+
// CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
4353+
// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4354+
// CHECK: %[[VAL_13:.*]] = tosa.pad %[[VAL_10]], %[[VAL_11]], %[[VAL_12]] : (tensor<1x512x10x1xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x512x12x1xf32>
4355+
// CHECK: %[[VAL_14:.*]] = tosa.transpose %[[VAL_13]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<1x512x12x1xf32>) -> tensor<1x12x1x512xf32>
4356+
// CHECK: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4357+
// CHECK: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
4358+
// CHECK: %[[VAL_17:.*]] = tosa.avg_pool2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {acc_type = f32, kernel = array<i64: 3, 1>, pad = array<i64: 0, 0, 0, 0>, stride = array<i64: 1, 1>} : (tensor<1x12x1x512xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x10x1x512xf32>
4359+
// CHECK: %[[VAL_18:.*]] = tosa.transpose %[[VAL_17]] {perms = array<i32: 0, 3, 1, 2>} : (tensor<1x10x1x512xf32>) -> tensor<1x512x10x1xf32>
4360+
// CHECK: %[[VAL_19:.*]] = tosa.const_shape {values = dense<[1, 512, 10]> : tensor<3xindex>} : () -> !tosa.shape<3>
4361+
// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_19]] : (tensor<1x512x10x1xf32>, !tosa.shape<3>) -> tensor<1x512x10xf32>
4362+
// CHECK: %[[VAL_21:.*]] = tensor.cast %[[VAL_20]] : tensor<1x512x10xf32> to tensor<1x512x10xf32>
4363+
// CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21]] : tensor<1x512x10xf32> -> !torch.vtensor<[1,512,10],f32>
4364+
// CHECK: return %[[VAL_22]] : !torch.vtensor<[1,512,10],f32>
4365+
// CHECK: }
4366+
func.func @torch.aten.avg_pool1d.count_include_pad(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
4367+
%int1 = torch.constant.int 1
4368+
%int3 = torch.constant.int 3
4369+
%false = torch.constant.bool false
4370+
%count_include_pad = torch.constant.bool true
4371+
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
4372+
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
4373+
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
4374+
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
4375+
return %3 : !torch.vtensor<[1,512,10],f32>
4376+
}

0 commit comments

Comments
 (0)