Skip to content

Commit fcb9618

Browse files
linting
Signed-off-by: cliu-us <[email protected]>
1 parent 496cc5d commit fcb9618

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

fms_mo/fx/dynamo_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434

3535

3636
def run_fwd_once(model, sample_inp):
37+
"""Convenient function to run model once using correct input unpack."""
3738
with torch.no_grad():
3839
if isinstance(sample_inp, dict) or all(
3940
hasattr(sample_inp, k) for k in ("keys", "values", "items")
@@ -252,7 +253,7 @@ def _dfs(curr_node, depth):
252253
def find_conv_on_shortcut_gm(
253254
gm: torch.fx.GraphModule,
254255
lut_fx_mod_name_to_org: Optional[Dict[str, str]] = None,
255-
lut_name_to_mod={},
256+
lut_name_to_mod=None,
256257
):
257258
"""Identify Conv on shortcut using FX GM DFS
258259
It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
@@ -277,6 +278,9 @@ def find_conv_on_shortcut_gm(
277278
5. count levels of each branch, decide which one is the shortcut
278279
"""
279280

281+
if lut_name_to_mod is None:
282+
lut_name_to_mod = {}
283+
280284
# 1. Find "add" nodes, including inplace add as some may use "out+=shortcut"
281285
nodes_add = dfs_gm(gm, ["add"], return_nodes=True)
282286

fms_mo/fx/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def get_org_mod_name_of_fx_node(
343343
str: corresponding name on original graph
344344
"""
345345
org_name = f"Unknown:{node.name}"
346-
if lut_fx2org == None:
346+
if lut_fx2org is None:
347347
lut_fx2org = {}
348348
if "nn_module_stack" in node.meta:
349349
n_fx_mod_name = list(node.meta["nn_module_stack"].keys())[-1]

0 commit comments

Comments
 (0)