diff --git a/exir/passes/remove_noop_pass.py b/exir/passes/remove_noop_pass.py index d9b99556636..a00b3f254dc 100644 --- a/exir/passes/remove_noop_pass.py +++ b/exir/passes/remove_noop_pass.py @@ -104,7 +104,7 @@ def call(self, graph_module: GraphModule) -> PassResult: if node.op != "call_function": continue - if node.target not in (torch.ops.aten._to_copy.default,): + if node.target not in (torch.ops.aten._to_copy.default, torch.ops.aten.clone.default): continue orig_tensor = node.args[0].meta["val"]