Skip to content

Commit fd04fe0

Browse files
authored
Update ReplaceConvolutionOptionalArgsWithConcreteArgsPass to work with cadence.convolution
Differential Revision: D82842567 Pull Request resolved: #14445
1 parent a954a75 commit fd04fe0

File tree

2 files changed

+5
-8
lines changed

2 files changed

+5
-8
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -438,11 +438,11 @@ class ReplaceConvolutionOptionalArgsWithConcreteArgsPass(ExportPass):
438438
"""
439439

440440
def call_operator(self, op, args, kwargs, meta):
441-
if get_edge_overload_packet(op) != exir_ops.edge.aten.convolution:
441+
if get_edge_overload_packet(op) != exir_ops.edge.cadence.convolution:
442442
return super().call_operator(op, args, kwargs, meta)
443443

444444
# Check if the bias is already concrete
445-
assert len(args) == 9
445+
assert len(args) == 8
446446
if args[2] is not None:
447447
return super().call_operator(op, args, kwargs, meta)
448448

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -455,8 +455,6 @@ def test_replace_convolution_optional_args_with_concrete_args(
455455
bias_enabled: bool = True,
456456
channel_last: bool = False,
457457
) -> None:
458-
transposed = True
459-
output_padding = [0]
460458
groups = in_channels if depthwise else 1
461459
builder = GraphBuilder()
462460
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
@@ -477,17 +475,16 @@ def test_replace_convolution_optional_args_with_concrete_args(
477475
args=(x, [0, 2, 1]),
478476
)
479477
convolution = builder.call_operator(
480-
op=exir_ops.edge.aten.convolution.default,
478+
op=exir_ops.edge.cadence.convolution.default,
481479
args=(
482480
x,
483481
weights,
484482
bias,
485483
[stride],
486484
[padding],
487485
[dilation],
488-
transposed,
489-
output_padding,
490486
groups,
487+
False,
491488
),
492489
)
493490
if channel_last:
@@ -504,7 +501,7 @@ def test_replace_convolution_optional_args_with_concrete_args(
504501
1,
505502
)
506503
self.assertEqual(
507-
count_node(graph_after_passes, exir_ops.edge.aten.convolution.default),
504+
count_node(graph_after_passes, exir_ops.edge.cadence.convolution.default),
508505
1,
509506
)
510507

0 commit comments

Comments
 (0)