Skip to content

Commit 4cd2eef

Browse files
hsharma35facebook-github-bot
authored andcommitted
Update conv replacement pass + tests. (pytorch#14363)
Summary: Fix aten->cadence convolution conversion for `wav2letter` torchaudio model. Differential Revision: D82577503
1 parent c1b7ec5 commit 4cd2eef

File tree

2 files changed

+210
-43
lines changed

2 files changed

+210
-43
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 20 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -693,43 +693,27 @@ def call_operator(self, op, args, kwargs, meta):
693693
# graph operation (in this case a transpose_copy op) to be an explicit
694694
# ProxyValue as well. If not, the view op can be done directly on the
695695
# tensor.
696-
transposed_weight = (
697-
super().call_operator(
698-
exir_ops.edge.aten.transpose_copy.int,
699-
(
700-
weight,
701-
0,
702-
1,
703-
),
704-
kwargs,
705-
meta,
706-
)
707-
if isinstance(weight, ProxyValue)
708-
else weight.transpose(0, 1)
696+
transposed_weight = super().call_operator(
697+
exir_ops.edge.aten.transpose_copy.int,
698+
(
699+
weight,
700+
0,
701+
1,
702+
),
703+
kwargs,
704+
meta,
709705
)
710706

711-
flipped_weight = (
712-
super().call_operator(
713-
exir_ops.edge.aten.flip.default,
714-
(
715-
transposed_weight,
716-
[-1] if transposed_weight.to_tensor().dim() == 3 else [-1, -2],
717-
),
718-
kwargs,
719-
meta,
720-
)
721-
if isinstance(transposed_weight, ProxyValue)
722-
else (
723-
transposed_weight.flip(-1)
724-
if transposed_weight.dim() == 3
725-
else transposed_weight.flip(-1, -2)
726-
)
707+
flipped_weight = super().call_operator(
708+
exir_ops.edge.aten.flip.default,
709+
(
710+
transposed_weight,
711+
[-1] if transposed_weight.to_tensor().dim() == 3 else [-1, -2],
712+
),
713+
kwargs,
714+
meta,
727715
)
728716

729-
# From the previous checks, if flipped_weight is a FakeTensor, it has to be
730-
# a constant (if not, it would be a ProxyValue). Mark it as such.
731-
if isinstance(flipped_weight, FakeTensor):
732-
flipped_weight.constant = flipped_weight
733717
new_args = (
734718
in_tensor,
735719
flipped_weight,
@@ -745,16 +729,10 @@ def call_operator(self, op, args, kwargs, meta):
745729
# Verify that output_padding is 0.
746730
assert all(
747731
x == 0 for x in output_padding
748-
), "Cannot handle padded output in convolution"
732+
), f"Cannot handle padded output in convolution. Got {output_padding=}"
749733

750-
# If the innermost dim of output tensor is 1, then the stride
751-
# should be 1. Note that the first dimension of output tensor is
752-
# channel
753-
new_stride = stride.copy()
754-
out_shape = meta["val"].shape
755-
assert out_shape is not None
756-
for i, e in enumerate(out_shape[2:]):
757-
new_stride[i] = 1 if e == 1 else stride[i]
734+
# Keep the original stride to maintain correct output dimensions
735+
new_stride = stride
758736

759737
new_args = (
760738
in_tensor,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 190 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@
5252

5353
from executorch.backends.cadence.aot.typing_stubs import expand
5454
from executorch.exir.dialects._ops import ops as exir_ops
55-
from executorch.exir.pass_base import ExportPass
55+
from executorch.exir.pass_base import ExportPass, ProxyValue
5656
from executorch.exir.passes import dead_code_elimination_pass
5757
from torch.fx.passes.infra.pass_base import PassResult
58+
from torch.utils import _pytree as pytree
5859

5960

6061
class TestReplaceOpsPasses(unittest.TestCase):
@@ -345,6 +346,194 @@ def test_replace_functionally_equivalent_op_targets_unsafe_split(
345346
count_node(graph_after_passes, exir_ops.edge.aten.unsafe_split.Tensor), 0, x
346347
)
347348

349+
def assertTensorMetadataIsSame(
350+
self, a: Sequence[torch.Tensor], b: Sequence[torch.Tensor]
351+
) -> None:
352+
for i, (_a, _b) in enumerate(zip(a, b)):
353+
# TODO: actually compare the tensors.
354+
self.assertTrue(
355+
_a.shape == _b.shape, f"Tensor {i}: {_a.shape} != {_b.shape}"
356+
)
357+
self.assertTrue(
358+
_a.dtype == _b.dtype, f"Tensor {i}: {_a.dtype} != {_b.dtype}"
359+
)
360+
361+
@expand(
362+
[
363+
[(1, 8, 18), 8, 16, 3],
364+
[(1, 8, 18), 8, 16, 5, 2],
365+
# depthwise + bias
366+
[(1, 8, 18), 8, 16, 5, 2, 0, 1, True],
367+
# no bias
368+
[(1, 8, 18), 8, 16, 3, 2, 4, 3, False, False],
369+
# bias + transposed
370+
[(1, 8, 18), 8, 16, 5, 2, 0, 1, False, True],
371+
# Stride of 2 needed.
372+
[(1, 8, 3), 8, 8, 48, 2, 23],
373+
]
374+
)
375+
@torch.no_grad()
376+
def test_replace_aten_conv_with_cadence_conv(
377+
self,
378+
shape: Tuple[int, ...],
379+
in_channels: int,
380+
out_channels: int,
381+
kernel: int,
382+
stride: int = 1,
383+
padding: int = 0,
384+
dilation: int = 1,
385+
depthwise: bool = False,
386+
bias_enabled: bool = True,
387+
output_padding: Optional[int] = None,
388+
) -> None:
389+
groups = in_channels if depthwise else 1
390+
builder = GraphBuilder()
391+
x_tensor = torch.randn(*shape, dtype=torch.float32)
392+
x = builder.placeholder("x", x_tensor)
393+
weights_tensor = torch.randn(
394+
[out_channels, in_channels // groups, kernel], dtype=torch.float32
395+
)
396+
weights = builder.placeholder("weights", weights_tensor)
397+
bias: Optional[ProxyValue] = None
398+
bias_tensor: Optional[torch.Tensor] = None
399+
if bias_enabled:
400+
bias_tensor = torch.randn([out_channels], dtype=torch.float32)
401+
bias = builder.placeholder("bias", bias_tensor)
402+
convolution = builder.call_operator(
403+
op=exir_ops.edge.aten.convolution.default,
404+
args=(
405+
x,
406+
weights,
407+
bias,
408+
[stride],
409+
[padding],
410+
[dilation],
411+
False,
412+
[output_padding] if output_padding else [0],
413+
groups,
414+
),
415+
)
416+
builder.output([convolution])
417+
original_gm = builder.get_graph_module()
418+
419+
replacement_pass_result = (
420+
ReplaceAtenConvolutionWithCadenceConvolutionPass().call(original_gm)
421+
)
422+
self.assertIsNotNone(replacement_pass_result)
423+
graph_after_passes = replacement_pass_result.graph_module
424+
425+
self.assertEqual(
426+
count_node(graph_after_passes, exir_ops.edge.aten.convolution.default),
427+
0,
428+
)
429+
self.assertEqual(
430+
count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default),
431+
1,
432+
)
433+
self.assertEqual(
434+
count_node(
435+
graph_after_passes, exir_ops.edge.cadence.transposed_convolution.default
436+
),
437+
0,
438+
)
439+
440+
inputs = (x.to_tensor(), weights.to_tensor())
441+
if bias is not None:
442+
inputs += (bias.to_tensor(),)
443+
self.assertTensorMetadataIsSame(
444+
pytree.tree_flatten(original_gm.forward(*inputs))[0],
445+
pytree.tree_flatten(graph_after_passes.forward(*inputs))[0],
446+
)
447+
448+
@expand(
449+
[
450+
[(1, 8, 18), 8, 16, 3],
451+
[(1, 8, 18), 8, 16, 5, 2],
452+
# depthwise + bias
453+
[(1, 8, 18), 8, 16, 5, 2, 0, 1, True, True],
454+
# no bias
455+
[(1, 8, 18), 8, 16, 3, 2, 4, 3, False, False],
456+
# depthwise + no bias
457+
[(1, 8, 18), 8, 16, 3, 1, 0, 1, True, False],
458+
# bias
459+
[(1, 8, 18), 8, 16, 5, 2, 0, 1, False, True],
460+
]
461+
)
462+
@torch.no_grad()
463+
def test_replace_aten_transposed_conv_with_cadence_transposed_conv(
464+
self,
465+
shape: Tuple[int, ...],
466+
in_channels: int,
467+
out_channels: int,
468+
kernel: int,
469+
stride: int = 1,
470+
padding: int = 0,
471+
dilation: int = 1,
472+
depthwise: bool = False,
473+
bias_enabled: bool = True,
474+
output_padding: Optional[int] = None,
475+
) -> None:
476+
groups = in_channels if depthwise else 1
477+
builder = GraphBuilder()
478+
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
479+
weights_shape = [in_channels, out_channels // groups, kernel]
480+
weights = builder.placeholder(
481+
"weights",
482+
torch.randn(weights_shape, dtype=torch.float32),
483+
)
484+
bias = (
485+
builder.placeholder(
486+
"bias", torch.randn([out_channels], dtype=torch.float32)
487+
)
488+
if bias_enabled
489+
else None
490+
)
491+
convolution = builder.call_operator(
492+
op=exir_ops.edge.aten.convolution.default,
493+
args=(
494+
x,
495+
weights,
496+
bias,
497+
[stride],
498+
[padding],
499+
[dilation],
500+
True,
501+
[output_padding] if output_padding else [0],
502+
groups,
503+
),
504+
)
505+
builder.output([convolution])
506+
original_gm = builder.get_graph_module()
507+
508+
replacement_pass_result = (
509+
ReplaceAtenConvolutionWithCadenceConvolutionPass().call(original_gm)
510+
)
511+
self.assertIsNotNone(replacement_pass_result)
512+
graph_after_passes = replacement_pass_result.graph_module
513+
514+
self.assertEqual(
515+
count_node(graph_after_passes, exir_ops.edge.aten.convolution.default),
516+
0,
517+
)
518+
self.assertEqual(
519+
count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default),
520+
0,
521+
)
522+
self.assertEqual(
523+
count_node(
524+
graph_after_passes, exir_ops.edge.cadence.transposed_convolution.default
525+
),
526+
1,
527+
)
528+
529+
inputs = (x.to_tensor(), weights.to_tensor())
530+
if bias is not None:
531+
inputs += (bias.to_tensor(),)
532+
self.assertTensorMetadataIsSame(
533+
pytree.tree_flatten(original_gm.forward(*inputs))[0],
534+
pytree.tree_flatten(graph_after_passes.forward(*inputs))[0],
535+
)
536+
348537
@expand(
349538
[
350539
[(1, 8, 33), 8, 16, 3],

0 commit comments

Comments
 (0)