Skip to content

Commit 9938a6e

Browse files
rahul-tulihorheynm
andauthored
Update: Test for Compatibility with Transformers 4.48 (#239)
* Update: test for new transformers release * fix bug * check layers exist with counting --------- Co-authored-by: George Ohashi <[email protected]>
1 parent 6fffbd7 commit 9938a6e

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

tests/test_quantization/lifecycle/test_apply.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import re
16+
from collections import defaultdict
1617
from typing import Optional
1718
from unittest.mock import MagicMock
1819

@@ -114,31 +115,36 @@ def test_apply_quantization_config_tinyllama():
114115
for module in model.modules():
115116
_test_layer_quantization_status(module, inputs=False, weights=False)
116117

118+
count_layer_names = ("Linear", "Embeddidng", "LlamaRotaryEmbedding")
119+
count_layer_num = defaultdict(int)
120+
121+
for name, module in model.named_modules():
122+
if name in quant_config.ignore:
123+
continue
124+
module_type = module.__class__.__name__
125+
if module_type in count_layer_names:
126+
count_layer_num[module_type] += 1
127+
128+
assert len(count_layer_num) > 0, f"None of {count_layer_names} found in model"
129+
assert all(value > 0 for value in count_layer_num.values())
130+
117131
# apply quant config to model
118132
apply_quantization_config(model, quant_config)
119133

120134
# check for correct application of quant config
121-
num_linears = 0
122-
num_embeddings = 0
123-
num_rotary_embeddings = 0
124135
for name, module in model.named_modules():
125136
if name in quant_config.ignore:
126137
continue
127138
module_type = module.__class__.__name__
128-
if module_type == "Linear":
129-
num_linears += 1
130-
_test_layer_quantization_status(module, inputs=True, weights=True)
131-
elif module_type == "Embedding":
132-
num_embeddings += 1
133-
_test_layer_quantization_status(module, inputs=False, weights=True)
134-
elif module_type == "LlamaRotaryEmbedding":
135-
num_rotary_embeddings += 1
136-
_test_layer_quantization_status(module, inputs=False, weights=False)
137-
138-
# sanity check correct number of layers targeted
139-
assert num_linears == 154 # 155 Linear layers - 1 that gets ignored
140-
assert num_embeddings == 1
141-
assert num_rotary_embeddings == 23 # model updated, now has model.rotary_embedding
139+
if module_type in count_layer_names:
140+
count_layer_num[module_type] -= 1
141+
_inputs = module_type == "Linear"
142+
_weights = not module_type == "LlamaRotaryEmbedding"
143+
_test_layer_quantization_status(module, inputs=_inputs, weights=_weights)
144+
145+
assert all(
146+
value == 0 for value in count_layer_num.values()
147+
), "Not all values are zero"
142148

143149
# test quantization compression
144150
# sample forward pass to fill scales, zps

0 commit comments

Comments
 (0)