Skip to content

Commit 51d4cf1

Browse files
Merge pull request #87 from chichun-charlie-liu/main
fix: fix QBmm detection and default behavior
2 parents dc2ad5d + f12edec commit 51d4cf1

File tree

7 files changed

+165
-20
lines changed

7 files changed

+165
-20
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,3 +45,4 @@ fms_mo.log
4545
data*_train/
4646
data*_test/
4747
act_scales/
48+
examples/

fms_mo/fx/dynamo_utils.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,14 @@
3232

3333
logger = logging.getLogger(__name__)
3434

35+
# From PyTorch 2.5+, graphModule received in dynamo custom backend will be Aten IR instead of FX IR,
36+
# i.e. no "call_module" nodes, all parameter tensors become "placeholder" nodes, and etc...
37+
# This following flag will make dynamo behaves like PyTorch 2.4. Only use it when model_analyzer()
38+
# really stop working and hard to recover.
39+
# Ref: https://pytorch.org/tutorials/recipes/regional_compilation.html
40+
41+
# torch._dynamo.config.inline_inbuilt_nn_modules = False
42+
3543

3644
def run_fwd_once(model, sample_inp):
3745
"""Convenient function to run model once using correct input unpack."""
@@ -836,14 +844,16 @@ def find_and_prep_bmm_gm(gm, lut_fx_mod_name_to_org: Optional[Dict[str, str]] =
836844
return_dict["which2patch_contextmanager"] = "torch.matmul"
837845
LUT2sort = all_matmuls
838846
else:
839-
warn_msg = None
840847
if Nbmm_found > 0 and Nmatmul_found > 0:
841-
warn_msg = "Both bmm and matmul are found. Not sure which to patch."
842-
elif Nbmm_found == 0 and Nmatmul_found == 0 and len(all_sdpas) > 0:
843-
warn_msg = "No bmm and matmul are found. Likely SDPA is enabled."
848+
raise RuntimeError(
849+
"Both bmm and matmul are found. Not sure which to patch."
850+
)
851+
if Nbmm_found == 0 and Nmatmul_found == 0 and len(all_sdpas) > 0:
852+
logger.warning(
853+
"No bmm and matmul are found. Likely SDPA is enabled. "
854+
"Will patch nothing!"
855+
)
844856

845-
if warn_msg:
846-
logger.warning(f"{warn_msg} Will patch nothing.")
847857
return return_dict
848858

849859
LUTmodname2linenum = {} # see Note 4
@@ -1085,6 +1095,25 @@ def cus_backend_model_analyzer(
10851095
"which2patch_contextmanager"
10861096
]
10871097
qcfg["bmm_prep"]["layers_with_bmm"].update(temp_dict["layers_with_bmm"])
1098+
# make sure there are ONLY 2 bmm per layer (self_attention). some models may use
1099+
# additional bmm/matmuls. Raise warning if that's the case.
1100+
num_layers = len(temp_dict["layers_with_bmm"])
1101+
num_bmms = 0
1102+
seen_line_num = []
1103+
for line_nums in temp_dict["layers_with_bmm"].values():
1104+
num_bmms += len(line_nums)
1105+
for line_num in line_nums:
1106+
if line_num not in seen_line_num:
1107+
seen_line_num.append(line_num)
1108+
qcfg["bmm_prep"]["bmm_only_in_self_attn"] = True
1109+
if num_bmms != num_layers * 2 or len(seen_line_num) != 2:
1110+
qcfg["bmm_prep"]["bmm_only_in_self_attn"] = False
1111+
logger.warning(
1112+
"This model uses additional matmul/bmm other than those in self-attention. "
1113+
"If you plan to quantize self-attention, please note that the additional bmms "
1114+
"may also be quantized!"
1115+
f"{temp_dict['layers_with_bmm']}\n"
1116+
)
10881117

10891118
# Check 7: QKV
10901119
temp_dict = find_qkvsync_candidates_gm(
@@ -1213,6 +1242,32 @@ def call_seq_hook(mod, *_args, **_kwargs):
12131242
)
12141243
setattr(mod_bmm_happened, f"QBmm{ln}", newQBmm)
12151244

1245+
# add auto QBmm check to last layer if any QBmms in model (only for transformers)
1246+
def qbmm_auto_check(_mod, *_args, **_kwargs):
1247+
"""Automatic QBmm check. This hook will be attached to the last module and check once
1248+
only at the end of first forward() call. Throw a "warning" if a model has QBmm attached
1249+
but not called (as it could be intentional.)
1250+
"""
1251+
num_called_qbmms = []
1252+
for lay, line_nums in qcfg["bmm_prep"]["layers_with_bmm"].items():
1253+
for ln in line_nums:
1254+
qbmm_i = model.get_submodule(f"{lay}.QBmm{ln}")
1255+
num_called_qbmms.append(qbmm_i.num_module_called == 1)
1256+
1257+
if not all(num_called_qbmms):
1258+
err_msg = (
1259+
"QBmms were attached but not called during forward()."
1260+
"Possibly patch_torch_bmm() context manager is missing."
1261+
)
1262+
if qcfg["force_stop_if_qbmm_auto_check_failed"]:
1263+
raise RuntimeError(err_msg)
1264+
logger.warning(err_msg)
1265+
1266+
qcfg["hook_qbmm_auto_check"].remove()
1267+
1268+
last_mod = model.get_submodule(qcfg["mod_call_seq"][-1])
1269+
qcfg["hook_qbmm_auto_check"] = last_mod.register_forward_hook(qbmm_auto_check)
1270+
12161271
# c) identify RPN/FPN
12171272
# TODO this hack only works for torchvision models. will use find_rpn_fpn_gm()
12181273

fms_mo/utils/qconfig_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def qconfig_init(recipe: str = None, args: Any = None):
198198
qcfg["which2patch_contextmanager"] = (
199199
None # an internal var that should not be set by user
200200
)
201+
qcfg["force_stop_if_qbmm_auto_check_failed"] = False
201202

202203
# LSTM related, if any of these is not None, then last layer (FC) will not be skipped.
203204
qcfg["nbits_w_lstm"] = None
@@ -372,6 +373,7 @@ def remove_unwanted_from_config(config):
372373
"LUTmodule_name",
373374
"qkvsync_my_1st_sibling",
374375
"graph_in_out",
376+
"hook_qbmm_auto_check",
375377
]
376378
len_before = len(config)
377379
dump = {k: config.pop(k) for k in unwanted_items if k in config}

