Skip to content

Commit dd53de5

Browse files
fix QBmm detection and default behavior
Signed-off-by: cliu-us <[email protected]>
1 parent a0c2aae commit dd53de5

File tree

3 files changed

+60
-18
lines changed

3 files changed

+60
-18
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,4 @@ fms_mo.log
4545
data*_train/
4646
data*_test/
4747
act_scales/
48+
examples/

fms_mo/fx/dynamo_utils.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@
3232

3333
logger = logging.getLogger(__name__)
3434

35+
# From PyTorch 2.5+, graphModule received in dynamo custom backend will be Aten IR instead of FX IR,
36+
# i.e. no "call_module" nodes, all parameter tensors become "placeholder" nodes, and etc...
37+
# This following flag will make dynamo behaves like PyTorch 2.4. Only use it when model_analyzer()
38+
# really stop working and hard to recover.
39+
# Ref: https://pytorch.org/tutorials/recipes/regional_compilation.html
40+
41+
# torch._dynamo.config.inline_inbuilt_nn_modules = False
42+
3543

3644
def run_fwd_once(model, sample_inp):
3745
"""Convenient function to run model once using correct input unpack."""
@@ -836,14 +844,16 @@ def find_and_prep_bmm_gm(gm, lut_fx_mod_name_to_org: Optional[Dict[str, str]] =
836844
return_dict["which2patch_contextmanager"] = "torch.matmul"
837845
LUT2sort = all_matmuls
838846
else:
839-
warn_msg = None
840847
if Nbmm_found > 0 and Nmatmul_found > 0:
841-
warn_msg = "Both bmm and matmul are found. Not sure which to patch."
842-
elif Nbmm_found == 0 and Nmatmul_found == 0 and len(all_sdpas) > 0:
843-
warn_msg = "No bmm and matmul are found. Likely SDPA is enabled."
848+
raise RuntimeError(
849+
"Both bmm and matmul are found. Not sure which to patch."
850+
)
851+
if Nbmm_found == 0 and Nmatmul_found == 0 and len(all_sdpas) > 0:
852+
logger.warning(
853+
"No bmm and matmul are found. Likely SDPA is enabled. "
854+
"Will patch nothing!"
855+
)
844856

845-
if warn_msg:
846-
logger.warning(f"{warn_msg} Will patch nothing.")
847857
return return_dict
848858

849859
LUTmodname2linenum = {} # see Note 4
@@ -1085,6 +1095,25 @@ def cus_backend_model_analyzer(
10851095
"which2patch_contextmanager"
10861096
]
10871097
qcfg["bmm_prep"]["layers_with_bmm"].update(temp_dict["layers_with_bmm"])
1098+
# make sure there are ONLY 2 bmm per layer (self_attention). some models may use
1099+
# additional bmm/matmuls. Raise warning if that's the case.
1100+
num_layers = len(temp_dict["layers_with_bmm"])
1101+
num_bmms = 0
1102+
seen_line_num = []
1103+
for line_nums in temp_dict["layers_with_bmm"].values():
1104+
num_bmms += len(line_nums)
1105+
for line_num in line_nums:
1106+
if line_num not in seen_line_num:
1107+
seen_line_num.append(line_num)
1108+
qcfg["bmm_prep"]["bmm_only_in_self_attn"] = True
1109+
if num_bmms != num_layers * 2 or len(seen_line_num) != 2:
1110+
qcfg["bmm_prep"]["bmm_only_in_self_attn"] = False
1111+
logger.warning(
1112+
"This model uses additional matmul/bmm other than those in self-attention. "
1113+
"If you plan to quantize self-attention, please note that the additional bmms "
1114+
"may also be quantized!"
1115+
f"{temp_dict['layers_with_bmm']}\n"
1116+
)
10881117

10891118
# Check 7: QKV
10901119
temp_dict = find_qkvsync_candidates_gm(

fms_mo/utils/utils.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def move_to(obj, device):
7171
return obj
7272

7373

74-
def mockbmm(mat1, mat2):
74+
def mockbmm(mat1, mat2, default_to_torch=False):
7575
"""
7676
This function is used to mock the behavior of the bmm function in PyTorch.
7777
It is used to work around the fact that the bmm function in PyTorch is not
@@ -86,20 +86,23 @@ def mockbmm(mat1, mat2):
8686
"""
8787
cf = sys._getframe()
8888
qbmm_mod = None
89+
qbmm_lineno = cf.f_back.f_lineno
8990
while cf.f_back and qbmm_mod is None:
9091
# First frame is QBmm's forward itself, can start searching from previous stack
9192
cf = cf.f_back
92-
if "forward" in cf.f_code.co_name or "_attn" in cf.f_code.co_name:
93+
if (
94+
"forward" in cf.f_code.co_name or "_attn" in cf.f_code.co_name
95+
) and "self" in cf.f_locals:
9396
mod_calling_bmm_function = cf.f_locals["self"]
9497
# If not found -> default to torch.matmul
95-
qbmm_mod = getattr(
96-
mod_calling_bmm_function, "QBmm" + str(cf.f_lineno), torch.matmul
97-
)
98+
qbmm_mod = getattr(mod_calling_bmm_function, f"QBmm{qbmm_lineno}", None)
9899
del cf
100+
if qbmm_mod is None and default_to_torch:
101+
qbmm_mod = torch.matmul
99102
return qbmm_mod(mat1, mat2)
100103

101104

102-
def mockmatmul(mat1, mat2):
105+
def mockmatmul(mat1, mat2, default_to_torch=False):
103106
"""
104107
Patches torch.matmul() with QBmm( torch.bmm() )
105108
@@ -109,31 +112,37 @@ def mockmatmul(mat1, mat2):
109112
110113
Returns:
111114
torch.Tensor: The result of the mock matrix multiplication.
115+
NOTE:
116+
1. First frame is mockmatmul itself. One frame back (cf.f_back) is where torch.matmul
117+
happened, whose line number is the one used for QBmm<xxx>
118+
2. QBmm module may not be attached to the immediate frame where torch.matmul happened. Need
119+
to trace back and find the frame with both "forward" in name and "self" in locals, i.e.
120+
a class (nn.module) has a function named "forward" something
121+
3. Keep default_to_torch=False unless really needed, otherwise if something went wrong with
122+
QBmm detection, it could go to default silently, which would be very difficult to debug.
112123
"""
113124
cf = sys._getframe()
114125
qbmm_mod = None
126+
qbmm_lineno = cf.f_back.f_lineno
115127
while cf.f_back and qbmm_mod is None:
116-
# First frame is QBmm's forward itself, can start searching from previous stack
117128
cf = cf.f_back
118129
if (
119130
"forward" in cf.f_code.co_name or "_attn" in cf.f_code.co_name
120131
) and "self" in cf.f_locals:
121132
mod_calling_bmm_function = cf.f_locals["self"]
122133
# If not found -> default to torch.bmm
123-
qbmm_mod = getattr(
124-
mod_calling_bmm_function, "QBmm" + str(cf.f_lineno), torch.bmm
125-
)
134+
qbmm_mod = getattr(mod_calling_bmm_function, f"QBmm{qbmm_lineno}", None)
126135
del cf
127136

128137
# Didn't find the corresponding QBmm, default the call to torch.bmm
129-
if qbmm_mod == torch.bmm:
138+
if qbmm_mod is None and default_to_torch:
130139
org_batch_header = mat1.shape[:2]
131140
# Need to double check m1/m2 are 3d, otherwise reshape
132141
if len(mat1.shape) > 3:
133142
mat1 = mat1.reshape([-1, mat1.shape[-2], mat1.shape[-1]])
134143
if len(mat2.shape) > 3:
135144
mat2 = mat2.reshape([-1, mat2.shape[-2], mat2.shape[-1]])
136-
output = qbmm_mod(mat1, mat2)
145+
output = torch.bmm(mat1, mat2)
137146
output = output.reshape([*org_batch_header, *output.shape[1:]])
138147
return output
139148
return qbmm_mod(mat1, mat2)
@@ -149,6 +158,9 @@ def patch_torch_bmm(qcfg):
149158
if qcfg is not None:
150159
# could be 'torch.bmm', 'torch.matmul', or None
151160
ops_to_patch = qcfg.get("which2patch_contextmanager", None)
161+
# if qcfg["bmm_prep"]["bmm_only_in_self_attn"] is False, may need to enable default_to_torch
162+
# in mock functions, e.g. partial(mockmatmul, default_to_torch=True)
163+
# This is in case a model uses extra matmuls, and QBmmXXX is not found or attached properly.
152164
new_target = (
153165
mockbmm
154166
if ops_to_patch == "torch.bmm"

0 commit comments

Comments
 (0)