Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 20 additions & 42 deletions backends/cadence/aot/replace_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,43 +699,27 @@ def call_operator(self, op, args, kwargs, meta):
# graph operation (in this case a transpose_copy op) to be an explicit
# ProxyValue as well. If not, the view op can be done directly on the
# tensor.
transposed_weight = (
super().call_operator(
exir_ops.edge.aten.transpose_copy.int,
(
weight,
0,
1,
),
kwargs,
meta,
)
if isinstance(weight, ProxyValue)
else weight.transpose(0, 1)
transposed_weight = super().call_operator(
exir_ops.edge.aten.transpose_copy.int,
(
weight,
0,
1,
),
kwargs,
meta,
)

flipped_weight = (
super().call_operator(
exir_ops.edge.aten.flip.default,
(
transposed_weight,
[-1] if transposed_weight.to_tensor().dim() == 3 else [-1, -2],
),
kwargs,
meta,
)
if isinstance(transposed_weight, ProxyValue)
else (
transposed_weight.flip(-1)
if transposed_weight.dim() == 3
else transposed_weight.flip(-1, -2)
)
flipped_weight = super().call_operator(
exir_ops.edge.aten.flip.default,
(
transposed_weight,
[-1] if transposed_weight.to_tensor().dim() == 3 else [-1, -2],
),
kwargs,
meta,
)

# From the previous checks, if flipped_weight is a FakeTensor, it has to be
# a constant (if not, it would be a ProxyValue). Mark it as such.
if isinstance(flipped_weight, FakeTensor):
flipped_weight.constant = flipped_weight
new_args = (
in_tensor,
flipped_weight,
Expand All @@ -751,16 +735,10 @@ def call_operator(self, op, args, kwargs, meta):
# Verify that output_padding is 0.
assert all(
x == 0 for x in output_padding
), "Cannot handle padded output in convolution"
), f"Cannot handle padded output in convolution. Got {output_padding=}"

# If the innermost dim of output tensor is 1, then the stride
# should be 1. Note that the first dimension of output tensor is
# channel
new_stride = stride.copy()
out_shape = meta["val"].shape
assert out_shape is not None
for i, e in enumerate(out_shape[2:]):
new_stride[i] = 1 if e == 1 else stride[i]
# Keep the original stride to maintain correct output dimensions
new_stride = stride

new_args = (
in_tensor,
Expand Down
191 changes: 190 additions & 1 deletion backends/cadence/aot/tests/test_replace_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,10 @@

from executorch.backends.cadence.aot.typing_stubs import expand
from executorch.exir.dialects._ops import ops as exir_ops
from executorch.exir.pass_base import ExportPass
from executorch.exir.pass_base import ExportPass, ProxyValue
from executorch.exir.passes import dead_code_elimination_pass
from torch.fx.passes.infra.pass_base import PassResult
from torch.utils import _pytree as pytree


class TestReplaceOpsPasses(unittest.TestCase):
Expand Down Expand Up @@ -345,6 +346,194 @@ def test_replace_functionally_equivalent_op_targets_unsafe_split(
count_node(graph_after_passes, exir_ops.edge.aten.unsafe_split.Tensor), 0, x
)

def assertTensorMetadataIsSame(
self, a: Sequence[torch.Tensor], b: Sequence[torch.Tensor]
) -> None:
for i, (_a, _b) in enumerate(zip(a, b)):
# TODO: actually compare the tensors.
self.assertTrue(
_a.shape == _b.shape, f"Tensor {i}: {_a.shape} != {_b.shape}"
)
self.assertTrue(
_a.dtype == _b.dtype, f"Tensor {i}: {_a.dtype} != {_b.dtype}"
)

@expand(
[
[(1, 8, 18), 8, 16, 3],
[(1, 8, 18), 8, 16, 5, 2],
# depthwise + bias
[(1, 8, 18), 8, 16, 5, 2, 0, 1, True],
# no bias
[(1, 8, 18), 8, 16, 3, 2, 4, 3, False, False],
# bias + transposed
[(1, 8, 18), 8, 16, 5, 2, 0, 1, False, True],
# Stride of 2 needed.
[(1, 8, 3), 8, 8, 48, 2, 23],
]
)
@torch.no_grad()
def test_replace_aten_conv_with_cadence_conv(
self,
shape: Tuple[int, ...],
in_channels: int,
out_channels: int,
kernel: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
depthwise: bool = False,
bias_enabled: bool = True,
output_padding: Optional[int] = None,
) -> None:
groups = in_channels if depthwise else 1
builder = GraphBuilder()
x_tensor = torch.randn(*shape, dtype=torch.float32)
x = builder.placeholder("x", x_tensor)
weights_tensor = torch.randn(
[out_channels, in_channels // groups, kernel], dtype=torch.float32
)
weights = builder.placeholder("weights", weights_tensor)
bias: Optional[ProxyValue] = None
bias_tensor: Optional[torch.Tensor] = None
if bias_enabled:
bias_tensor = torch.randn([out_channels], dtype=torch.float32)
bias = builder.placeholder("bias", bias_tensor)
convolution = builder.call_operator(
op=exir_ops.edge.aten.convolution.default,
args=(
x,
weights,
bias,
[stride],
[padding],
[dilation],
False,
[output_padding] if output_padding else [0],
groups,
),
)
builder.output([convolution])
original_gm = builder.get_graph_module()

replacement_pass_result = (
ReplaceAtenConvolutionWithCadenceConvolutionPass().call(original_gm)
)
self.assertIsNotNone(replacement_pass_result)
graph_after_passes = replacement_pass_result.graph_module

self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.convolution.default),
0,
)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default),
1,
)
self.assertEqual(
count_node(
graph_after_passes, exir_ops.edge.cadence.transposed_convolution.default
),
0,
)

