Skip to content

Commit 943e34a

Browse files
authored
Update conv replacement pass + tests.
Differential Revision: D82577503 Pull Request resolved: pytorch#14363
1 parent 0b17bd2 commit 943e34a

File tree

2 files changed

+210
-43
lines changed

2 files changed

+210
-43
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 20 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -699,43 +699,27 @@ def call_operator(self, op, args, kwargs, meta):
699699
# graph operation (in this case a transpose_copy op) to be an explicit
700700
# ProxyValue as well. If not, the view op can be done directly on the
701701
# tensor.
702-
transposed_weight = (
703-
super().call_operator(
704-
exir_ops.edge.aten.transpose_copy.int,
705-
(
706-
weight,
707-
0,
708-
1,
709-
),
710-
kwargs,
711-
meta,
712-
)
713-
if isinstance(weight, ProxyValue)
714-
else weight.transpose(0, 1)
702+
transposed_weight = super().call_operator(
703+
exir_ops.edge.aten.transpose_copy.int,
704+
(
705+
weight,
706+
0,
707+
1,
708+
),
709+
kwargs,
710+
meta,
715711
)
716712

717-
flipped_weight = (
718-
super().call_operator(
719-
exir_ops.edge.aten.flip.default,
720-
(
721-
transposed_weight,
722-
[-1] if transposed_weight.to_tensor().dim() == 3 else [-1, -2],
723-
),
724-
kwargs,
725-
meta,
726-
)
727-
if isinstance(transposed_weight, ProxyValue)
728-
else (
729-
transposed_weight.flip(-1)
730-
if transposed_weight.dim() == 3
731-
else transposed_weight.flip(-1, -2)
732-
)
713+
flipped_weight = super().call_operator(
714+
exir_ops.edge.aten.flip.default,
715+
(
716+
transposed_weight,
717+
[-1] if transposed_weight.to_tensor().dim() == 3 else [-1, -2],
718+
),
719+
kwargs,
720+
meta,
733721
)
734722

735-
# From the previous checks, if flipped_weight is a FakeTensor, it has to be
736-
# a constant (if not, it would be a ProxyValue). Mark it as such.
737-
if isinstance(flipped_weight, FakeTensor):
738-
flipped_weight.constant = flipped_weight
739723
new_args = (
740724
in_tensor,
741725
flipped_weight,
@@ -751,16 +735,10 @@ def call_operator(self, op, args, kwargs, meta):
751735
# Verify that output_padding is 0.
752736
assert all(
753737
x == 0 for x in output_padding
754-
), "Cannot handle padded output in convolution"
738+
), f"Cannot handle padded output in convolution. Got {output_padding=}"
755739

756-
# If the innermost dim of output tensor is 1, then the stride
757-
# should be 1. Note that the first dimension of output tensor is
758-
# channel
759-
new_stride = stride.copy()
760-
out_shape = meta["val"].shape
761-
assert out_shape is not None
762-
for i, e in enumerate(out_shape[2:]):
763-
new_stride[i] = 1 if e == 1 else stride[i]
740+
# Keep the original stride to maintain correct output dimensions
741+
new_stride = stride
764742

765743
new_args = (
766744
in_tensor,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 190 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@
5252

5353
from executorch.backends.cadence.aot.typing_stubs import expand
5454
from executorch.exir.dialects._ops import ops as exir_ops
55-
from executorch.exir.pass_base import ExportPass
55+
from executorch.exir.pass_base import ExportPass, ProxyValue
5656
from executorch.exir.passes import dead_code_elimination_pass
5757
from torch.fx.passes.infra.pass_base import PassResult
58+
from torch.utils import _pytree as pytree
5859

5960

6061
class TestReplaceOpsPasses(unittest.TestCase):
@@ -345,6 +346,194 @@ def test_replace_functionally_equivalent_op_targets_unsafe_split(
345346
count_node(graph_after_passes, exir_ops.edge.aten.unsafe_split.Tensor), 0, x
346347
)
347348

349+
def assertTensorMetadataIsSame(
350+
self, a: Sequence[torch.Tensor], b: Sequence[torch.Tensor]
351+
) -> None:
352+
for i, (_a, _b) in enumerate(zip(a, b)):
353+
# TODO: actually compare the tensors.
354+
self.assertTrue(
355+
_a.shape == _b.shape, f"Tensor {i}: {_a.shape} != {_b.shape}"
356+
)
357+
self.assertTrue(
358+
_a.dtype == _b.dtype, f"Tensor {i}: {_a.dtype} != {_b.dtype}"
359+
)
360+
361+
@expand(
362+
[
363+
[(1, 8, 18), 8, 16, 3],
364+
[(1, 8, 18), 8, 16, 5, 2],
365+
# depthwise + bias
366+
[(1, 8, 18), 8, 16, 5, 2, 0, 1, True],
367+
# no bias
368+
[(1, 8, 18), 8, 16, 3, 2, 4, 3, False, False],
369+
# bias + transposed
370+
[(1, 8, 18), 8, 16, 5, 2, 0, 1, False, True],
371+
# Stride of 2 needed.
372+
[(1, 8, 3), 8, 8, 48, 2, 23],
373+
]
374+
)
375+
@torch.no_grad()
376+
def test_replace_aten_conv_with_cadence_conv(
377+
self,
378+
shape: Tuple[int, ...],
379+
in_channels: int,
380+
out_channels: int,
381+
kernel: int,
382+
stride: int = 1,
383+
padding: int = 0,
384+
dilation: int = 1,
385+
depthwise: bool = False,
386+
bias_enabled: bool = True,
387+
output_padding: Optional[int] = None,
388+
) -> None:
389+
groups = in_channels if depthwise else 1
390+
builder = GraphBuilder()
391+
x_tensor = torch.randn(*shape, dtype=torch.float32)
392+
x = builder.placeholder("x", x_tensor)
393+
weights_tensor = torch.randn(
394+
[out_channels, in_channels // groups, kernel], dtype=torch.float32
395+
)
396+
weights = builder.placeholder("weights", weights_tensor)
397+
bias: Optional[ProxyValue] = None
398+
bias_tensor: Optional[torch.Tensor] = None
399+
if bias_enabled:
400+
bias_tensor = torch.randn([out_channels], dtype=torch.float32)
401+
bias = builder.placeholder("bias", bias_tensor)
402+
convolution = builder.call_operator(
403+
op=exir_ops.edge.aten.convolution.default,
404+
args=(
405+
x,
406+
weights,
407+
bias,
408+
[stride],
409+
[padding],
410+
[dilation],
411+
False,
412+
[output_padding] if output_padding else [0],
413+
groups,
414+
),
415+
)
416+
builder.output([convolution])
417+
original_gm = builder.get_graph_module()
418+
419+
replacement_pass_result = (
420+
ReplaceAtenConvolutionWithCadenceConvolutionPass().call(original_gm)
421+
)
422+
self.assertIsNotNone(replacement_pass_result)
423+
graph_after_passes = replacement_pass_result.graph_module
424+
425+
self.assertEqual(
426+
count_node(graph_after_passes, exir_ops.edge.aten.convolution.default),
427+
0,
428+
)
429+
self.assertEqual(
430+
count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default),
431+
1,
432+
)
433+
self.assertEqual(
434+
count_node(
435+
graph_after_passes, exir_ops.edge.cadence.transposed_convolution.default
436+
),
437+
0,
438+
)
439+
440+
inputs = (x.to_tensor(), weights.to_tensor())
441+
if bias is not None:
442+
inputs += (bias.to_tensor(),)
443+
self.assertTensorMetadataIsSame(
444+
pytree.tree_flatten(original_gm.forward(*inputs))[0],
445+
pytree.tree_flatten(graph_after_passes.forward(*inputs))[0],
446+
)
447+
448+
@expand(
449+
[
450+
[(1, 8, 18), 8, 16, 3],
451+
[(1, 8, 18), 8, 16, 5, 2],
452+
# depthwise + bias
453+
[(1, 8, 18), 8, 16, 5, 2, 0, 1, True, True],
454+
# no bias
455+
[(1, 8, 18), 8, 16, 3, 2, 4, 3, False, False],
456+
# depthwise + no bias
457+
[(1, 8, 18), 8, 16, 3, 1, 0, 1, True, False],
458+
# bias
459+
[(1, 8, 18), 8, 16, 5, 2, 0, 1, False, True],
460+
]
461+
)
462+
@torch.no_grad()
463+
def test_replace_aten_transposed_conv_with_cadence_transposed_conv(
464+
self,
465+
shape: Tuple[int, ...],
466+
in_channels: int,
467+
out_channels: int,
468+
kernel: int,
469+
stride: int = 1,
470+
padding: int = 0,
471+
dilation: int = 1,
472+
depthwise: bool = False,
473+
bias_enabled: bool = True,
474+
output_padding: Optional[int] = None,
475+
) -> None:
476+
groups = in_channels if depthwise else 1
477+
builder = GraphBuilder()
478+
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
479+
weights_shape = [in_channels, out_channels // groups, kernel]
480+
weights = builder.placeholder(
481+
"weights",
482+
torch.randn(weights_shape, dtype=torch.float32),
483+
)
484+
bias = (
485+
builder.placeholder(
486+
"bias", torch.randn([out_channels], dtype=torch.float32)
487+
)
488+
if bias_enabled
489+
else None
490+
)
491+
convolution = builder.call_operator(
492+
op=exir_ops.edge.aten.convolution.default,
493+
args=(
494+
x,
495+
weights,
496+
bias,
497+
[stride],
498+
[padding],
499+
[dilation],
500+
True,
501+
[output_padding] if output_padding else [0],
502+
groups,
503+
),
504+
)
505+
builder.output([convolution])
506+
original_gm = builder.get_graph_module()
507+
508+
replacement_pass_result = (
509+
ReplaceAtenConvolutionWithCadenceConvolutionPass().call(original_gm)
510+
)
511+
self.assertIsNotNone(replacement_pass_result)
512+
graph_after_passes = replacement_pass_result.graph_module
513+
514+
self.assertEqual(
515+
count_node(graph_after_passes, exir_ops.edge.aten.convolution.default),
516+
0,
517+
)
518+
self.assertEqual(
519+
count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default),
520+
0,
521+
)
522+
self.assertEqual(
523+
count_node(
524+
graph_after_passes, exir_ops.edge.cadence.transposed_convolution.default
525+
),
526+
1,
527+
)
528+
529+
inputs = (x.to_tensor(), weights.to_tensor())
530+
if bias is not None:
531+
inputs += (bias.to_tensor(),)
532+
self.assertTensorMetadataIsSame(
533+
pytree.tree_flatten(original_gm.forward(*inputs))[0],
534+
pytree.tree_flatten(graph_after_passes.forward(*inputs))[0],
535+
)
536+
348537
@expand(
349538
[
350539
[(1, 8, 33), 8, 16, 3],

0 commit comments

Comments
 (0)