Skip to content

Commit 30c5193

Browse files
Hanumanth04Hanumanth Hanumantharayappa
andauthored
Support default padding case for tosa::AvgPool in the presence of count_include_pad (#3868)
Essentially, as part of my earlier [change](7f9f99c) , I didn't consider the `padding` value while erroring out for unsupported `count_include_pad` during `torch-to-tosa` lowering for AvgPool2d. The fix captured in this change addresses this. Please see [issue](#3862) for more details on this. Co-authored-by: Hanumanth Hanumantharayappa <[email protected]>
1 parent cd38ecf commit 30c5193

File tree

3 files changed

+43
-40
lines changed

3 files changed

+43
-40
lines changed

lib/Conversion/TorchToTosa/TorchToTosa.cpp

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5549,6 +5549,26 @@ static LogicalResult getOutputTypeAndPoolingParameters(
55495549
std::is_same<AtenOpT, AtenAvgPool1dOp>())
55505550
paddingInts.push_back(0);
55515551

5552+
if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
5553+
std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
5554+
// Currently, we can not represent `count_include_pad` with the existing
5555+
// TOSA AvgPool2d specification. Without the below check, we produce silent
5556+
// wrong answer (SWA) when the `count_include_pad` value is `true.`
5557+
//
5558+
// Note: We need to check for `count_include_pad` only when the `padding`
5559+
// value is non-zero.
5560+
bool countIncludePad;
5561+
if ((paddingInts[0] != 0 || paddingInts[1] != 0) &&
5562+
(!matchPattern(op.getCountIncludePad(),
5563+
m_TorchConstantBool(&countIncludePad)) ||
5564+
5565+
countIncludePad)) {
5566+
return rewriter.notifyMatchFailure(
5567+
op, "Unsupported `count_include_pad` value, for tosa AvgPool "
5568+
"`count_include_pad` value should be `False`.");
5569+
}
5570+
}
5571+
55525572
SmallVector<int64_t, 4> padArr = {paddingInts[0], paddingInts[0],
55535573
paddingInts[1], paddingInts[1]};
55545574
kernel = rewriter.getDenseI64ArrayAttr(kernelSizeInts);
@@ -5677,18 +5697,6 @@ class ConvertAtenAvgPool2dOp
56775697
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
56785698
Type &outputTy) const override {
56795699

5680-
// Currently, we can not represent `count_include_pad` with the existing
5681-
// TOSA AvgPool2d specification. Without the below check, we produce silent
5682-
// wrong answers (SWA) when the `count_include_pad` value is `true.`
5683-
bool countIncludePad;
5684-
if (!matchPattern(op.getCountIncludePad(),
5685-
m_TorchConstantBool(&countIncludePad)) ||
5686-
countIncludePad) {
5687-
return rewriter.notifyMatchFailure(
5688-
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
5689-
"`count_include_pad` value should be `False`.");
5690-
}
5691-
56925700
// Currently, we can not represent `divisor_override` with the existing TOSA
56935701
// AvgPool2d specification. Without the below check, we produce silent wrong
56945702
// answers (SWA) when the `divisor_override` value is other than `None.`
@@ -5737,7 +5745,7 @@ class ConvertAtenAvgPool1dOp
57375745
// Expected a rank 3 input tensor
57385746
if (selfTy.getRank() != 3)
57395747
return rewriter.notifyMatchFailure(
5740-
op, "Input tensor for MaxPool1d should have rank 3");
5748+
op, "Input tensor for AvgPool1d should have rank 3");
57415749

57425750
// Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
57435751
SmallVector<int64_t> rank4Shape(selfShape);
@@ -5748,18 +5756,6 @@ class ConvertAtenAvgPool1dOp
57485756
selfTy.getElementType()),
57495757
self, rewriter.getDenseI64ArrayAttr(rank4Shape));
57505758

