Skip to content

Commit ae847c2

Browse files
Merge pull request #161 from chichun-charlie-liu/qkvsync_bug_fix
fix: qkvsync bug fix
2 parents 0e98567 + d6fd553 commit ae847c2

File tree

1 file changed

+34
-1
lines changed

1 file changed

+34
-1
lines changed

fms_mo/fx/dynamo_utils.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,8 @@ def cus_backend_model_analyzer(
10101010
}
10111011
prefix = None
10121012
if qcfg["N_backend_called"] > 1: # subgraph found, see Note 2
1013+
# TODO this approach only works for FX IR (call_module nodes are not functionalized)
1014+
# need an update for Aten IR cases
10131015
for n in gm_fx.graph.nodes:
10141016
if n.op == "call_module":
10151017
mod = gm_fx.get_submodule(n.target)
@@ -1226,6 +1228,37 @@ def call_seq_hook(mod, *_args, **kwargs):
12261228
# ------ model analysis is finished, but there are a few remaining things to be done
12271229

12281230
# a) qkvsync dict update from "module names" to "module instances"
1231+
# NOTE when graph break happened, qkvsync() may only find partial QKV names. For example,
1232+
# as opposed to ["model.layers.0.self_attn.q_proj", ..., "model.layers.1.self_attn.q_proj", ...]
1233+
# it may report ["self_attn.q_proj", "self_attn.k_proj", ...]
1234+
# Therefore, length of qcfg["qkvsync_my_1st_sibling"] will be much shorter and keys of this dict
1235+
# won't exist in full list (like all_linears below).
1236+
all_linears = set(
1237+
n for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)
1238+
)
1239+
1240+
if any(k not in all_linears for k in qcfg["qkvsync_my_1st_sibling"]):
1241+
# qcfg["qkvsync_my_1st_sibling"] dict is like {q:q, k:q, v:q,...}, here we need a simpler
1242+
# dict like {q:[q,k,v], gate:[up, gate]}
1243+
lut_all_siblings = {}
1244+
for me, sib_1st in qcfg["qkvsync_my_1st_sibling"].items():
1245+
if sib_1st not in lut_all_siblings:
1246+
lut_all_siblings[sib_1st] = [sib_1st]
1247+
elif me not in lut_all_siblings[sib_1st]:
1248+
lut_all_siblings[sib_1st].append(me)
1249+
1250+
full_sib_list = {}
1251+
for me, all_sibs in lut_all_siblings.items():
1252+
partial_matches = [lin for lin in all_linears if me in lin]
1253+
# here lin is full_name, me and all_sibs are partial
1254+
for lin in partial_matches:
1255+
prefix = lin[: lin.index(me)]
1256+
for sib in all_sibs:
1257+
full_sib_list[prefix + sib] = prefix + me
1258+
all_linears.remove(prefix + sib)
1259+
# all_linears will still have down_proj, out_proj, lm_head, and maybe others
1260+
qcfg["qkvsync_my_1st_sibling"] = full_sib_list
1261+
12291262
updated_dict = {
12301263
model.get_submodule(mod): model.get_submodule(sib)
12311264
for mod, sib in qcfg["qkvsync_my_1st_sibling"].items()
@@ -1309,7 +1342,7 @@ def qbmm_auto_check(_mod, *_args, **_kwargs):
13091342
if qcfg["N_backend_called"] > 1:
13101343
logger.warning(
13111344
f"Found {qcfg['N_backend_called']} graph breaks during Dynamo tracing!! \n"
1312-
f"First/Last layer, which usually needs to stay unquantized, cannot be identified"
1345+
f"First/Last layer, which usually needs to stay unquantized, may not be identified"
13131346
f" correctly now. Please double-check layers being skipped:\n"
13141347
f"{qcfg['qskip_layer_name']}\n NOTE: Users can control layer selection by adding layer"
13151348
f"names to:\n"

0 commit comments

Comments
 (0)