fms_mo/utils/utils.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def move_to(obj, device):
7171
return obj
7272

7373

74-
def mockbmm(mat1, mat2):
74+
def mockbmm(mat1, mat2, default_to_torch=False):
7575
"""
7676
This function is used to mock the behavior of the bmm function in PyTorch.
7777
It is used to work around the fact that the bmm function in PyTorch is not
@@ -86,20 +86,23 @@ def mockbmm(mat1, mat2):
8686
"""
8787
cf = sys._getframe()
8888
qbmm_mod = None
89+
qbmm_lineno = cf.f_back.f_lineno
8990
while cf.f_back and qbmm_mod is None:
9091
# First frame is QBmm's forward itself, can start searching from previous stack
9192
cf = cf.f_back
92-
if "forward" in cf.f_code.co_name or "_attn" in cf.f_code.co_name:
93+
if (
94+
"forward" in cf.f_code.co_name or "_attn" in cf.f_code.co_name
95+
) and "self" in cf.f_locals:
9396
mod_calling_bmm_function = cf.f_locals["self"]
9497
# If not found -> default to torch.matmul
95-
qbmm_mod = getattr(
96-
mod_calling_bmm_function, "QBmm" + str(cf.f_lineno), torch.matmul
97-
)
98+
qbmm_mod = getattr(mod_calling_bmm_function, f"QBmm{qbmm_lineno}", None)
9899
del cf
100+
if qbmm_mod is None and default_to_torch:
101+
qbmm_mod = torch.matmul
99102
return qbmm_mod(mat1, mat2)
100103

101104

