@@ -5549,6 +5549,26 @@ static LogicalResult getOutputTypeAndPoolingParameters(
5549
5549
std::is_same<AtenOpT, AtenAvgPool1dOp>())
5550
5550
paddingInts.push_back (0 );
5551
5551
5552
+ if constexpr (std::is_same<AtenOpT, AtenAvgPool1dOp>() ||
5553
+ std::is_same<AtenOpT, AtenAvgPool2dOp>()) {
5554
+ // Currently, we can not represent `count_include_pad` with the existing
5555
+ // TOSA AvgPool2d specification. Without the below check, we produce silent
5556
+ // wrong answer (SWA) when the `count_include_pad` value is `true.`
5557
+ //
5558
+ // Note: We need to check for `count_include_pad` only when the `padding`
5559
+ // value is non-zero.
5560
+ bool countIncludePad;
5561
+ if ((paddingInts[0 ] != 0 || paddingInts[1 ] != 0 ) &&
5562
+ (!matchPattern (op.getCountIncludePad (),
5563
+ m_TorchConstantBool (&countIncludePad)) ||
5564
+
5565
+ countIncludePad)) {
5566
+ return rewriter.notifyMatchFailure (
5567
+ op, " Unsupported `count_include_pad` value, for tosa AvgPool "
5568
+ " `count_include_pad` value should be `False`." );
5569
+ }
5570
+ }
5571
+
5552
5572
SmallVector<int64_t , 4 > padArr = {paddingInts[0 ], paddingInts[0 ],
5553
5573
paddingInts[1 ], paddingInts[1 ]};
5554
5574
kernel = rewriter.getDenseI64ArrayAttr (kernelSizeInts);
@@ -5677,18 +5697,6 @@ class ConvertAtenAvgPool2dOp
5677
5697
DenseI64ArrayAttr &stride, DenseI64ArrayAttr &pad,
5678
5698
Type &outputTy) const override {
5679
5699
5680
- // Currently, we can not represent `count_include_pad` with the existing
5681
- // TOSA AvgPool2d specification. Without the below check, we produce silent
5682
- // wrong answers (SWA) when the `count_include_pad` value is `true.`
5683
- bool countIncludePad;
5684
- if (!matchPattern (op.getCountIncludePad (),
5685
- m_TorchConstantBool (&countIncludePad)) ||
5686
- countIncludePad) {
5687
- return rewriter.notifyMatchFailure (
5688
- op, " Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
5689
- " `count_include_pad` value should be `False`." );
5690
- }
5691
-
5692
5700
// Currently, we can not represent `divisor_override` with the existing TOSA
5693
5701
// AvgPool2d specification. Without the below check, we produce silent wrong
5694
5702
// answers (SWA) when the `divisor_override` value is other than `None.`
@@ -5737,7 +5745,7 @@ class ConvertAtenAvgPool1dOp
5737
5745
// Expected a rank 3 input tensor
5738
5746
if (selfTy.getRank () != 3 )
5739
5747
return rewriter.notifyMatchFailure (
5740
- op, " Input tensor for MaxPool1d should have rank 3" );
5748
+ op, " Input tensor for AvgPool1d should have rank 3" );
5741
5749
5742
5750
// Unsqueeze input tensor to rank 4 to be compatible with tosa::AvgPool2dOp
5743
5751
SmallVector<int64_t > rank4Shape (selfShape);
@@ -5748,18 +5756,6 @@ class ConvertAtenAvgPool1dOp
5748
5756
selfTy.getElementType ()),
5749
5757
self, rewriter.getDenseI64ArrayAttr (rank4Shape));
5750
5758
5751
- // Currently, we can not represent `count_include_pad` with the existing
5752
- // TOSA AvgPool2d specification. Without the below check, we produce silent
5753
- // wrong answers (SWA) when the `count_include_pad` value is `true.`
5754
- bool countIncludePad;
5755
- if (!matchPattern (op.getCountIncludePad (),
5756
- m_TorchConstantBool (&countIncludePad)) ||
5757
- countIncludePad) {
5758
- return rewriter.notifyMatchFailure (
5759
- op, " Unsupported `count_include_pad` value, for tosa AvgPool2dOp "
5760
- " `count_include_pad` value should be `False`." );
5761
- }
5762
-
5763
5759
SmallVector<int64_t , 2 > dilationArray{1 , 1 };
5764
5760
if (failed (getOutputTypeAndPoolingParameters<AtenAvgPool1dOp,
5765
5761
tosa::AvgPool2dOp>(
0 commit comments