|
40 | 40 | is_kv_cache_quant_scheme,
|
41 | 41 | )
|
42 | 42 | from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module
|
| 43 | +from compressed_tensors.utils.match import match_named_modules |
43 | 44 | from compressed_tensors.utils.offload import update_parameter_data
|
44 | 45 | from compressed_tensors.utils.safetensors_load import get_safetensors_folder
|
45 | 46 | from safetensors import safe_open
|
@@ -147,47 +148,35 @@ def apply_quantization_config(
|
147 | 148 | if run_compressed:
|
148 | 149 | from compressed_tensors.linear.compressed_linear import CompressedLinear
|
149 | 150 |
|
150 |
| - # list of submodules to ignore |
151 |
| - ignored_submodules = defaultdict(list) |
152 | 151 | # mark appropriate layers for quantization by setting their quantization schemes
|
153 |
| - for name, submodule in model.named_modules(): |
154 |
| - # potentially fix module name to remove FSDP wrapper prefix |
155 |
| - name = fix_fsdp_module_name(name) |
156 |
| - if matches := find_name_or_class_matches(name, submodule, config.ignore): |
157 |
| - for match in matches: |
158 |
| - ignored_submodules[match].append(name) |
159 |
| - continue # layer matches ignore list, continue |
160 |
| - |
161 |
| - targets = find_name_or_class_matches(name, submodule, target_to_scheme) |
162 |
| - |
163 |
| - if targets: |
164 |
| - # mark modules to be quantized by adding |
165 |
| - # quant scheme to the matching layers |
166 |
| - scheme = _scheme_from_targets(target_to_scheme, targets, name) |
167 |
| - if run_compressed: |
168 |
| - format = config.format |
169 |
| - if format != CompressionFormat.dense.value: |
170 |
| - if isinstance(submodule, torch.nn.Linear): |
171 |
| - # TODO: expand to more module types |
172 |
| - compressed_linear = CompressedLinear.from_linear( |
173 |
| - submodule, |
174 |
| - quantization_scheme=scheme, |
175 |
| - quantization_format=format, |
176 |
| - ) |
177 |
| - replace_module(model, name, compressed_linear) |
178 |
| - |
179 |
| - # target matched - add layer and scheme to target list |
180 |
| - submodule.quantization_scheme = scheme |
181 |
| - |
182 |
| - names_to_scheme[name] = submodule.quantization_scheme |
183 |
| - |
184 |
| - if config.ignore is not None and ignored_submodules is not None: |
185 |
| - if set(config.ignore) - set(ignored_submodules): |
186 |
| - _LOGGER.warning( |
187 |
| - "Some layers that were to be ignored were " |
188 |
| - "not found in the model: " |
189 |
| - f"{set(config.ignore) - set(ignored_submodules)}" |
190 |
| - ) |
| 152 | + for name, submodule, matched_targets in match_named_modules( |
| 153 | + model, |
| 154 | + target_to_scheme, |
| 155 | + config.ignore or [], |
| 156 | + warn_on_fail=True, |
| 157 | + warn_on_unmatched_ignores=True, |
| 158 | + return_matched_targets=True, |
| 159 | + preprocess_name=fix_fsdp_module_name, |
| 160 | + ): |
| 161 | + # mark modules to be quantized by adding |
| 162 | + # quant scheme to the matching layers |
| 163 | + scheme = _scheme_from_targets(target_to_scheme, matched_targets, name) |
| 164 | + if run_compressed: |
| 165 | + format = config.format |
| 166 | + if format != CompressionFormat.dense.value: |
| 167 | + if isinstance(submodule, torch.nn.Linear): |
| 168 | + # TODO: expand to more module types |
| 169 | + compressed_linear = CompressedLinear.from_linear( |
| 170 | + submodule, |
| 171 | + quantization_scheme=scheme, |
| 172 | + quantization_format=format, |
| 173 | + ) |
| 174 | + replace_module(model, name, compressed_linear) |
| 175 | + |
| 176 | + # target matched - add layer and scheme to target list |
| 177 | + submodule.quantization_scheme = scheme |
| 178 | + |
| 179 | + names_to_scheme[name] = submodule.quantization_scheme |
191 | 180 |
|
192 | 181 | # apply current quantization status across all targeted layers
|
193 | 182 | apply_quantization_status(model, config.quantization_status)
|
@@ -429,7 +418,6 @@ def _scheme_from_targets(
|
429 | 418 | def _merge_schemes(
|
430 | 419 | schemes_to_merge: List[QuantizationScheme], name: str
|
431 | 420 | ) -> QuantizationScheme:
|
432 |
| - |
433 | 421 | kv_cache_quantization_scheme = [
|
434 | 422 | scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme)
|
435 | 423 | ]
|
|
0 commit comments