|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | import re
|
| 16 | +from collections import defaultdict |
16 | 17 | from typing import Optional
|
17 | 18 | from unittest.mock import MagicMock
|
18 | 19 |
|
@@ -114,31 +115,36 @@ def test_apply_quantization_config_tinyllama():
|
114 | 115 | for module in model.modules():
|
115 | 116 | _test_layer_quantization_status(module, inputs=False, weights=False)
|
116 | 117 |
|
| 118 | + count_layer_names = ("Linear", "Embeddidng", "LlamaRotaryEmbedding") |
| 119 | + count_layer_num = defaultdict(int) |
| 120 | + |
| 121 | + for name, module in model.named_modules(): |
| 122 | + if name in quant_config.ignore: |
| 123 | + continue |
| 124 | + module_type = module.__class__.__name__ |
| 125 | + if module_type in count_layer_names: |
| 126 | + count_layer_num[module_type] += 1 |
| 127 | + |
| 128 | + assert len(count_layer_num) > 0, f"None of {count_layer_names} found in model" |
| 129 | + assert all(value > 0 for value in count_layer_num.values()) |
| 130 | + |
117 | 131 | # apply quant config to model
|
118 | 132 | apply_quantization_config(model, quant_config)
|
119 | 133 |
|
120 | 134 | # check for correct application of quant config
|
121 |
| - num_linears = 0 |
122 |
| - num_embeddings = 0 |
123 |
| - num_rotary_embeddings = 0 |
124 | 135 | for name, module in model.named_modules():
|
125 | 136 | if name in quant_config.ignore:
|
126 | 137 | continue
|
127 | 138 | module_type = module.__class__.__name__
|
128 |
| - if module_type == "Linear": |
129 |
| - num_linears += 1 |
130 |
| - _test_layer_quantization_status(module, inputs=True, weights=True) |
131 |
| - elif module_type == "Embedding": |
132 |
| - num_embeddings += 1 |
133 |
| - _test_layer_quantization_status(module, inputs=False, weights=True) |
134 |
| - elif module_type == "LlamaRotaryEmbedding": |
135 |
| - num_rotary_embeddings += 1 |
136 |
| - _test_layer_quantization_status(module, inputs=False, weights=False) |
137 |
| - |
138 |
| - # sanity check correct number of layers targeted |
139 |
| - assert num_linears == 154 # 155 Linear layers - 1 that gets ignored |
140 |
| - assert num_embeddings == 1 |
141 |
| - assert num_rotary_embeddings == 23 # model updated, now has model.rotary_embedding |
| 139 | + if module_type in count_layer_names: |
| 140 | + count_layer_num[module_type] -= 1 |
| 141 | + _inputs = module_type == "Linear" |
| 142 | + _weights = not module_type == "LlamaRotaryEmbedding" |
| 143 | + _test_layer_quantization_status(module, inputs=_inputs, weights=_weights) |
| 144 | + |
| 145 | + assert all( |
| 146 | + value == 0 for value in count_layer_num.values() |
| 147 | + ), "Not all values are zero" |
142 | 148 |
|
143 | 149 | # test quantization compression
|
144 | 150 | # sample forward pass to fill scales, zps
|
|
0 commit comments