Skip to content

Commit 79851eb

Browse files
fix qbmm
Signed-off-by: chichun-charlie-liu <[email protected]>
1 parent 5b5f33f commit 79851eb

File tree

2 files changed

+42
-42
lines changed

2 files changed

+42
-42
lines changed

fms_mo/fx/dynamo_utils.py

Lines changed: 35 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def dfs_gm(
4646
prescreenOp=None,
4747
hook=None,
4848
return_nodes=False,
49-
LUTfx_mod_name_to_org={},
49+
lut_fx_mod_name_to_org={},
5050
):
5151
"""Depth-First Search at FX IR level, to replace our old TorchScript equivalent func
5252
Because FX IR is a higher level IR, should have much fewer
@@ -217,7 +217,7 @@ def _dfs(curr_node, depth):
217217
for n_ln, d in node_found.items():
218218
n, line_num = n_ln # unpack tuple
219219
org_mod_names[
220-
get_org_mod_name_of_fx_node(n, gm, LUTfx_mod_name_to_org), line_num
220+
get_org_mod_name_of_fx_node(n, gm, lut_fx_mod_name_to_org), line_num
221221
] = d # see Note 2
222222

223223
return dict(
@@ -227,7 +227,7 @@ def _dfs(curr_node, depth):
227227
return node_found
228228

229229

230-
def find_conv_on_shortcut_gm(gm: torch.fx.GraphModule, LUTfx_mod_name_to_org={}):
230+
def find_conv_on_shortcut_gm(gm: torch.fx.GraphModule, lut_fx_mod_name_to_org={}):
231231
"""Identify Conv on shortcut using FX GM DFS
232232
It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
233233
original model, not FX module names)
@@ -335,18 +335,18 @@ def find_conv_on_shortcut_gm(gm: torch.fx.GraphModule, LUTfx_mod_name_to_org={})
335335
conv_mod = gm.get_submodule(n_conv_i.target)
336336
else:
337337
conv_mod = get_org_mod_name_of_fx_node(
338-
n_conv_i, LUTfx2org=LUTfx_mod_name_to_org
338+
n_conv_i, lut_fx2org=lut_fx_mod_name_to_org
339339
)
340340
if conv_mod.out_channels > conv_mod.in_channels: # see Note 2
341341
qconv_candidate.append(
342-
get_org_mod_name_of_fx_node(n_conv_i, gm, LUTfx_mod_name_to_org)
342+
get_org_mod_name_of_fx_node(n_conv_i, gm, lut_fx_mod_name_to_org)
343343
)
344344

345345
return qconv_candidate
346346

347347

348348
def find_1st_last_gm(
349-
gm, firstOps=None, lastOps=None, return_1st_last_sep=False, LUTfx_mod_name_to_org={}
349+
gm, firstOps=None, lastOps=None, return_1st_last_sep=False, lut_fx_mod_name_to_org={}
350350
):
351351
"""Identify the first and last layer of interests
352352
Usually only interested in Conv and Linear, but could be others as well
@@ -366,14 +366,14 @@ def find_1st_last_gm(
366366
gm,
367367
targetOp=firstOps,
368368
find1stOnly=True,
369-
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
369+
lut_fx_mod_name_to_org=lut_fx_mod_name_to_org,
370370
)
371371
last_candidates = dfs_gm(
372372
gm,
373373
targetOp=lastOps,
374374
find1stOnly=True,
375375
reverse=True,
376-
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
376+
lut_fx_mod_name_to_org=lut_fx_mod_name_to_org,
377377
)
378378

379379
min_depth = min(list(first_candidates.values()) + [999])
@@ -397,7 +397,7 @@ def find_1st_last_gm(
397397

398398

399399
def find_single_sided_op_gm(
400-
gm, op_of_interest=None, return_LUTs=False, verbose=False, LUTfx_mod_name_to_org={}
400+
gm, op_of_interest=None, return_LUTs=False, verbose=False, lut_fx_mod_name_to_org={}
401401
):
402402
"""Try to determine the "polarity" of output of "every nodes" based on their inputs and the Op
403403
itself, then decide which Conv/Linear (or user-specified Op) will use single-sided quantizer
@@ -546,7 +546,7 @@ def find_single_sided_op_gm(
546546
gm,
547547
targetOp=op_of_interest,
548548
return_nodes=True,
549-
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
549+
lut_fx_mod_name_to_org=lut_fx_mod_name_to_org,
550550
)
551551

