Skip to content

Commit 04e4cb1

Browse files
qkvsync bug fix, graph breaks will induce qkv sibling list error. only partial names will be found and cause problems
Signed-off-by: cliu-us <[email protected]>
1 parent a60a4b8 commit 04e4cb1

File tree

1 file changed

+35
-1
lines changed

1 file changed

+35
-1
lines changed

fms_mo/fx/dynamo_utils.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1010,6 +1010,8 @@ def cus_backend_model_analyzer(
10101010
}
10111011
prefix = None
10121012
if qcfg["N_backend_called"] > 1: # subgraph found, see Note 2
1013+
# TODO this approach only works for FX IR (call_module nodes are not functionalized)
1014+
# need an update for Aten IR cases
10131015
for n in gm_fx.graph.nodes:
10141016
if n.op == "call_module":
10151017
mod = gm_fx.get_submodule(n.target)
@@ -1220,6 +1222,38 @@ def call_seq_hook(mod, *_args, **_kwargs):
12201222
# ------ model analysis is finished, but there are a few remaining things to be done
12211223

12221224
# a) qkvsync dict update from "module names" to "module instances"
1225+
# NOTE when graph break happened, qkvsync() may only find partial QKV names. For example,
1226+
# as opposed to ["model.layers.0.self_attn.q_proj", ..., "model.layers.1.self_attn.q_proj", ...]
1227+
# it may report ["self_attn.q_proj", "self_attn.k_proj", ...]
1228+
# Therefore, length of qcfg["qkvsync_my_1st_sibling"] will be much shorter and keys of this dict
1229+
# won't exist in full list (like all_linears below).
1230+
all_linears = set(
1231+
[n for n, m in model.named_modules() if isinstance(m, torch.nn.Linear)]
1232+
)
1233+
1234+
if any(k not in all_linears for k in qcfg["qkvsync_my_1st_sibling"]):
1235+
# qcfg["qkvsync_my_1st_sibling"] dict is like {q:q, k:q, v:q,...}, here we need a simpler
1236+
# dict like {q:[q,k,v], gate:[up, gate]}
1237+
lut_all_siblings = {}
1238+
for me, sib_1st in qcfg["qkvsync_my_1st_sibling"].items():
1239+
if sib_1st not in lut_all_siblings:
1240+
lut_all_siblings[sib_1st] = [sib_1st]
1241+
elif me not in lut_all_siblings[sib_1st]:
1242+
lut_all_siblings[sib_1st].append(me)
1243+
1244+
full_sib_list = {}
1245+
for me in lut_all_siblings:
1246+
partial_matches = [lin for lin in all_linears if me in lin]
1247+
all_sibs = lut_all_siblings[me]
1248+
# here lin is full_name, me and all_sibs are partial
1249+
for lin in partial_matches:
1250+
prefix = lin[: lin.index(me)]
1251+
for sib in all_sibs:
1252+
full_sib_list[prefix + sib] = prefix + me
1253+
all_linears.remove(prefix + sib)
1254+
# all_linears will still have down_proj, out_proj, lm_head, and maybe others
1255+
qcfg["qkvsync_my_1st_sibling"] = full_sib_list
1256+
12231257
updated_dict = {
12241258
model.get_submodule(mod): model.get_submodule(sib)
12251259
for mod, sib in qcfg["qkvsync_my_1st_sibling"].items()
@@ -1303,7 +1337,7 @@ def qbmm_auto_check(_mod, *_args, **_kwargs):
13031337
if qcfg["N_backend_called"] > 1:
13041338
logger.warning(
13051339
f"Found {qcfg['N_backend_called']} graph breaks during Dynamo tracing!! \n"
1306-
f"First/Last layer, which usually needs to stay unquantized, cannot be identified"
1340+
f"First/Last layer, which usually needs to stay unquantized, may not be identified"
13071341
f" correctly now. Please double-check layers being skipped:\n"
13081342
f"{qcfg['qskip_layer_name']}\n NOTE: Users can control layer selection by adding layer"
13091343
f"names to:\n"

0 commit comments

Comments
 (0)