|
26 | 26 | QuantizationStatus,
|
27 | 27 | )
|
28 | 28 | from compressed_tensors.quantization.lifecycle import apply_quantization_config
|
| 29 | +from compressed_tensors.utils import match_named_modules |
29 | 30 | from tests.testing_utils import requires_accelerate
|
30 | 31 | from transformers import AutoModelForCausalLM
|
31 | 32 |
|
@@ -103,60 +104,27 @@ def test_target_prioritization(mock_frozen):
|
103 | 104 |
|
104 | 105 | def test_apply_quantization_config_tinyllama():
|
105 | 106 | quant_config = get_sample_tinyllama_quant_config(
|
106 |
| - status=QuantizationStatus.CALIBRATION |
| 107 | + status=QuantizationStatus.INITIALIZED |
107 | 108 | )
|
108 | 109 | model = get_tinyllama_model()
|
109 | 110 |
|
110 | 111 | # check that model is not already quantized
|
111 | 112 | for module in model.modules():
|
112 | 113 | _test_layer_quantization_status(module, inputs=False, weights=False)
|
113 | 114 |
|
114 |
| - count_layer_names = ("Linear", "Embeddidng", "LlamaRotaryEmbedding") |
115 |
| - count_layer_num = defaultdict(int) |
116 |
| - |
117 |
| - for name, module in model.named_modules(): |
118 |
| - if name in quant_config.ignore: |
119 |
| - continue |
120 |
| - module_type = module.__class__.__name__ |
121 |
| - if module_type in count_layer_names: |
122 |
| - count_layer_num[module_type] += 1 |
123 |
| - |
124 |
| - assert len(count_layer_num) > 0, f"None of {count_layer_names} found in model" |
125 |
| - assert all(value > 0 for value in count_layer_num.values()) |
126 |
| - |
127 | 115 | # apply quant config to model
|
128 | 116 | apply_quantization_config(model, quant_config)
|
129 | 117 |
|
130 | 118 | # check for correct application of quant config
|
131 |
| - for name, module in model.named_modules(): |
132 |
| - if name in quant_config.ignore: |
133 |
| - continue |
134 |
| - module_type = module.__class__.__name__ |
135 |
| - if module_type in count_layer_names: |
136 |
| - count_layer_num[module_type] -= 1 |
137 |
| - _inputs = module_type == "Linear" |
138 |
| - _weights = not module_type == "LlamaRotaryEmbedding" |
139 |
| - _test_layer_quantization_status(module, inputs=_inputs, weights=_weights) |
140 |
| - |
141 |
| - assert all( |
142 |
| - value == 0 for value in count_layer_num.values() |
143 |
| - ), "Not all values are zero" |
144 |
| - |
145 |
| - # test quantization compression |
146 |
| - # sample forward pass to fill scales, zps |
147 |
| - model(torch.zeros((1, 1), dtype=int), torch.zeros((1, 1), dtype=int)) |
148 |
| - quant_config.quantization_status = QuantizationStatus.COMPRESSED |
149 |
| - apply_quantization_config(model, quant_config) |
150 |
| - for name, module in model.named_modules(): |
151 |
| - if name in quant_config.ignore: |
152 |
| - continue |
153 |
| - module_type = module.__class__.__name__ |
154 |
| - if module_type == "Linear": |
| 119 | + for quant_scheme in quant_config.config_groups.values(): |
| 120 | + for name, module in match_named_modules( |
| 121 | + model, quant_scheme.targets, quant_config.ignore |
| 122 | + ): |
155 | 123 | _test_layer_quantization_status(
|
156 | 124 | module,
|
157 |
| - inputs=True, |
158 |
| - weights=True, |
159 |
| - expected_status=QuantizationStatus.COMPRESSED, |
| 125 | + inputs=quant_scheme.input_activations is not None, |
| 126 | + weights=quant_scheme.weights is not None, |
| 127 | + expected_status=QuantizationStatus.INITIALIZED, |
160 | 128 | )
|
161 | 129 |
|
162 | 130 |
|
|
0 commit comments