552552
SingleSidedOps = []
@@ -557,7 +557,7 @@ def find_single_sided_op_gm(
557557
risky_nodes.append(n)
558558
if all(input_pos):
559559
SingleSidedOps.append(
560-
get_org_mod_name_of_fx_node(n, gm, LUTfx_mod_name_to_org)
560+
get_org_mod_name_of_fx_node(n, gm, lut_fx_mod_name_to_org)
561561
)
562562

563563
if risky_nodes:
@@ -570,7 +570,7 @@ def find_single_sided_op_gm(
570570
return SingleSidedOps
571571

572572

573-
def find_qkvsync_candidates_gm(gm, return_nodes=False, LUTfx_mod_name_to_org={}):
573+
def find_qkvsync_candidates_gm(gm, return_nodes=False, lut_fx_mod_name_to_org={}):
574574
"""Identify groups of Linears that share the same parent. It's a transformer-specific feature.
575575
576576
NOTE:
@@ -609,7 +609,7 @@ def find_qkvsync_candidates_gm(gm, return_nodes=False, LUTfx_mod_name_to_org={})
609609
for depth, nodes in LUTdep2nodes.items():
610610
parents = [ni.all_input_nodes[0] for ni in nodes]
611611
org_mod_names = [
612-
get_org_mod_name_of_fx_node(ni, gm, LUTfx_mod_name_to_org) for ni in nodes
612+
get_org_mod_name_of_fx_node(ni, gm, lut_fx_mod_name_to_org) for ni in nodes
613613
]
614614
if all(p == parents[0] for p in parents[1:]):
615615
Nshared_parents += 1
@@ -620,7 +620,7 @@ def find_qkvsync_candidates_gm(gm, return_nodes=False, LUTfx_mod_name_to_org={})
620620
return my_1st_sibling
621621

622622

623-
def find_silu_gm(gm, LUTfx_mod_name_to_org={}):
623+
def find_silu_gm(gm, lut_fx_mod_name_to_org={}):
624624
"""Special handle for Conv following silu, specific for EffDet and etc
625625
LLM could use SiLU as well (llama?), but not relavent to this func
626626
"""
@@ -633,14 +633,14 @@ def find_silu_gm(gm, LUTfx_mod_name_to_org={}):
633633
gpOp = get_target_op_from_node(gp_nodes[0], gm) if gp_nodes else None
634634

635635
if torch.nn.functional.silu in [pOp, gpOp]:
636-
siluConv[get_org_mod_name_of_fx_node(n, gm, LUTfx_mod_name_to_org)] = {
636+
siluConv[get_org_mod_name_of_fx_node(n, gm, lut_fx_mod_name_to_org)] = {
637637
"qa_mode": "qsilu"
638638
}
639639

640640
return siluConv
641641

642642

643-
def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, LUTfx_mod_name_to_org={}):
643+
def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, lut_fx_mod_name_to_org={}):
644644
"""For object detection CNN models, RPN (RegionProposalNetwork) and FPN (FeaturePyramidNetwork)
645645
are commonly used. prefer to skip them, but may be ok to quantize in some cases.
646646
@@ -703,7 +703,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, LUTfx_mod_name_to_org={}):
703703
targetOp=[torch.nn.Conv2d],
704704
start_nodes=fpn_st_nodes,
705705
stop_nodes=[fpn_end_node],
706-
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
706+
lut_fx_mod_name_to_org=lut_fx_mod_name_to_org,
707707
)
708708
fpn_convs = [mod_name for mod_name, ln in fpn_convs.keys()] # see Note 4
709709
fpn_adds = dfs_gm(
@@ -730,7 +730,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, LUTfx_mod_name_to_org={}):
730730
):
731731
fpn_inner_blocks.append(
732732
get_org_mod_name_of_fx_node(
733-
gp, LUTfx2org=LUTfx_mod_name_to_org
733+
gp, lut_fx2org=lut_fx_mod_name_to_org
734734
)
735735
)
736736
fpn_convs += fpn_inner_blocks
@@ -744,7 +744,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, LUTfx_mod_name_to_org={}):
744744
return fpn_convs
745745

746746

