Skip to content

Commit 5b5f33f

Browse files
fix qbmm detection issue caused by gm module naming convention change in pt2.4
Signed-off-by: chichun-charlie-liu <[email protected]>
1 parent cea5bc7 commit 5b5f33f

File tree

4 files changed

+130
-39
lines changed

4 files changed

+130
-39
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,8 @@ venv/
3333
# Generated by spelling check
3434
dictionary.dic
3535

36+
# Files generated from running examples
37+
fms_mo.log
38+
data_train/
39+
data_test/
40+
act_scales/

fms_mo/dq.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -65,22 +65,25 @@ def run_dq(model_args, data_args, fms_mo_args, output_dir):
6565
tokenized dataset
6666
fms_mo_args (fms_mo.training_args.FMSMOArguments): Parameters to use for DQ quantization
6767
output_dir (str) Output directory to write to
68+
NOTE:
69+
use dynamo tracing instead of torchscript by default. if torchscript is needed, change
70+
1) config_kwarks and 2) use_dynamo in qmodel_prep()
6871
"""
6972
# for attention or kv-cache quantization, need to use eager attention
7073
attn_bits = [
7174
fms_mo_args.nbits_bmm1,
7275
fms_mo_args.nbits_bmm2,
7376
fms_mo_args.nbits_kvcache,
7477
]
75-
if any(attn_bits) != 32:
78+
if any(x != 32 for x in attn_bits):
7679
attn_implementation = "eager"
7780
else:
7881
attn_implementation = None
7982
config_kwargs = {
8083
"cache_dir": model_args.cache_dir,
8184
"revision": model_args.model_revision,
8285
"use_auth_token": True if model_args.use_auth_token else None,
83-
"torchscript": True,
86+
"torchscript": False,
8487
"attn_implementation": attn_implementation,
8588
}
8689
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
@@ -123,6 +126,7 @@ def run_dq(model_args, data_args, fms_mo_args, output_dir):
123126
if torch.cuda.is_available():
124127
total_gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
125128
model_size = model_size_Wb(model, unit="GB")
129+
gpu_mem_util_per = model_size / total_gpu_memory
126130

127131
known_large_models = [
128132
"Llama-2-70b",
@@ -134,7 +138,7 @@ def run_dq(model_args, data_args, fms_mo_args, output_dir):
134138
]
135139
qcfg["large_model"] = any(
136140
name in model_args.model_name_or_path for name in known_large_models
137-
) or (model_size > 0.7 * total_gpu_memory)
141+
) or (gpu_mem_util_per > 0.7)
138142
dev = "cpu" if qcfg["large_model"] else "cuda:0"
139143

140144
if hasattr(model.config, "model_type"):
@@ -184,6 +188,9 @@ def run_dq(model_args, data_args, fms_mo_args, output_dir):
184188
if qcfg["large_model"]:
185189
act_scales = get_act_scales_1gpu(model, dq_dataloader, qcfg)
186190
else:
191+
if gpu_mem_util_per < 0.7:
192+
model.to(dev)
193+
187194
act_scales = get_act_scales(model, dq_dataloader, qcfg)
188195
scale_file = f"{act_scale_directory}/{qcfg['model'].replace('/', '-')}" + ".pt"
189196
torch.save(act_scales, scale_file)

fms_mo/fx/dynamo_utils.py

Lines changed: 91 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def dfs_gm(
4646
prescreenOp=None,
4747
hook=None,
4848
return_nodes=False,
49+
LUTfx_mod_name_to_org={},
4950
):
5051
"""Depth-First Search at FX IR level, to replace our old TorchScript equivalent func
5152
Because FX IR is a higher level IR, should have much fewer
@@ -215,9 +216,9 @@ def _dfs(curr_node, depth):
215216
org_mod_names = {}
216217
for n_ln, d in node_found.items():
217218
n, line_num = n_ln # unpack tuple
218-
org_mod_names[get_org_mod_name_of_fx_node(n, gm), line_num] = (
219-
d # see Note 2
220-
)
219+
org_mod_names[
220+
get_org_mod_name_of_fx_node(n, gm, LUTfx_mod_name_to_org), line_num
221+
] = d # see Note 2
221222

