Skip to content

Commit 39f1f85

Browse files
variable renaming
1 parent 1b7e829 commit 39f1f85

File tree

2 files changed

+37
-37
lines changed

2 files changed

+37
-37
lines changed

fms_mo/fx/dynamo_utils.py

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def dfs_gm(
4646
prescreenOp=None,
4747
hook=None,
4848
return_nodes=False,
49-
LUTfx_mod_name_to_org={},
49+
lut_fx_mod_name_to_org={},
5050
):
5151
"""Depth-First Search at FX IR level, to replace our old TorchScript equivalent func
5252
Because FX IR is a higher level IR, should have much fewer
@@ -217,7 +217,7 @@ def _dfs(curr_node, depth):
217217
for n_ln, d in node_found.items():
218218
n, line_num = n_ln # unpack tuple
219219
org_mod_names[
220-
get_org_mod_name_of_fx_node(n, gm, LUTfx_mod_name_to_org), line_num
220+
get_org_mod_name_of_fx_node(n, gm, lut_fx_mod_name_to_org), line_num
221221
] = d # see Note 2
222222

223223
return dict(
@@ -227,7 +227,7 @@ def _dfs(curr_node, depth):
227227
return node_found
228228

229229

230-
def find_conv_on_shortcut_gm(gm: torch.fx.GraphModule, LUTfx_mod_name_to_org={}):
230+
def find_conv_on_shortcut_gm(gm: torch.fx.GraphModule, lut_fx_mod_name_to_org={}):
231231
"""Identify Conv on shortcut using FX GM DFS
232232
It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
233233
original model, not FX module names)
@@ -335,18 +335,18 @@ def find_conv_on_shortcut_gm(gm: torch.fx.GraphModule, LUTfx_mod_name_to_org={})
335335
conv_mod = gm.get_submodule(n_conv_i.target)
336336
else:
337337
conv_mod = get_org_mod_name_of_fx_node(
338-
n_conv_i, LUTfx2org=LUTfx_mod_name_to_org
338+
n_conv_i, lut_fx2org=lut_fx_mod_name_to_org
339339
)
340340
if conv_mod.out_channels > conv_mod.in_channels: # see Note 2
341341
qconv_candidate.append(
342-
get_org_mod_name_of_fx_node(n_conv_i, gm, LUTfx_mod_name_to_org)
342+
get_org_mod_name_of_fx_node(n_conv_i, gm, lut_fx_mod_name_to_org)
343343
)
344344

345345
return qconv_candidate
346346

347347

348348
def find_1st_last_gm(
349-
gm, firstOps=None, lastOps=None, return_1st_last_sep=False, LUTfx_mod_name_to_org={}
349+
gm, firstOps=None, lastOps=None, return_1st_last_sep=False, lut_fx_mod_name_to_org={}
350350
):
351351
"""Identify the first and last layer of interests
352352
Usually only interested in Conv and Linear, but could be others as well
@@ -366,14 +366,14 @@ def find_1st_last_gm(
366366
gm,
367367
targetOp=firstOps,
368368
find1stOnly=True,
369-
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
369+
lut_fx_mod_name_to_org=lut_fx_mod_name_to_org,
370370
)
371371
last_candidates = dfs_gm(
372372
gm,
373373
targetOp=lastOps,
374374
find1stOnly=True,
375375
reverse=True,
376-
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
376+
lut_fx_mod_name_to_org=lut_fx_mod_name_to_org,
377377
)
378378

379379
min_depth = min(list(first_candidates.values()) + [999])
@@ -397,7 +397,7 @@ def find_1st_last_gm(
397397

398398

399399
def find_single_sided_op_gm(
400-
gm, op_of_interest=None, return_LUTs=False, verbose=False, LUTfx_mod_name_to_org={}
400+
gm, op_of_interest=None, return_LUTs=False, verbose=False, lut_fx_mod_name_to_org={}
401401
):
402402
"""Try to determine the "polarity" of output of "every nodes" based on their inputs and the Op
403403
itself, then decide which Conv/Linear (or user-specified Op) will use single-sided quantizer
@@ -546,7 +546,7 @@ def find_single_sided_op_gm(
546546
gm,
547547
targetOp=op_of_interest,
548548
return_nodes=True,
549-
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
549+
lut_fx_mod_name_to_org=lut_fx_mod_name_to_org,
550550
)
551551