747-
def find_and_prep_bmm_gm(gm, LUTfx_mod_name_to_org={}):
747+
def find_and_prep_bmm_gm(gm, lut_fx_mod_name_to_org={}):
748748
"""Previously with TorchScript, we use this func to perform 2 tasks:
749749
a) create QBmms, and then attach them to the model,
750750
b) set up qcfg["which2patch_contextmanager"] so that patch_torch_bmm() context
@@ -798,7 +798,7 @@ def find_and_prep_bmm_gm(gm, LUTfx_mod_name_to_org={}):
798798
LUTmodname2linenum = {} # see Note 4
799799
for node_line_num, depth in LUT2sort.items():
800800
node, line_num = node_line_num
801-
org_mod_name = get_org_mod_name_of_fx_node(node, gm, LUTfx_mod_name_to_org)
801+
org_mod_name = get_org_mod_name_of_fx_node(node, gm, lut_fx_mod_name_to_org)
802802
if org_mod_name in LUTmodname2linenum:
803803
LUTmodname2linenum[org_mod_name] += [(node, line_num, depth)]
804804
else:
@@ -880,7 +880,7 @@ def model_analyzer(
880880
2. Use Dynamo to replace TorchScript tracing in old qmodel_prep(),
881881
882882
NOTE:
883-
1. Will use LUTweight2modname to find the prefix for subgraphs, should graph break. As module
883+
1. Will use lut_weight2modname to find the prefix for subgraphs, should graph break. As module
884884
seems to have extra layer of wrapper from Dynamo, matching module, i.e. id(module), may lead
885885
to incorrect results, matching weights (tensor) should be consistent.
886886
2. For subgraph, we might be getting a partial "original name", such as layer.0.xxx instead of
@@ -894,7 +894,7 @@ def model_analyzer(
894894
"""
895895