102-
def mockmatmul(mat1, mat2):
105+
def mockmatmul(mat1, mat2, default_to_torch=False):
103106
"""
104107
Patches torch.matmul() with QBmm( torch.bmm() )
105108
@@ -109,31 +112,37 @@ def mockmatmul(mat1, mat2):
109112
110113
Returns:
111114
torch.Tensor: The result of the mock matrix multiplication.
115+
NOTE:
116+
1. First frame is mockmatmul itself. One frame back (cf.f_back) is where torch.matmul
117+
happened, whose line number is the one used for QBmm<xxx>
118+
2. QBmm module may not be attached to the immediate frame where torch.matmul happened. Need
119+
to trace back and find the frame with both "forward" in name and "self" in locals, i.e.
120+
a class (nn.module) has a function named "forward" something
121+
3. Keep default_to_torch=False unless really needed, otherwise if something went wrong with
122+
QBmm detection, it could go to default silently, which would be very difficult to debug.
112123
"""
113124
cf = sys._getframe()
114125
qbmm_mod = None
126+
qbmm_lineno = cf.f_back.f_lineno
115127
while cf.f_back and qbmm_mod is None:
116-
# First frame is QBmm's forward itself, can start searching from previous stack
117128
cf = cf.f_back
118129
if (
119130
"forward" in cf.f_code.co_name or "_attn" in cf.f_code.co_name
120131
) and "self" in cf.f_locals:
121132
mod_calling_bmm_function = cf.f_locals["self"]
122133
# If not found -> default to torch.bmm
123-
qbmm_mod = getattr(
124-
mod_calling_bmm_function, "QBmm" + str(cf.f_lineno), torch.bmm
125-
)
134+
qbmm_mod = getattr(mod_calling_bmm_function, f"QBmm{qbmm_lineno}", None)
126135
del cf
127136

