|
15 | 15 | import logging
|
16 | 16 | from collections import OrderedDict
|
17 | 17 | from copy import deepcopy
|
18 |
| -from typing import Dict, List, Optional |
| 18 | +from typing import Dict, Iterable, List, Optional |
19 | 19 | from typing import OrderedDict as OrderedDictType
|
20 | 20 | from typing import Union
|
21 | 21 |
|
|
50 | 50 | "load_pretrained_quantization_parameters",
|
51 | 51 | "apply_quantization_config",
|
52 | 52 | "apply_quantization_status",
|
| 53 | + "find_name_or_class_matches", |
53 | 54 | ]
|
54 | 55 |
|
55 | 56 | from compressed_tensors.quantization.utils.helpers import is_module_quantized
|
@@ -242,6 +243,39 @@ def apply_quantization_status(model: Module, status: QuantizationStatus):
|
242 | 243 | model.apply(compress_quantized_weights)
|
243 | 244 |
|
244 | 245 |
|
| 246 | +def find_name_or_class_matches( |
| 247 | + name: str, module: Module, targets: Iterable[str], check_contains: bool = False |
| 248 | +) -> List[str]: |
| 249 | + """ |
| 250 | + DEPRECATED: Use `match_targets` instead. |
| 251 | +
|
| 252 | + This function is deprecated and will be removed in a future release. |
| 253 | + Please use `match_targets` from `compressed_tensors.utils.match` instead. |
| 254 | +
|
| 255 | + Returns all targets that match the given name or the class name. |
| 256 | + Returns empty list otherwise. |
| 257 | + The order of the output `matches` list matters. |
| 258 | + The entries are sorted in the following order: |
| 259 | + 1. matches on exact strings |
| 260 | + 2. matches on regex patterns |
| 261 | + 3. matches on module names |
| 262 | + """ |
| 263 | + import warnings |
| 264 | + |
| 265 | + warnings.warn( |
| 266 | + "find_name_or_class_matches is deprecated and will be removed in a future release. " |
| 267 | + "Please use compressed_tensors.utils.match.match_targets instead.", |
| 268 | + DeprecationWarning, |
| 269 | + stacklevel=2, |
| 270 | + ) |
| 271 | + if check_contains: |
| 272 | + raise NotImplementedError( |
| 273 | + "This function is deprecated, and the check_contains=True option has been removed." |
| 274 | + ) |
| 275 | + |
| 276 | + return match_targets(name, module, targets) |
| 277 | + |
| 278 | + |
245 | 279 | def _infer_status(model: Module) -> Optional[QuantizationStatus]:
|
246 | 280 | for module in model.modules():
|
247 | 281 | status = getattr(module, "quantization_status", None)
|
|
0 commit comments