552552
SingleSidedOps = []
@@ -557,7 +557,7 @@ def find_single_sided_op_gm(
557557
risky_nodes.append(n)
558558
if all(input_pos):
559559
SingleSidedOps.append(
560-
get_org_mod_name_of_fx_node(n, gm, LUTfx_mod_name_to_org)
560+
get_org_mod_name_of_fx_node(n, gm, lut_fx_mod_name_to_org)
561561
)
562562

563563
if risky_nodes:
@@ -570,7 +570,7 @@ def find_single_sided_op_gm(
570570
return SingleSidedOps
571571

572572

573-
def find_qkvsync_candidates_gm(gm, return_nodes=False, LUTfx_mod_name_to_org={}):
573+
def find_qkvsync_candidates_gm(gm, return_nodes=False, lut_fx_mod_name_to_org={}):
574574
"""Identify groups of Linears that share the same parent. It's a transformer-specific feature.
575575
576576
NOTE:
@@ -609,7 +609,7 @@ def find_qkvsync_candidates_gm(gm, return_nodes=False, LUTfx_mod_name_to_org={})
609609
for depth, nodes in LUTdep2nodes.items():
610610
parents = [ni.all_input_nodes[0] for ni in nodes]
611611
org_mod_names = [
612-
get_org_mod_name_of_fx_node(ni, gm, LUTfx_mod_name_to_org) for ni in nodes
612+
get_org_mod_name_of_fx_node(ni, gm, lut_fx_mod_name_to_org) for ni in nodes
613613
]
614614
if all(p == parents[0] for p in parents[1:]):
615615
Nshared_parents += 1
@@ -620,7 +620,7 @@ def find_qkvsync_candidates_gm(gm, return_nodes=False, LUTfx_mod_name_to_org={})
620620
return my_1st_sibling
621621

622622

623-
def find_silu_gm(gm, LUTfx_mod_name_to_org={}):
623+
def find_silu_gm(gm, lut_fx_mod_name_to_org={}):
624624
"""Special handle for Conv following silu, specific for EffDet and etc
625625
LLM could use SiLU as well (llama?), but not relavent to this func
626626
"""
@@ -633,14 +633,14 @@ def find_silu_gm(gm, LUTfx_mod_name_to_org={}):
633633
gpOp = get_target_op_from_node(gp_nodes[0], gm) if gp_nodes else None
634634

635635
if torch.nn.functional.silu in [pOp, gpOp]:
636-
siluConv[get_org_mod_name_of_fx_node(n, gm, LUTfx_mod_name_to_org)] = {
636+
siluConv[get_org_mod_name_of_fx_node(n, gm, lut_fx_mod_name_to_org)] = {
637637
"qa_mode": "qsilu"
638638
}
639639

640640
return siluConv
641641

642642

643-
def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, LUTfx_mod_name_to_org={}):
643+
def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, lut_fx_mod_name_to_org={}):
644644
"""For object detection CNN models, RPN (RegionProposalNetwork) and FPN (FeaturePyramidNetwork)
645645
are commonly used. prefer to skip them, but may be ok to quantize in some cases.
646646
@@ -703,7 +703,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, LUTfx_mod_name_to_org={}):
703703
targetOp=[torch.nn.Conv2d],
704704
start_nodes=fpn_st_nodes,
705705
stop_nodes=[fpn_end_node],
706-
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
706+
lut_fx_mod_name_to_org=lut_fx_mod_name_to_org,
707707
)
708708
fpn_convs = [mod_name for mod_name, ln in fpn_convs.keys()] # see Note 4
709709
fpn_adds = dfs_gm(
@@ -730,7 +730,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, LUTfx_mod_name_to_org={}):
730730
):
731731
fpn_inner_blocks.append(
732732
get_org_mod_name_of_fx_node(
733-
gp, LUTfx2org=LUTfx_mod_name_to_org
733+
gp, lut_fx2org=lut_fx_mod_name_to_org
734734
)
735735
)
736736
fpn_convs += fpn_inner_blocks
@@ -744,7 +744,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, LUTfx_mod_name_to_org={}):
744744
return fpn_convs
745745

746746

