File tree Expand file tree Collapse file tree 2 files changed +6
-2
lines changed
Expand file tree Collapse file tree 2 files changed +6
-2
lines changed Original file line number Diff line number Diff line change 3434
3535
3636def run_fwd_once (model , sample_inp ):
37+ """Convenient function to run model once using correct input unpack."""
3738 with torch .no_grad ():
3839 if isinstance (sample_inp , dict ) or all (
3940 hasattr (sample_inp , k ) for k in ("keys" , "values" , "items" )
@@ -252,7 +253,7 @@ def _dfs(curr_node, depth):
252253def find_conv_on_shortcut_gm (
253254 gm : torch .fx .GraphModule ,
254255 lut_fx_mod_name_to_org : Optional [Dict [str , str ]] = None ,
255- lut_name_to_mod = {} ,
256+ lut_name_to_mod = None ,
256257):
257258 """Identify Conv on shortcut using FX GM DFS
258259 It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
@@ -277,6 +278,9 @@ def find_conv_on_shortcut_gm(
277278 5. count levels of each branch, decide which one is the shortcut
278279 """
279280
281+ if lut_name_to_mod is None :
282+ lut_name_to_mod = {}
283+
280284 # 1. Find "add" nodes, including inplace add as some may use "out+=shortcut"
281285 nodes_add = dfs_gm (gm , ["add" ], return_nodes = True )
282286
Original file line number Diff line number Diff line change @@ -343,7 +343,7 @@ def get_org_mod_name_of_fx_node(
343343 str: corresponding name on original graph
344344 """
345345 org_name = f"Unknown:{ node .name } "
346- if lut_fx2org == None :
346+ if lut_fx2org is None :
347347 lut_fx2org = {}
348348 if "nn_module_stack" in node .meta :
349349 n_fx_mod_name = list (node .meta ["nn_module_stack" ].keys ())[- 1 ]
You can’t perform that action at this time.
0 commit comments