@@ -266,7 +266,16 @@ def test_bert_dynamo_wi_qbmm(
266266 config_int8 : dict ,
267267):
268268 """
269- Perform int8 quantization on BERT w/ Dynamo tracer and QBmm modules
269+ Perform int8 quantization on BERT w/ Dynamo tracer and QBmm modules. QBmms will be run in place
270+ of torch.matmul/torch.bmm automatically, if everything is set up correctly. See the 3 checks
271+ below for more details.
272+ NOTE:
273+ 1. QBmm modules will be added after qmodel_prep(), see check 1.
274+ 2. The self-attention forward() will still call torch.matmul as written in the original
275+ python code, i.e. if we check QLinear.num_called and QBmm.num_called, they will be 1 and
276+ 0, respectively, meaning QBmms were attached but not called.
277+ 3. By using patch_torch_bmm() context manager, QBmm modules will be triggered and those
278+ torch.matmul (usually 2 per attn module) calls will be redirect to QBmm's forward.
270279
271280 Args:
272281 model_bert (transformers.models.bert.modeling_bert.BertModel): BERT model + weights
@@ -281,15 +290,29 @@ def test_bert_dynamo_wi_qbmm(
281290 # check 1: make sure QBmm are added, i.e. 72 QLinear + 24 QBmm
282291 qmodule_error (model_bert_eager , 1 , 96 )
283292
284- # check 2: make sure context manager can reach QBmm
285293 _ , fms_qmodules = count_qmodules (model_bert_eager )
286- with torch .no_grad (), patch_torch_bmm (config_int8 ):
294+ qbmms = []
295+ other_qmodules = []
296+ for n , m in fms_qmodules :
297+ if "QBmm" in n :
298+ qbmms .append (m )
299+ else :
300+ other_qmodules .append (m )
301+
302+ # check 2: model call without our "patch" context manager, will not reach QBmm
303+ with torch .no_grad ():
287304 model_bert_eager (** input_bert )
288- qbmms = [m for n , m in fms_qmodules if "QBmm" in n ]
305+ assert all (
306+ m .num_module_called == 0 for m in qbmms
307+ ), "Some QBmm was called when they shouldn't be."
289308
309+ # check 3: model call with context manager, will reach QBmm
310+ with torch .no_grad (), patch_torch_bmm (config_int8 ):
311+ model_bert_eager (** input_bert )
290312 assert all (
291313 m .num_module_called == 1 for m in qbmms
292314 ), "Some QBmm was not called properly."
315+
293316 assert all (
294- m .num_module_called == 1 for _ , m in fms_qmodules
295- ), "Some module was not called properly."
317+ m .num_module_called == 2 for m in other_qmodules
318+ ), "Modules other than QBmm were not called properly."
0 commit comments