diff --git a/graph_net/torch/dim_gen_passes/naive_call_method_expand_pass.py b/graph_net/torch/dim_gen_passes/naive_call_method_expand_pass.py index e372cc84c..99fb7f552 100644 --- a/graph_net/torch/dim_gen_passes/naive_call_method_expand_pass.py +++ b/graph_net/torch/dim_gen_passes/naive_call_method_expand_pass.py @@ -21,7 +21,7 @@ def need_rewrite(self, traced_module: fx.GraphModule) -> bool: def _node_need_rewrite(self, node) -> bool: if not (node.op == "call_method"): return False - if not (node.op == "expand"): + if not (node.target == "expand"): return False input_tensor_node = node.args[0] input_meta = input_tensor_node.meta.get("tensor_meta")