5751-
// Currently, we can not represent `count_include_pad` with the existing
5752-
// TOSA AvgPool2d specification. Without the below check, we produce silent
5753-
// wrong answers (SWA) when the `count_include_pad` value is `true.`
5754-
bool countIncludePad;
5755-
if (!matchPattern(op.getCountIncludePad(),
5756-
m_TorchConstantBool(&countIncludePad)) ||
5757-
countIncludePad) {
5758-
return rewriter.notifyMatchFailure(
5759-
op, "Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
5760-
"`count_include_pad` value should be `False`.");
5761-
}
5762-
57635759
SmallVector<int64_t, 2> dilationArray{1, 1};
57645760
if (failed(getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
57655761
tosa::AvgPool2dOp>(

projects/pt1/e2e_testing/xfail_sets.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1736,6 +1736,12 @@
17361736
# Write the TOSA set as a "passing" set as it is very early in development
17371737
# and very few tests work yet.
17381738
TOSA_PASS_SET = {
1739+
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
1740+
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
1741+
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
1742+
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
1743+
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
1744+
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
17391745
"ElementwiseAtenLogicalNotOpPromoteModule_basic",
17401746
"ElementwiseCosIntModule_basic",
17411747
"ElementwiseReciprocalIntModule_basic",
@@ -2316,6 +2322,7 @@
23162322
"ReshapeExpandModule_basic",
23172323
"ReturnThreeTensorFloat32_basic",
23182324
"ReturnTwoTensorF32I64_basic",
2325+
"ResNet18StaticModule_basic",
23192326
"RsubFloatModule_basic",
23202327
"RsubFloatModule_noalpha_basic",
23212328
"RsubInt0d_NumToTensor_Module_basic",
@@ -3869,26 +3876,11 @@
38693876
"ViewSizeFromOtherTensor_basic",
38703877
"VisionTransformerModule_basic",
38713878
"ZerosLikeModule_falsePinMemory",
3872-
# count_include_pad and divisor_override check in TOSA AvgPool2d
3873-
"AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic",
3874-
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
3875-
"AdaptiveAvgPool2dOutputSizeDivisibleByInputDynamicModule_basic",
3876-
"AdaptiveAvgPool2dOutputSizeDivisibleByInputStaticModule_basic",
3877-
"AdaptiveAvgPool2dFixedKernelStrideSizeStaticModule_basic",
3878-
"AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic",
3879-
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
3880-
"ResNet18Module_basic",
3881-
"ResNet18StaticModule_basic",
3882-
"MobilenetV3Module_basic",
38833879
# Unexpected failures due to new PyTorch version update
38843880
"AdaptiveAvgPool1dGeneralDynamicNoBatches_basic",
38853881
"AdaptiveAvgPool1dGeneralDynamic_basic",
3886-
"AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic",
3887-
"AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic",
38883882
"AdaptiveAvgPool1dStaticEvenMultiple_basic",
38893883
"AdaptiveAvgPool1dStaticLargerOutput_basic",
3890-
"AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic",
3891-
"AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic",
38923884
"AdaptiveAvgPool2dDynamicNoBatch_basic",
38933885
"AdaptiveAvgPool2dDynamic_basic",
38943886
"CrossEntropyLossModule_basic",

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2424,3 +2424,18 @@ func.func @torch.prims.collapse$basic(%arg0: !torch.vtensor<[2,3,4],f32>) -> !to
24242424
%0 = torch.prims.collapse %arg0, %int1, %int2 : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,12],f32>
24252425
return %0 : !torch.vtensor<[2,12],f32>
24262426
}
2427+
2428+
// -----
2429+
2430+
func.func @torch.aten.avg_pool1d.count_include_pad_unsupported_value(%arg0: !torch.vtensor<[1,512,10],f32>) -> !torch.vtensor<[1,512,10],f32> {
2431+
%int1 = torch.constant.int 1
2432+
%int3 = torch.constant.int 3
2433+
%false = torch.constant.bool false
2434+
%count_include_pad = torch.constant.bool true
2435+
%0 = torch.prim.ListConstruct %int3 : (!torch.int) -> !torch.list<int>
2436+
%1 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
2437+
%2 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list<int>
2438+
// expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool1d' that was explicitly marked illegal}}
2439+
%3 = torch.aten.avg_pool1d %arg0, %0, %1, %2, %false, %count_include_pad : !torch.vtensor<[1,512,10],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool -> !torch.vtensor<[1,512,10],f32>
2440+
return %3 : !torch.vtensor<[1,512,10],f32>
2441+
}

0 commit comments

Comments
 (0)