inputs = (x.to_tensor(), weights.to_tensor())
if bias is not None:
inputs += (bias.to_tensor(),)
self.assertTensorMetadataIsSame(
pytree.tree_flatten(original_gm.forward(*inputs))[0],
pytree.tree_flatten(graph_after_passes.forward(*inputs))[0],
)

@expand(
[
[(1, 8, 18), 8, 16, 3],
[(1, 8, 18), 8, 16, 5, 2],
# depthwise + bias
[(1, 8, 18), 8, 16, 5, 2, 0, 1, True, True],
# no bias
[(1, 8, 18), 8, 16, 3, 2, 4, 3, False, False],
# depthwise + no bias
[(1, 8, 18), 8, 16, 3, 1, 0, 1, True, False],
# bias
[(1, 8, 18), 8, 16, 5, 2, 0, 1, False, True],
]
)
@torch.no_grad()
def test_replace_aten_transposed_conv_with_cadence_transposed_conv(
self,
shape: Tuple[int, ...],
in_channels: int,
out_channels: int,
kernel: int,
stride: int = 1,
padding: int = 0,
dilation: int = 1,
depthwise: bool = False,
bias_enabled: bool = True,
output_padding: Optional[int] = None,
) -> None:
groups = in_channels if depthwise else 1
builder = GraphBuilder()
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
weights_shape = [in_channels, out_channels // groups, kernel]
weights = builder.placeholder(
"weights",
torch.randn(weights_shape, dtype=torch.float32),
)
bias = (
builder.placeholder(
"bias", torch.randn([out_channels], dtype=torch.float32)
)
if bias_enabled
else None
)
convolution = builder.call_operator(
op=exir_ops.edge.aten.convolution.default,
args=(
x,
weights,
bias,
[stride],
[padding],
[dilation],
True,
[output_padding] if output_padding else [0],
groups,
),
)
builder.output([convolution])
original_gm = builder.get_graph_module()

replacement_pass_result = (
ReplaceAtenConvolutionWithCadenceConvolutionPass().call(original_gm)
)
self.assertIsNotNone(replacement_pass_result)
graph_after_passes = replacement_pass_result.graph_module

self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.aten.convolution.default),
0,
)
self.assertEqual(
count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default),
0,
)
self.assertEqual(
count_node(
graph_after_passes, exir_ops.edge.cadence.transposed_convolution.default
),
1,
)

inputs = (x.to_tensor(), weights.to_tensor())
if bias is not None:
inputs += (bias.to_tensor(),)
self.assertTensorMetadataIsSame(
pytree.tree_flatten(original_gm.forward(*inputs))[0],
pytree.tree_flatten(graph_after_passes.forward(*inputs))[0],
)

@expand(
[
[(1, 8, 33), 8, 16, 3],
Expand Down
Loading