Skip to content

Commit d6c0fe9

Browse files
enhance qbmm unit test to demo the effect of context manager patch_torch_bmm()
Signed-off-by: cliu-us <[email protected]>
1 parent 357622d commit d6c0fe9

File tree

1 file changed

+29
-6
lines changed

1 file changed

+29
-6
lines changed

tests/models/test_qmodelprep.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,16 @@ def test_bert_dynamo_wi_qbmm(
266266
config_int8: dict,
267267
):
268268
"""
269-
Perform int8 quantization on BERT w/ Dynamo tracer and QBmm modules
269+
Perform int8 quantization on BERT w/ Dynamo tracer and QBmm modules. QBmms will be run in place
270+
of torch.matmul/torch.bmm automatically, if everything is set up correctly. See the 3 checks
271+
below for more details.
272+
NOTE:
273+
1. QBmm modules will be added after qmodel_prep(), see check 1.
274+
2. The self-attention forward() will still call torch.matmul as written in the original
275+
python code, i.e. if we check QLinear.num_called and QBmm.num_called, they will be 1 and
276+
0, respectively, meaning QBmms were attached but not called.
277+
3. By using patch_torch_bmm() context manager, QBmm modules will be triggered and those
278+
torch.matmul (usually 2 per attn module) calls will be redirect to QBmm's forward.
270279
271280
Args:
272281
model_bert (transformers.models.bert.modeling_bert.BertModel): BERT model + weights
@@ -281,15 +290,29 @@ def test_bert_dynamo_wi_qbmm(
281290
# check 1: make sure QBmm are added, i.e. 72 QLinear + 24 QBmm
282291
qmodule_error(model_bert_eager, 1, 96)
283292

284-
# check 2: make sure context manager can reach QBmm
285293
_, fms_qmodules = count_qmodules(model_bert_eager)
286-
with torch.no_grad(), patch_torch_bmm(config_int8):
294+
qbmms = []
295+
other_qmodules = []
296+
for n, m in fms_qmodules:
297+
if "QBmm" in n:
298+
qbmms.append(m)
299+
else:
300+
other_qmodules.append(m)
301+
302+
# check 2: model call without our "patch" context manager, will not reach QBmm
303+
with torch.no_grad():
287304
model_bert_eager(**input_bert)
288-
qbmms = [m for n, m in fms_qmodules if "QBmm" in n]
305+
assert all(
306+
m.num_module_called == 0 for m in qbmms
307+
), "Some QBmm was called when they shouldn't be."
289308

309+
# check 3: model call with context manager, will reach QBmm
310+
with torch.no_grad(), patch_torch_bmm(config_int8):
311+
model_bert_eager(**input_bert)
290312
assert all(
291313
m.num_module_called == 1 for m in qbmms
292314
), "Some QBmm was not called properly."
315+
293316
assert all(
294-
m.num_module_called == 1 for _, m in fms_qmodules
295-
), "Some module was not called properly."
317+
m.num_module_called == 2 for m in other_qmodules
318+
), "Modules other than QBmm were not called properly."

0 commit comments

Comments
 (0)