@@ -1010,6 +1010,8 @@ def cus_backend_model_analyzer(
10101010 }
10111011 prefix = None
10121012 if qcfg ["N_backend_called" ] > 1 : # subgraph found, see Note 2
1013+ # TODO this approach only works for FX IR (call_module nodes are not functionalized)
1014+ # need an update for Aten IR cases
10131015 for n in gm_fx .graph .nodes :
10141016 if n .op == "call_module" :
10151017 mod = gm_fx .get_submodule (n .target )
@@ -1220,6 +1222,38 @@ def call_seq_hook(mod, *_args, **_kwargs):
12201222 # ------ model analysis is finished, but there are a few remaining things to be done
12211223
12221224 # a) qkvsync dict update from "module names" to "module instances"
1225+ # NOTE when graph break happened, qkvsync() may only find partial QKV names. For example,
1226+ # as opposed to ["model.layers.0.self_attn.q_proj", ..., "model.layers.1.self_attn.q_proj", ...]
1227+ # it may report ["self_attn.q_proj", "self_attn.k_proj", ...]
1228+ # Therefore, length of qcfg["qkvsync_my_1st_sibling"] will be much shorter and keys of this dict
1229+ # won't exist in full list (like all_linears below).
1230+ all_linears = set (
1231+ [n for n , m in model .named_modules () if isinstance (m , torch .nn .Linear )]
1232+ )
1233+
1234+ if any (k not in all_linears for k in qcfg ["qkvsync_my_1st_sibling" ]):
1235+ # qcfg["qkvsync_my_1st_sibling"] dict is like {q:q, k:q, v:q,...}, here we need a simpler
1236+ # dict like {q:[q,k,v], gate:[up, gate]}
1237+ lut_all_siblings = {}
1238+ for me , sib_1st in qcfg ["qkvsync_my_1st_sibling" ].items ():
1239+ if sib_1st not in lut_all_siblings :
1240+ lut_all_siblings [sib_1st ] = [sib_1st ]
1241+ elif me not in lut_all_siblings [sib_1st ]:
1242+ lut_all_siblings [sib_1st ].append (me )
1243+
1244+ full_sib_list = {}
1245+ for me in lut_all_siblings :
1246+ partial_matches = [lin for lin in all_linears if me in lin ]
1247+ all_sibs = lut_all_siblings [me ]
1248+ # here lin is full_name, me and all_sibs are partial
1249+ for lin in partial_matches :
1250+ prefix = lin [: lin .index (me )]
1251+ for sib in all_sibs :
1252+ full_sib_list [prefix + sib ] = prefix + me
1253+ all_linears .remove (prefix + sib )
1254+ # all_linears will still have down_proj, out_proj, lm_head, and maybe others
1255+ qcfg ["qkvsync_my_1st_sibling" ] = full_sib_list
1256+
12231257 updated_dict = {
12241258 model .get_submodule (mod ): model .get_submodule (sib )
12251259 for mod , sib in qcfg ["qkvsync_my_1st_sibling" ].items ()
@@ -1303,7 +1337,7 @@ def qbmm_auto_check(_mod, *_args, **_kwargs):
13031337 if qcfg ["N_backend_called" ] > 1 :
13041338 logger .warning (
13051339 f"Found { qcfg ['N_backend_called' ]} graph breaks during Dynamo tracing!! \n "
1306- f"First/Last layer, which usually needs to stay unquantized, cannot be identified"
1340+ f"First/Last layer, which usually needs to stay unquantized, may not be identified"
13071341 f" correctly now. Please double-check layers being skipped:\n "
13081342 f"{ qcfg ['qskip_layer_name' ]} \n NOTE: Users can control layer selection by adding layer"
13091343 f"names to:\n "
0 commit comments