diff --git a/backends/cadence/aot/remove_ops.py b/backends/cadence/aot/remove_ops.py index 97cb1ae49d1..996dfa43f8f 100644 --- a/backends/cadence/aot/remove_ops.py +++ b/backends/cadence/aot/remove_ops.py @@ -235,10 +235,7 @@ def call_operator( kwargs: dict[str, Argument], meta: NodeMetadata, ) -> ProxyValue: - if op not in { - exir_ops.edge.aten.linalg_vector_norm.default, - exir_ops.edge.cadence.linalg_vector_norm.default, - }: + if op is not exir_ops.edge.aten.linalg_vector_norm.default: return super().call_operator(op, args, kwargs, meta) # If the op has three args or less, it can't be a nop diff --git a/backends/cadence/aot/tests/test_remove_ops_passes.py b/backends/cadence/aot/tests/test_remove_ops_passes.py index aed62089cea..72d09911072 100644 --- a/backends/cadence/aot/tests/test_remove_ops_passes.py +++ b/backends/cadence/aot/tests/test_remove_ops_passes.py @@ -465,10 +465,7 @@ def forward(self, x: torch.Tensor): # Expect the linalg_vector_norm op to be removed by the pass self.assertEqual( - count_node(graph_module, exir_ops.edge.aten.linalg_vector_norm.default) - + count_node( - graph_module, exir_ops.edge.cadence.linalg_vector_norm.default - ), + count_node(graph_module, exir_ops.edge.aten.linalg_vector_norm.default), 0, )