Skip to content

Commit 16ac11d

Browse files
authored
Fix typo in naive_call_method_expand_pass. (#507)
1 parent cdf9f09 commit 16ac11d

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

graph_net/torch/dim_gen_passes/naive_call_method_expand_pass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def need_rewrite(self, traced_module: fx.GraphModule) -> bool:
2121
def _node_need_rewrite(self, node) -> bool:
2222
if not (node.op == "call_method"):
2323
return False
24-
if not (node.op == "expand"):
24+
if not (node.target == "expand"):
2525
return False
2626
input_tensor_node = node.args[0]
2727
input_meta = input_tensor_node.meta.get("tensor_meta")

0 commit comments

Comments
 (0)