Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

# pyre-unsafe

import logging
import math
import operator
from operator import neg
Expand Down Expand Up @@ -2346,6 +2347,66 @@ def resolve_full_arg(self, x_arg, const_arg):
return const_arg


@register_cadence_pass(CadencePassAttribute(opt_level=0))
class ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(ExportPass):
"""
Replace the aten adaptive avg_pool op with the aten avg_pool2d op.
"""

def call_operator(self, op, args, kwargs, meta):
# Only continue for avg_pool op
if op not in {exir_ops.edge.aten._adaptive_avg_pool2d.default}:
return super().call_operator(op, args, kwargs, meta)

# Get the input tensor
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
# Permute NCHW to NHWC for computation
in_tensor_permuted = in_tensor.permute(0, 2, 3, 1)
in_tensor_shape = in_tensor_permuted.shape

output_size = args[1]
num_dims = len(output_size)

# TODO: If in_tensor_shape is not a multiple of output size,
# this pass will not work. T224984800
dim_multiples = [
(in_tensor_shape[i + 1] % output_size[i]) == 0 for i in range(num_dims)
]
if not all(dim_multiples):
logging.info(
f"Unable to replace adaptive average pool with average pool. Input tensor shape of {in_tensor_shape} is not a multiple of output size: {output_size}"
)
return super().call_operator(op, args, kwargs, meta)

# Compute stride and kernel_size, then set default values for other arguments
stride = [(in_tensor_shape[i + 1] // output_size[i]) for i in range(num_dims)]
kernel_size = [
in_tensor_shape[i + 1] - (output_size[i] - 1) * stride[i]
for i in range(num_dims)
]
padding = [0] * num_dims
ceil_mode = False
count_include_pad = True
divisor_override = None

# Create a new avg_pool node with the updated args
new_args = (
args[0],
kernel_size,
stride,
padding,
ceil_mode,
count_include_pad,
divisor_override,
)
return super().call_operator(
exir_ops.edge.aten.avg_pool2d.default,
new_args,
kwargs,
meta,
)


# This class encapsulates all the functions that replace/switch one op in the
# graph with another.
class CadenceReplaceOpsInGraph:
Expand Down Expand Up @@ -2382,6 +2443,7 @@ class CadenceReplaceOpsInGraph:
ReplacePT2QuantWithCadenceQuantPass,
ReplacePT2DequantWithCadenceDequantPass,
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
ReplaceAtenAvgPoolWithJarvisAvgPoolPass,
ReplaceWhereWithFullArgsWithWhereScalar,
ReplaceAtenApproxGeluWithApproxGeluPass,
Expand Down
100 changes: 100 additions & 0 deletions backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from executorch.backends.cadence.aot.replace_ops import (
ForceChannelLastForConvPass,
MakeSliceAndCatDimOutermostPass,
ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass,
ReplaceAddMMWithLinearPass,
ReplaceAtenApproxGeluWithApproxGeluPass,
ReplaceAtenConvolutionWithJarvisConvolutionPass,
Expand Down Expand Up @@ -1936,3 +1937,102 @@ def test_extract_mul_argument_to_full(
},
)
)


class TestReplaceAdaptiveAvgPoolWithAtenAvgPoolPass(unittest.TestCase):
def _get_adaptive_avg_pool_gm(
self, input_shape: Tuple[int, int, int, int], output_shape: Tuple[int, int]
) -> torch.fx.GraphModule:
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*input_shape))
adaptive_avg_pool2d = builder.call_operator(
exir_ops.edge.aten._adaptive_avg_pool2d.default, (x, output_shape)
)
builder.output([adaptive_avg_pool2d])
return builder.get_graph_module()

def test_replace_adaptive_avg_pool_with_aten_avg_pool(self) -> None:
gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (8, 8))
self.assertEqual(
len(
gm.graph.find_nodes(
op="call_function",
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
)
),
1,
)
self.assertEqual(
len(
gm.graph.find_nodes(
op="call_function",
target=exir_ops.edge.aten.avg_pool2d.default,
)
),
0,
)
p = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass()
updated_gm = p.call(gm).graph_module
self.assertEqual(
len(
updated_gm.graph.find_nodes(
op="call_function",
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
)
),
0,
)
avg_pool2d_nodes = updated_gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
)
self.assertEqual(
len(avg_pool2d_nodes),
1,
)
avg_pool2d_node = avg_pool2d_nodes[0]

self.assertEqual(avg_pool2d_node.args[1], [16, 16]) # kernel_size is 16x16
self.assertEqual(avg_pool2d_node.args[2], [16, 16]) # stride is 16, 16
self.assertEqual(avg_pool2d_node.args[3], [0, 0]) # padding is 0, 0
self.assertEqual(avg_pool2d_node.args[4], False) # ceil_mode is False
self.assertEqual(avg_pool2d_node.args[5], True) # count_include_pad is True
self.assertEqual(avg_pool2d_node.args[6], None) # divisor_override is None

def test_replace_adaptive_avg_pool_with_aten_avg_pool_irregular(self) -> None:
gm = self._get_adaptive_avg_pool_gm((1, 64, 128, 128), (9, 9))
self.assertEqual(
len(
gm.graph.find_nodes(
op="call_function",
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
)
),
1,
)
self.assertEqual(
len(
gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
)
),
0,
)
# Shapes are not multiples of each other, so pass will not trigger
p = ReplaceAdaptiveAvgPoolWithAtenAvgPoolPass()
updated_gm = p.call(gm).graph_module
self.assertEqual(
len(
updated_gm.graph.find_nodes(
op="call_function",
target=exir_ops.edge.aten._adaptive_avg_pool2d.default,
)
),
1,
)
avg_pool2d_nodes = updated_gm.graph.find_nodes(
op="call_function", target=exir_ops.edge.aten.avg_pool2d.default
)
self.assertEqual(
len(avg_pool2d_nodes),
0,
)
Loading