128137
# Didn't find the corresponding QBmm, default the call to torch.bmm
129-
if qbmm_mod == torch.bmm:
138+
if qbmm_mod is None and default_to_torch:
130139
org_batch_header = mat1.shape[:2]
131140
# Need to double check m1/m2 are 3d, otherwise reshape
132141
if len(mat1.shape) > 3:
133142
mat1 = mat1.reshape([-1, mat1.shape[-2], mat1.shape[-1]])
134143
if len(mat2.shape) > 3:
135144
mat2 = mat2.reshape([-1, mat2.shape[-2], mat2.shape[-1]])
136-
output = qbmm_mod(mat1, mat2)
145+
output = torch.bmm(mat1, mat2)
137146
output = output.reshape([*org_batch_header, *output.shape[1:]])
138147
return output
139148
return qbmm_mod(mat1, mat2)
@@ -149,6 +158,9 @@ def patch_torch_bmm(qcfg):
149158
if qcfg is not None:
150159
# could be 'torch.bmm', 'torch.matmul', or None
151160
ops_to_patch = qcfg.get("which2patch_contextmanager", None)
161+
# if qcfg["bmm_prep"]["bmm_only_in_self_attn"] is False, may need to enable default_to_torch
162+
# in mock functions, e.g. partial(mockmatmul, default_to_torch=True)
163+
# This is in case a model uses extra matmuls, and QBmmXXX is not found or attached properly.
152164
new_target = (
153165
mockbmm
154166
if ops_to_patch == "torch.bmm"

tests/models/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,3 +1092,16 @@ def model_bert():
10921092
transformers.models.bert.modeling_bert.BertModel: BERT model
10931093
"""
10941094
return BertModel.from_pretrained("google-bert/bert-base-uncased", torchscript=True)
1095+
1096+
1097+
@pytest.fixture(scope="function")
1098+
def model_bert_eager():
1099+
"""
1100+
Get a BERT model
1101+
1102+
Returns:
1103+
transformers.models.bert.modeling_bert.BertModel: BERT model
1104+
"""
1105+
return BertModel.from_pretrained(
1106+
"google-bert/bert-base-uncased", torchscript=True, attn_implementation="eager"
1107+
)

tests/models/test_model_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import torch
2727

2828
# Local
29+
from fms_mo.modules.bmm import QBmm
2930
from fms_mo.modules.conv import DetQConv2d, QConv2d, QConv2dPTQ, QConv2dPTQv2
3031
from fms_mo.modules.linear import QLinear
3132
from fms_mo.utils.qconfig_utils import serialize_config
@@ -99,7 +100,7 @@ def count_qmodules(model: torch.nn.Module):
99100
"""
100101
torch_modules, fms_qmodules = [], []
101102
for n, m in model.named_modules():
102-
if isinstance(m, (QConv2d, QLinear)):
103+
if isinstance(m, (QConv2d, QLinear, QBmm)):
103104
fms_qmodules.append((n, m))
104105
elif isinstance(m, (Conv2d, Linear)):
105106
torch_modules.append((n, m))

tests/models/test_qmodelprep.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
# fms_mo imports
2727
from fms_mo import qmodel_prep
2828
from fms_mo.prep import has_quantized_module
29-
from tests.models.test_model_utils import delete_config, qmodule_error
29+
from fms_mo.utils.utils import patch_torch_bmm
30+
from tests.models.test_model_utils import count_qmodules, delete_config, qmodule_error
3031

3132
################
3233
# Qmodel tests #
@@ -257,3 +258,63 @@ def test_bert_dynamo(
257258
delete_config()
258259
qmodel_prep(model_bert, input_bert, config_int8, use_dynamo=True)
259260
qmodule_error(model_bert, 1, 72)
261+
262+
263+
def test_bert_dynamo_wi_qbmm(
264+
model_bert_eager: transformers.models.bert.modeling_bert.BertModel,
265+
input_bert: torch.FloatTensor,
266+
config_int8: dict,
267+
):
268+
"""
269+
Perform int8 quantization on BERT w/ Dynamo tracer and QBmm modules. QBmms will be run in place
270+
of torch.matmul/torch.bmm automatically, if everything is set up correctly. See the 3 checks
271+
below for more details.
272+
NOTE:
273+
1. QBmm modules will be added after qmodel_prep(), see check 1.
274+
2. The self-attention forward() will still call torch.matmul as written in the original
275+
python code, i.e. if we check QLinear.num_called and QBmm.num_called, they will be 1 and
276+
0, respectively, meaning QBmms were attached but not called.
277+
3. By using patch_torch_bmm() context manager, QBmm modules will be triggered and those
278+
torch.matmul (usually 2 per attn module) calls will be redirect to QBmm's forward.
279+
280+
Args:
281+
model_bert (transformers.models.bert.modeling_bert.BertModel): BERT model + weights
282+
input_bert (torch.FloatTensor): Tokenized input for BERT
283+
config (dict): Recipe Config w/ int8 settings
284+
"""
285+
delete_config()
286+
config_int8["nbits_bmm1"] = 8
287+
config_int8["nbits_bmm2"] = 8
288+
qmodel_prep(model_bert_eager, input_bert, config_int8, use_dynamo=True)
289+
290+
# check 1: make sure QBmm are added, i.e. 72 QLinear + 24 QBmm
291+
qmodule_error(model_bert_eager, 1, 96)
292+
293+
_, fms_qmodules = count_qmodules(model_bert_eager)
294+
qbmms = []
295+
other_qmodules = []
296+
for n, m in fms_qmodules:
297+
if "QBmm" in n:
298+
qbmms.append(m)
299+
else:
300+
other_qmodules.append(m)
301+
302+
# check 2: model call without our "patch" context manager, will not reach QBmm
303+
# we have an auto check in place, but it will only log warning, unless this flag
304+
# qcfg["force_stop_if_qbmm_auto_check_failed"] = True
305+
with torch.no_grad():
306+
model_bert_eager(**input_bert)
307+
assert all(
308+
m.num_module_called == 0 for m in qbmms
309+
), "Some QBmm was called when they shouldn't be."
310+
311+
# check 3: model call with context manager, will reach QBmm
312+
with torch.no_grad(), patch_torch_bmm(config_int8):
313+
model_bert_eager(**input_bert)
314+
assert all(
315+
m.num_module_called == 1 for m in qbmms
316+
), "Some QBmm was not called properly."
317+
318+
assert all(
319+
m.num_module_called == 2 for m in other_qmodules
320+
), "Modules other than QBmm were not called properly."

0 commit comments

Comments
 (0)