222223
return dict(
223224
sorted(org_mod_names.items(), key=lambda item: item[1])
@@ -226,7 +227,7 @@ def _dfs(curr_node, depth):
226227
return node_found
227228

228229

229-
def find_conv_on_shortcut_gm(gm: torch.fx.GraphModule):
230+
def find_conv_on_shortcut_gm(gm: torch.fx.GraphModule, LUTfx_mod_name_to_org={}):
230231
"""Identify Conv on shortcut using FX GM DFS
231232
It's (almost) specific for ResNet-like CNNs, will return a list of module names (as used in the
232233
original model, not FX module names)
@@ -333,14 +334,20 @@ def find_conv_on_shortcut_gm(gm: torch.fx.GraphModule):
333334
if n_conv_i.op == "call_module":
334335
conv_mod = gm.get_submodule(n_conv_i.target)
335336
else:
336-
conv_mod = get_org_mod_name_of_fx_node(n_conv_i)
337+
conv_mod = get_org_mod_name_of_fx_node(
338+
n_conv_i, LUTfx2org=LUTfx_mod_name_to_org
339+
)
337340
if conv_mod.out_channels > conv_mod.in_channels: # see Note 2
338-
qconv_candidate.append(get_org_mod_name_of_fx_node(n_conv_i, gm))
341+
qconv_candidate.append(
342+
get_org_mod_name_of_fx_node(n_conv_i, gm, LUTfx_mod_name_to_org)
343+
)
339344

340345
return qconv_candidate
341346

342347

343-
def find_1st_last_gm(gm, firstOps=None, lastOps=None, return_1st_last_sep=False):
348+
def find_1st_last_gm(
349+
gm, firstOps=None, lastOps=None, return_1st_last_sep=False, LUTfx_mod_name_to_org={}
350+
):
344351
"""Identify the first and last layer of interests
345352
Usually only interested in Conv and Linear, but could be others as well
346353
NOTE:
@@ -355,8 +362,19 @@ def find_1st_last_gm(gm, firstOps=None, lastOps=None, return_1st_last_sep=False)
355362
firstOps = [torch.nn.Conv2d, torch.nn.Linear]
356363
if lastOps is None:
357364
lastOps = [torch.nn.Conv2d, torch.nn.Linear]
358-
first_candidates = dfs_gm(gm, targetOp=firstOps, find1stOnly=True)
359-
last_candidates = dfs_gm(gm, targetOp=lastOps, find1stOnly=True, reverse=True)
365+
first_candidates = dfs_gm(
366+
gm,
367+
targetOp=firstOps,
368+
find1stOnly=True,
369+
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
370+
)
371+
last_candidates = dfs_gm(
372+
gm,
373+
targetOp=lastOps,
374+
find1stOnly=True,
375+
reverse=True,
376+
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
377+
)
360378

361379
min_depth = min(list(first_candidates.values()) + [999])
362380
real_first = [
@@ -379,10 +397,7 @@ def find_1st_last_gm(gm, firstOps=None, lastOps=None, return_1st_last_sep=False)
379397

380398

381399
def find_single_sided_op_gm(
382-
gm,
383-
op_of_interest=None,
384-
return_LUTs=False,
385-
verbose=False,
400+
gm, op_of_interest=None, return_LUTs=False, verbose=False, LUTfx_mod_name_to_org={}
386401
):
387402
"""Try to determine the "polarity" of output of "every nodes" based on their inputs and the Op
388403
itself, then decide which Conv/Linear (or user-specified Op) will use single-sided quantizer
@@ -527,7 +542,12 @@ def find_single_sided_op_gm(
527542
if return_LUTs:
528543
return isActOutPositiveOnly, isActOutBounded
529544

530-
node_of_interest = dfs_gm(gm, targetOp=op_of_interest, return_nodes=True)
545+
node_of_interest = dfs_gm(
546+
gm,
547+
targetOp=op_of_interest,
548+
return_nodes=True,
549+
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
550+
)
531551

532552
SingleSidedOps = []
533553
risky_nodes = []
@@ -536,7 +556,9 @@ def find_single_sided_op_gm(
536556
if (False not in input_pos) and (None in input_pos): # see Note g
537557
risky_nodes.append(n)
538558
if all(input_pos):
539-
SingleSidedOps.append(get_org_mod_name_of_fx_node(n, gm))
559+
SingleSidedOps.append(
560+
get_org_mod_name_of_fx_node(n, gm, LUTfx_mod_name_to_org)
561+
)
540562

541563
if risky_nodes:
542564
logger.warning(
@@ -548,7 +570,7 @@ def find_single_sided_op_gm(
548570
return SingleSidedOps
549571

550572

551-
def find_qkvsync_candidates_gm(gm, return_nodes=False):
573+
def find_qkvsync_candidates_gm(gm, return_nodes=False, LUTfx_mod_name_to_org={}):
552574
"""Identify groups of Linears that share the same parent. It's a transformer-specific feature.
553575
554576
NOTE:
@@ -586,7 +608,9 @@ def find_qkvsync_candidates_gm(gm, return_nodes=False):
586608
Nshared_parents = 0
587609
for depth, nodes in LUTdep2nodes.items():
588610
parents = [ni.all_input_nodes[0] for ni in nodes]
589-
org_mod_names = [get_org_mod_name_of_fx_node(ni, gm) for ni in nodes]
611+
org_mod_names = [
612+
get_org_mod_name_of_fx_node(ni, gm, LUTfx_mod_name_to_org) for ni in nodes
613+
]
590614
if all(p == parents[0] for p in parents[1:]):
591615
Nshared_parents += 1
592616
for org_name_i in org_mod_names:
@@ -596,7 +620,7 @@ def find_qkvsync_candidates_gm(gm, return_nodes=False):
596620
return my_1st_sibling
597621

598622

599-
def find_silu_gm(gm):
623+
def find_silu_gm(gm, LUTfx_mod_name_to_org={}):
600624
"""Special handle for Conv following silu, specific for EffDet and etc
601625
LLM could use SiLU as well (llama?), but not relavent to this func
602626
"""
@@ -609,12 +633,14 @@ def find_silu_gm(gm):
609633
gpOp = get_target_op_from_node(gp_nodes[0], gm) if gp_nodes else None
610634

611635
if torch.nn.functional.silu in [pOp, gpOp]:
612-
siluConv[get_org_mod_name_of_fx_node(n, gm)] = {"qa_mode": "qsilu"}
636+
siluConv[get_org_mod_name_of_fx_node(n, gm, LUTfx_mod_name_to_org)] = {
637+
"qa_mode": "qsilu"
638+
}
613639

614640
return siluConv
615641

616642

617-
def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0):
643+
def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0, LUTfx_mod_name_to_org={}):
618644
"""For object detection CNN models, RPN (RegionProposalNetwork) and FPN (FeaturePyramidNetwork)
619645
are commonly used. prefer to skip them, but may be ok to quantize in some cases.
620646
@@ -677,6 +703,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0):
677703
targetOp=[torch.nn.Conv2d],
678704
start_nodes=fpn_st_nodes,
679705
stop_nodes=[fpn_end_node],
706+
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
680707
)
681708
fpn_convs = [mod_name for mod_name, ln in fpn_convs.keys()] # see Note 4
682709
fpn_adds = dfs_gm(
@@ -701,7 +728,11 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0):
701728
isinstance(tar_op_gp, torch.nn.Conv2d)
702729
and tar_op_gp.kernel_size in [1, (1, 1)]
703730
):
704-
fpn_inner_blocks.append(get_org_mod_name_of_fx_node(gp))
731+
fpn_inner_blocks.append(
732+
get_org_mod_name_of_fx_node(
733+
gp, LUTfx2org=LUTfx_mod_name_to_org
734+
)
735+
)
705736
fpn_convs += fpn_inner_blocks
706737

707738
if verbose:
@@ -713,7 +744,7 @@ def find_rpn_fpn_gm(gm, verbose=False, Nsubgraph=0):
713744
return fpn_convs
714745

715746

716-
def find_and_prep_bmm_gm(gm):
747+
def find_and_prep_bmm_gm(gm, LUTfx_mod_name_to_org={}):
717748
"""Previously with TorchScript, we use this func to perform 2 tasks:
718749
a) create QBmms, and then attach them to the model,
719750
b) set up qcfg["which2patch_contextmanager"] so that patch_torch_bmm() context
@@ -767,7 +798,7 @@ def find_and_prep_bmm_gm(gm):
767798
LUTmodname2linenum = {} # see Note 4
768799
for node_line_num, depth in LUT2sort.items():
769800
node, line_num = node_line_num
770-
org_mod_name = get_org_mod_name_of_fx_node(node, gm)
801+
org_mod_name = get_org_mod_name_of_fx_node(node, gm, LUTfx_mod_name_to_org)
771802
if org_mod_name in LUTmodname2linenum:
772803
LUTmodname2linenum[org_mod_name] += [(node, line_num, depth)]
773804
else:
@@ -910,14 +941,21 @@ def cus_backend_model_analyzer(
910941
911942
"""
912943
qcfg["N_backend_called"] += 1
944+
LUTfx_mod_name_to_org = {
945+
n.replace(".weight", ""): LUTweight2modname[p]
946+
for n, p in gm_fx.named_parameters()
947+
if p in LUTweight2modname
948+
}
913949
prefix = None
914950
if qcfg["N_backend_called"] > 1: # subgraph found, see Note 2
915951
for n in gm_fx.graph.nodes:
916952
if n.op == "call_module":
917953
mod = gm_fx.get_submodule(n.target)
918954
if isinstance(mod, (torch.nn.Linear, torch.nn.Conv2d)):
919955
real_org_modname = LUTweight2modname[mod.weight]
920-
part_org_modname = get_org_mod_name_of_fx_node(n, gm_fx)
956+
part_org_modname = get_org_mod_name_of_fx_node(
957+
n, gm_fx, LUTfx_mod_name_to_org
958+
)
921959
idx = real_org_modname.rindex(part_org_modname)
922960
if idx > 1:
923961
prefix = real_org_modname[: idx - 1] # remove trailing '.'
@@ -930,28 +968,45 @@ def cus_backend_model_analyzer(
930968
if not isinstance(plotsvg, str):
931969
plotsvg = f"debug{qcfg['N_backend_called']}.svg"
932970
plot_graph_module(
933-
gm_fx, outputname=plotsvg, show_details=True, Nnode_to_plot=1000
971+
gm_fx,
972+
outputname=plotsvg,
973+
show_details=True,
974+
Nnode_to_plot=1000,
975+
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
934976
)
935977

936978
# Graph checks begin. Use append, add prefix if needed
937979
skip_candidates = []
938980
# Check 1. Conv on shortcut
939-
skip_candidates += find_conv_on_shortcut_gm(gm_fx)
981+
all_conv = [
982+
m
983+
for _, m in gm_fx.named_modules()
984+
if isinstance(m, torch.nn.Conv2d) or issubclass(type(m), torch.nn.Conv2d)
985+
]
986+
if len(all_conv) > 0:
987+
skip_candidates += find_conv_on_shortcut_gm(gm_fx, LUTfx_mod_name_to_org)
940988

941989
# Check 2. first/last, see Note 2 and 3
942990
if qcfg["N_backend_called"] > 1:
943991
skip_candidates += []
944992
elif is_transformers:
945-
_, last_only = find_1st_last_gm(gm_fx, return_1st_last_sep=True)
993+
_, last_only = find_1st_last_gm(
994+
gm_fx,
995+
return_1st_last_sep=True,
996+
LUTfx_mod_name_to_org=LUTfx_mod_name_to_org,
997+
)
946998
skip_candidates += last_only
947999
else:
9481000
# see Note 4
949-
skip_candidates += find_1st_last_gm(gm_fx)
1001+
skip_candidates += find_1st_last_gm(
1002+
gm_fx, LUTfx_mod_name_to_org=LUTfx_mod_name_to_org
1003+
)
9501004
qcfg["qskip_layer_name"] += add_prefix_to_list_or_dict(skip_candidates, prefix)
9511005

9521006
# Check 3: single/double sided
9531007
qcfg["qsinglesided_name"] += add_prefix_to_list_or_dict(
954-
find_single_sided_op_gm(gm_fx), prefix
1008+
find_single_sided_op_gm(gm_fx, LUTfx_mod_name_to_org=LUTfx_mod_name_to_org),
1009+
prefix,
9551010
)
9561011

9571012
# Check 4: identify RPN/FPN
@@ -960,11 +1015,13 @@ def cus_backend_model_analyzer(
9601015
# NOTE: The following 3 funcs return dict instead of list. Use update() instead of append().
9611016
# Check 5: Conv+SiLU
9621017
qcfg["qspecial_layers"].update(
963-
add_prefix_to_list_or_dict(find_silu_gm(gm_fx), prefix)
1018+
add_prefix_to_list_or_dict(
1019+
find_silu_gm(gm_fx, LUTfx_mod_name_to_org), prefix
1020+
)
9641021
)
9651022

9661023
# Check 6: BMM
967-
temp_dict = find_and_prep_bmm_gm(gm_fx) # see Note 5
1024+
temp_dict = find_and_prep_bmm_gm(gm_fx, LUTfx_mod_name_to_org) # see Note 5
9681025
if len(temp_dict["layers_with_bmm"]) > 0:
9691026
temp_dict["layers_with_bmm"] = add_prefix_to_list_or_dict(
9701027
temp_dict["layers_with_bmm"], prefix
@@ -975,7 +1032,9 @@ def cus_backend_model_analyzer(
9751032
qcfg["bmm_prep"]["layers_with_bmm"].update(temp_dict["layers_with_bmm"])
9761033

9771034
# Check 7: QKV
978-
temp_dict = find_qkvsync_candidates_gm(gm_fx) # see Note 6
1035+
temp_dict = find_qkvsync_candidates_gm(
1036+
gm_fx, LUTfx_mod_name_to_org=LUTfx_mod_name_to_org
1037+
) # see Note 6
9791038
temp_dict = add_prefix_to_list_or_dict(
9801039
temp_dict, prefix, update_both_k_and_v=True
9811040
)

0 commit comments

Comments
 (0)