Skip to content

Commit 357622d

Browse files
add a unit test to cover qbmm attachment and all reachable
Signed-off-by: cliu-us <[email protected]>
1 parent dd53de5 commit 357622d

File tree

3 files changed

+52
-2
lines changed

3 files changed

+52
-2
lines changed

tests/models/conftest.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1092,3 +1092,16 @@ def model_bert():
10921092
transformers.models.bert.modeling_bert.BertModel: BERT model
10931093
"""
10941094
return BertModel.from_pretrained("google-bert/bert-base-uncased", torchscript=True)
1095+
1096+
1097+
@pytest.fixture(scope="function")
1098+
def model_bert_eager():
1099+
"""
1100+
Get a BERT model
1101+
1102+
Returns:
1103+
transformers.models.bert.modeling_bert.BertModel: BERT model
1104+
"""
1105+
return BertModel.from_pretrained(
1106+
"google-bert/bert-base-uncased", torchscript=True, attn_implementation="eager"
1107+
)

tests/models/test_model_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import torch
2727

2828
# Local
29+
from fms_mo.modules.bmm import QBmm
2930
from fms_mo.modules.conv import DetQConv2d, QConv2d, QConv2dPTQ, QConv2dPTQv2
3031
from fms_mo.modules.linear import QLinear
3132
from fms_mo.utils.qconfig_utils import serialize_config
@@ -99,7 +100,7 @@ def count_qmodules(model: torch.nn.Module):
99100
"""
100101
torch_modules, fms_qmodules = [], []
101102
for n, m in model.named_modules():
102-
if isinstance(m, (QConv2d, QLinear)):
103+
if isinstance(m, (QConv2d, QLinear, QBmm)):
103104
fms_qmodules.append((n, m))
104105
elif isinstance(m, (Conv2d, Linear)):
105106
torch_modules.append((n, m))

tests/models/test_qmodelprep.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
# fms_mo imports
2727
from fms_mo import qmodel_prep
2828
from fms_mo.prep import has_quantized_module
29-
from tests.models.test_model_utils import delete_config, qmodule_error
29+
from fms_mo.utils.utils import patch_torch_bmm
30+
from tests.models.test_model_utils import count_qmodules, delete_config, qmodule_error
3031

3132
################
3233
# Qmodel tests #
@@ -257,3 +258,38 @@ def test_bert_dynamo(
257258
delete_config()
258259
qmodel_prep(model_bert, input_bert, config_int8, use_dynamo=True)
259260
qmodule_error(model_bert, 1, 72)
261+
262+
263+
def test_bert_dynamo_wi_qbmm(
264+
model_bert_eager: transformers.models.bert.modeling_bert.BertModel,
265+
input_bert: torch.FloatTensor,
266+
config_int8: dict,
267+
):
268+
"""
269+
Perform int8 quantization on BERT w/ Dynamo tracer and QBmm modules
270+
271+
Args:
272+
model_bert (transformers.models.bert.modeling_bert.BertModel): BERT model + weights
273+
input_bert (torch.FloatTensor): Tokenized input for BERT
274+
config (dict): Recipe Config w/ int8 settings
275+
"""
276+
delete_config()
277+
config_int8["nbits_bmm1"] = 8
278+
config_int8["nbits_bmm2"] = 8
279+
qmodel_prep(model_bert_eager, input_bert, config_int8, use_dynamo=True)
280+
281+
# check 1: make sure QBmm are added, i.e. 72 QLinear + 24 QBmm
282+
qmodule_error(model_bert_eager, 1, 96)
283+
284+
# check 2: make sure context manager can reach QBmm
285+
_, fms_qmodules = count_qmodules(model_bert_eager)
286+
with torch.no_grad(), patch_torch_bmm(config_int8):
287+
model_bert_eager(**input_bert)
288+
qbmms = [m for n, m in fms_qmodules if "QBmm" in n]
289+
290+
assert all(
291+
m.num_module_called == 1 for m in qbmms
292+
), "Some QBmm was not called properly."
293+
assert all(
294+
m.num_module_called == 1 for _, m in fms_qmodules
295+
), "Some module was not called properly."

0 commit comments

Comments
 (0)