896896
qcfg["N_backend_called"] = 0
897-
LUTweight2modname = {
897+
lut_weight2modname = {
898898
mod.weight: n
899899
for n, mod in model.named_modules()
900900
if isinstance(mod, (torch.nn.Linear, torch.nn.Conv2d))
@@ -941,20 +941,20 @@ def cus_backend_model_analyzer(
941941
942942
"""
943943
qcfg["N_backend_called"] += 1
944-
LUTfx_mod_name_to_org = {
945-
n.replace(".weight", ""): LUTweight2modname[p]
944+
lut_fx_mod_name_to_org = {
945+
n.replace(".weight", ""): lut_weight2modname[p]
946946
for n, p in gm_fx.named_parameters()
947-
if p in LUTweight2modname
947+
if p in lut_weight2modname
948948
}
949949
prefix = None
950950
if qcfg["N_backend_called"] > 1: # subgraph found, see Note 2
951951
for n in gm_fx.graph.nodes:
952952
if n.op == "call_module":
953953
mod = gm_fx.get_submodule(n.target)
954954
if isinstance(mod, (torch.nn.Linear, torch.nn.Conv2d)):
955-
real_org_modname = LUTweight2modname[mod.weight]
955+
real_org_modname = lut_weight2modname[mod.weight]
956956
part_org_modname = get_org_mod_name_of_fx_node(
957-
n, gm_fx, LUTfx_mod_name_to_org
957+
n, gm_fx, lut_fx_mod_name_to_org
958958
)
959959
idx = real_org_modname.rindex(part_org_modname)
960960
if idx > 1:
@@ -972,7 +972,7 @@ def cus_backend_model_analyzer(
972972
outputname=plotsvg,
973973
show_details=True,
974974
Nnode_to_plot=1000,
975-
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
975+
lut_fx_mod_name_to_org=lut_fx_mod_name_to_org,
976976
)
977977

978978
# Graph checks begin. Use append, add prefix if needed
@@ -984,7 +984,7 @@ def cus_backend_model_analyzer(
984984
if isinstance(m, torch.nn.Conv2d) or issubclass(type(m), torch.nn.Conv2d)
985985
]
986986
if len(all_conv) > 0:
987-
skip_candidates += find_conv_on_shortcut_gm(gm_fx, LUTfx_mod_name_to_org)
987+
skip_candidates += find_conv_on_shortcut_gm(gm_fx, lut_fx_mod_name_to_org)
988988

989989
# Check 2. first/last, see Note 2 and 3
990990
if qcfg["N_backend_called"] > 1:
@@ -993,19 +993,19 @@ def cus_backend_model_analyzer(
993993
_, last_only = find_1st_last_gm(
994994
gm_fx,
995995
return_1st_last_sep=True,
996-
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
996+
lut_fx_mod_name_to_org=lut_fx_mod_name_to_org,
997997
)
998998
skip_candidates += last_only
999999
else:
10001000
# see Note 4
10011001
skip_candidates += find_1st_last_gm(
1002-
gm_fx, LUTfx_mod_name_to_org=LUTfx_mod_name_to_org
1002+
gm_fx, lut_fx_mod_name_to_org=lut_fx_mod_name_to_org
10031003
)
10041004
qcfg["qskip_layer_name"] += add_prefix_to_list_or_dict(skip_candidates, prefix)
10051005

10061006
# Check 3: single/double sided
10071007
qcfg["qsinglesided_name"] += add_prefix_to_list_or_dict(
1008-
find_single_sided_op_gm(gm_fx, LUTfx_mod_name_to_org=LUTfx_mod_name_to_org),
1008+
find_single_sided_op_gm(gm_fx, lut_fx_mod_name_to_org=lut_fx_mod_name_to_org),
10091009
prefix,
10101010
)
10111011

@@ -1016,12 +1016,12 @@ def cus_backend_model_analyzer(
10161016
# Check 5: Conv+SiLU
10171017
qcfg["qspecial_layers"].update(
10181018
add_prefix_to_list_or_dict(
1019-
find_silu_gm(gm_fx, LUTfx_mod_name_to_org), prefix
1019+
find_silu_gm(gm_fx, lut_fx_mod_name_to_org), prefix
10201020
)
10211021
)
10221022

10231023
# Check 6: BMM
1024-
temp_dict = find_and_prep_bmm_gm(gm_fx, LUTfx_mod_name_to_org) # see Note 5
1024+
temp_dict = find_and_prep_bmm_gm(gm_fx, lut_fx_mod_name_to_org) # see Note 5
10251025
if len(temp_dict["layers_with_bmm"]) > 0:
10261026
temp_dict["layers_with_bmm"] = add_prefix_to_list_or_dict(
10271027
temp_dict["layers_with_bmm"], prefix
@@ -1033,7 +1033,7 @@ def cus_backend_model_analyzer(
10331033

10341034
# Check 7: QKV
10351035
temp_dict = find_qkvsync_candidates_gm(
1036-
gm_fx, LUTfx_mod_name_to_org=LUTfx_mod_name_to_org
1036+
gm_fx, lut_fx_mod_name_to_org=lut_fx_mod_name_to_org
10371037
) # see Note 6
10381038
temp_dict = add_prefix_to_list_or_dict(
10391039
temp_dict, prefix, update_both_k_and_v=True

fms_mo/fx/utils.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def lname_to_org_name(Lname):
317317
return org_mod_name
318318

319319

320-
def get_org_mod_name_of_fx_node(node, gm=None, LUTfx2org={}):
320+
def get_org_mod_name_of_fx_node(node, gm=None, lut_fx2org={}):
321321
"""Given a FX node, could be call_module or call_fuction, find out the original module name,
322322
based on meta data
323323
@@ -335,7 +335,7 @@ def get_org_mod_name_of_fx_node(node, gm=None, LUTfx2org={}):
335335
node (fx.node): fx node of interest
336336
gm (GraphModule, optional): FX graph containing the given fx node. could be useful when
337337
parsing the node name
338-
LUTfx2org (dict, optional): LUT from fx module name to original module name
338+
lut_fx2org (dict, optional): LUT from fx module name to original module name
339339
340340
Returns:
341341
str: corresponding name on original graph
@@ -344,13 +344,13 @@ def get_org_mod_name_of_fx_node(node, gm=None, LUTfx2org={}):
344344
if "nn_module_stack" in node.meta:
345345
n_fx_mod_name = list(node.meta["nn_module_stack"].keys())[-1]
346346
n_fx_org_mod_name = list(node.meta["nn_module_stack"].values())[-1][0]
347-
if n_fx_mod_name in LUTfx2org:
348-
org_name = LUTfx2org[n_fx_mod_name]
347+
if n_fx_mod_name in lut_fx2org:
348+
org_name = lut_fx2org[n_fx_mod_name]
349349
elif gm and isinstance(node.target, str):
350350
LUT = gm.meta.get("dynamo_flat_name_to_original_fqn", {}) # see Note 2
351351
org_name = LUT.get(node.target, None)
352352
else:
353-
for k, v in LUTfx2org.items():
353+
for k, v in lut_fx2org.items():
354354
if k.startswith(n_fx_mod_name):
355355
suffix = k[len(n_fx_mod_name) :]
356356
suffix = "." + suffix[1:] # replace leading "_" with "."
@@ -489,7 +489,7 @@ def plot_graph_module(
489489
skip_nodes=None,
490490
Nnode_to_plot=None,
491491
additional_coloring_rules=None,
492-
LUTfx_mod_name_to_org={},
492+
lut_fx_mod_name_to_org={},
493493
):
494494
"""Plots a GraphModule in .SVG format to visualize the compute graph. If graphviz/pygraphviz is
495495
not installed properly, this function will just print out a message and do nothing.
@@ -562,7 +562,7 @@ def plot_graph_module(
562562
n_tar += f": {str(node_ptr.target).replace('<','').replace('>','')}"
563563
elif ntype in ["call_module", "get_attr"]:
564564
org_mod_name = get_org_mod_name_of_fx_node(
565-
node_ptr, LUTfx2org=LUTfx_mod_name_to_org
565+
node_ptr, lut_fx2org=lut_fx_mod_name_to_org
566566
)
567567
n_tar += f": {org_mod_name}"
568568
if node_ptr.target.startswith(fx_mod_name + "_"):

0 commit comments

Comments
 (0)