747-
def find_and_prep_bmm_gm(gm, LUTfx_mod_name_to_org={}):
747+
def find_and_prep_bmm_gm(gm, lut_fx_mod_name_to_org={}):
748748
"""Previously with TorchScript, we use this func to perform 2 tasks:
749749
a) create QBmms, and then attach them to the model,
750750
b) set up qcfg["which2patch_contextmanager"] so that patch_torch_bmm() context
@@ -798,7 +798,7 @@ def find_and_prep_bmm_gm(gm, LUTfx_mod_name_to_org={}):
798798
LUTmodname2linenum = {} # see Note 4
799799
for node_line_num, depth in LUT2sort.items():
800800
node, line_num = node_line_num
801-
org_mod_name = get_org_mod_name_of_fx_node(node, gm, LUTfx_mod_name_to_org)
801+
org_mod_name = get_org_mod_name_of_fx_node(node, gm, lut_fx_mod_name_to_org)
802802
if org_mod_name in LUTmodname2linenum:
803803
LUTmodname2linenum[org_mod_name] += [(node, line_num, depth)]
804804
else:
@@ -941,7 +941,7 @@ def cus_backend_model_analyzer(
941941
942942
"""
943943
qcfg["N_backend_called"] += 1
944-
LUTfx_mod_name_to_org = {
944+
lut_fx_mod_name_to_org = {
945945
n.replace(".weight", ""): LUTweight2modname[p]
946946
for n, p in gm_fx.named_parameters()
947947
if p in LUTweight2modname
@@ -954,7 +954,7 @@ def cus_backend_model_analyzer(
954954
if isinstance(mod, (torch.nn.Linear, torch.nn.Conv2d)):
955955
real_org_modname = LUTweight2modname[mod.weight]
956956
part_org_modname = get_org_mod_name_of_fx_node(
957-
n, gm_fx, LUTfx_mod_name_to_org
957+
n, gm_fx, lut_fx_mod_name_to_org
958958
)
959959
idx = real_org_modname.rindex(part_org_modname)
960960
if idx > 1:
@@ -972,7 +972,7 @@ def cus_backend_model_analyzer(
972972
outputname=plotsvg,
973973
show_details=True,
974974
Nnode_to_plot=1000,
975-
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
975+
lut_fx_mod_name_to_org=lut_fx_mod_name_to_org,
976976
)
977977

978978
# Graph checks begin. Use append, add prefix if needed
@@ -984,7 +984,7 @@ def cus_backend_model_analyzer(
984984
if isinstance(m, torch.nn.Conv2d) or issubclass(type(m), torch.nn.Conv2d)
985985
]
986986
if len(all_conv) > 0:
987-
skip_candidates += find_conv_on_shortcut_gm(gm_fx, LUTfx_mod_name_to_org)
987+
skip_candidates += find_conv_on_shortcut_gm(gm_fx, lut_fx_mod_name_to_org)
988988

989989
# Check 2. first/last, see Note 2 and 3
990990
if qcfg["N_backend_called"] > 1:
@@ -993,19 +993,19 @@ def cus_backend_model_analyzer(
993993
_, last_only = find_1st_last_gm(
994994
gm_fx,
995995
return_1st_last_sep=True,
996-
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
996+
lut_fx_mod_name_to_org=lut_fx_mod_name_to_org,
997997
)
998998
skip_candidates += last_only
999999
else:
10001000
# see Note 4
10011001
skip_candidates += find_1st_last_gm(
1002-
gm_fx, LUTfx_mod_name_to_org=LUTfx_mod_name_to_org
1002+
gm_fx, lut_fx_mod_name_to_org=lut_fx_mod_name_to_org
10031003
)
10041004
qcfg["qskip_layer_name"] += add_prefix_to_list_or_dict(skip_candidates, prefix)
10051005

10061006
# Check 3: single/double sided
10071007
qcfg["qsinglesided_name"] += add_prefix_to_list_or_dict(
1008-
find_single_sided_op_gm(gm_fx, LUTfx_mod_name_to_org=LUTfx_mod_name_to_org),
1008+
find_single_sided_op_gm(gm_fx, lut_fx_mod_name_to_org=lut_fx_mod_name_to_org),
10091009
prefix,
10101010
)
10111011

@@ -1016,12 +1016,12 @@ def cus_backend_model_analyzer(
10161016
# Check 5: Conv+SiLU
10171017
qcfg["qspecial_layers"].update(
10181018
add_prefix_to_list_or_dict(
1019-
find_silu_gm(gm_fx, LUTfx_mod_name_to_org), prefix
1019+
find_silu_gm(gm_fx, lut_fx_mod_name_to_org), prefix
10201020
)
10211021
)
10221022

10231023
# Check 6: BMM
1024-
temp_dict = find_and_prep_bmm_gm(gm_fx, LUTfx_mod_name_to_org) # see Note 5
1024+
temp_dict = find_and_prep_bmm_gm(gm_fx, lut_fx_mod_name_to_org) # see Note 5
10251025
if len(temp_dict["layers_with_bmm"]) > 0:
10261026
temp_dict["layers_with_bmm"] = add_prefix_to_list_or_dict(
10271027
temp_dict["layers_with_bmm"], prefix
@@ -1033,7 +1033,7 @@ def cus_backend_model_analyzer(
10331033

10341034
# Check 7: QKV
10351035
temp_dict = find_qkvsync_candidates_gm(
1036-
gm_fx, LUTfx_mod_name_to_org=LUTfx_mod_name_to_org
1036+
gm_fx, lut_fx_mod_name_to_org=lut_fx_mod_name_to_org
10371037
) # see Note 6
10381038
temp_dict = add_prefix_to_list_or_dict(
10391039
temp_dict, prefix, update_both_k_and_v=True

fms_mo/fx/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def lname_to_org_name(Lname):
317317
return org_mod_name
318318

319319

320-
def get_org_mod_name_of_fx_node(node, gm=None, LUTfx2org={}):
320+
def get_org_mod_name_of_fx_node(node, gm=None, lut_fx2org={}):
321321
"""Given a FX node, could be call_module or call_fuction, find out the original module name,
322322
based on meta data
323323
@@ -335,7 +335,7 @@ def get_org_mod_name_of_fx_node(node, gm=None, LUTfx2org={}):
335335
node (fx.node): fx node of interest
336336
gm (GraphModule, optional): FX graph containing the given fx node. could be useful when
337337
parsing the node name
338-
LUTfx2org (dict, optional): LUT from fx module name to original module name
338+
lut_fx2org (dict, optional): LUT from fx module name to original module name
339339
340340
Returns:
341341
str: corresponding name on original graph
@@ -344,13 +344,13 @@ def get_org_mod_name_of_fx_node(node, gm=None, LUTfx2org={}):
344344
if "nn_module_stack" in node.meta:
345345
n_fx_mod_name = list(node.meta["nn_module_stack"].keys())[-1]
346346
n_fx_org_mod_name = list(node.meta["nn_module_stack"].values())[-1][0]
347-
if n_fx_mod_name in LUTfx2org:
348-
org_name = LUTfx2org[n_fx_mod_name]
347+
if n_fx_mod_name in lut_fx2org:
348+
org_name = lut_fx2org[n_fx_mod_name]
349349
elif gm and isinstance(node.target, str):
350350
LUT = gm.meta.get("dynamo_flat_name_to_original_fqn", {}) # see Note 2
351351
org_name = LUT.get(node.target, None)
352352
else:
353-
for k, v in LUTfx2org.items():
353+
for k, v in lut_fx2org.items():
354354
if k.startswith(n_fx_mod_name):
355355
suffix = k[len(n_fx_mod_name) :]
356356
suffix = "." + suffix[1:] # replace leading "_" with "."
@@ -489,7 +489,7 @@ def plot_graph_module(
489489
skip_nodes=None,
490490
Nnode_to_plot=None,
491491
additional_coloring_rules=None,
492-
LUTfx_mod_name_to_org={},
492+
lut_fx_mod_name_to_org={},
493493
):
494494
"""Plots a GraphModule in .SVG format to visualize the compute graph. If graphviz/pygraphviz is
495495
not installed properly, this function will just print out a message and do nothing.
@@ -562,7 +562,7 @@ def plot_graph_module(
562562
n_tar += f": {str(node_ptr.target).replace('<','').replace('>','')}"
563563
elif ntype in ["call_module", "get_attr"]:
564564
org_mod_name = get_org_mod_name_of_fx_node(
565-
node_ptr, LUTfx2org=LUTfx_mod_name_to_org
565+
node_ptr, lut_fx2org=lut_fx_mod_name_to_org
566566
)
567567
n_tar += f": {org_mod_name}"
568568
if node_ptr.target.startswith(fx_mod_name + "_"):

0 commit comments

Comments
 (0)