diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index f567e5a6..e1fed34f 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -41,7 +41,6 @@ apply_quantization_config, load_pretrained_quantization_parameters, ) -from compressed_tensors.quantization.lifecycle import expand_target_names from compressed_tensors.quantization.utils import is_module_quantized from compressed_tensors.utils import ( align_module_device, @@ -58,6 +57,7 @@ fix_fsdp_module_name, is_compressed_tensors_config, ) +from compressed_tensors.utils.match import match_named_modules from torch import Tensor from torch.nn import Module from tqdm import tqdm @@ -292,13 +292,15 @@ def get_missing_module_keys(self, model: Module) -> List[str]: self.sparsity_compressor and self.sparsity_config.format != CompressionFormat.dense.value ): - sparse_targets = expand_target_names( + sparse_targets = match_named_modules( model=model, targets=self.sparsity_config.targets, ignore=self.sparsity_config.ignore, ) + missing_keys.update( - merge_names(target, "weight") for target in sparse_targets + merge_names(target_name, "weight") + for target_name, _module in sparse_targets ) # Determine missing keys due to pack quantization @@ -308,13 +310,14 @@ def get_missing_module_keys(self, model: Module) -> List[str]: == CompressionFormat.pack_quantized.value ): for scheme in self.quantization_config.config_groups.values(): - quant_targets = expand_target_names( + quant_targets = match_named_modules( model=model, targets=scheme.targets, ignore=self.quantization_config.ignore, ) missing_keys.update( - merge_names(target, "weight") for target in quant_targets + merge_names(target_name, "weight") + for target_name, _module in quant_targets ) return list(missing_keys) @@ -345,28 +348,28 @@ def get_unexpected_file_keys(self, model: Module) -> List[str]: self.sparsity_compressor and self.sparsity_config.format != CompressionFormat.dense.value ): - sparse_targets: Set[str] = expand_target_names( + sparse_targets = match_named_modules( model=model, targets=self.sparsity_config.targets, ignore=self.sparsity_config.ignore, ) unexpected_keys.update( - merge_names(target, param) - for target in sparse_targets + merge_names(target_name, param) + for target_name, _module in sparse_targets for param in self.sparsity_compressor.compression_param_names ) # Identify unexpected keys from quantization compression if self.quantization_compressor: for scheme in self.quantization_config.config_groups.values(): - quant_targets: Set[str] = expand_target_names( + quant_targets = match_named_modules( model=model, targets=scheme.targets, ignore=self.quantization_config.ignore, ) unexpected_keys.update( - merge_names(target, param) - for target in quant_targets + merge_names(target_name, param) + for target_name, _module in quant_targets for param in self.quantization_compressor.compression_param_names if param != "weight" ) @@ -383,58 +386,65 @@ def compress_model(self, model: Module): :param model: model containing parameters to compress """ module_to_scheme = map_module_to_scheme(model) - sparse_compression_targets: Set[str] = expand_target_names( - model=model, - targets=self.sparsity_config.targets if self.sparsity_config else [], - ignore=self.sparsity_config.ignore if self.sparsity_config else [], - ) - - for prefix, module in tqdm(model.named_modules(), desc="Compressing model"): - - if prefix in module_to_scheme or prefix in sparse_compression_targets: - module_device = get_execution_device(module) - is_meta = module_device.type == "meta" - - exec_device = "meta" if is_meta else "cpu" - onloading_device = "meta" if is_meta else module_device - - # in the future, support compression on same device - with align_module_device(module, execution_device=exec_device): - state_dict = { - f"{prefix}.{name}": param - for name, param in module.named_parameters(recurse=False) - } - - # quantization first - if prefix in module_to_scheme: - state_dict = self.quantization_compressor.compress( - state_dict, - names_to_scheme=module_to_scheme, - show_progress=False, - compression_device=exec_device, - ) + sparse_compression_targets = [ + module_name + for module_name, _module in match_named_modules( + model=model, + targets=self.sparsity_config.targets if self.sparsity_config else [], + ignore=self.sparsity_config.ignore if self.sparsity_config else [], + ) + ] + for prefix, module in tqdm( + match_named_modules( + model, + [*sparse_compression_targets, *module_to_scheme.keys()], + warn_on_fail=True, + ), + desc="Compressing model", + ): + module_device = get_execution_device(module) + is_meta = module_device.type == "meta" + + exec_device = "meta" if is_meta else "cpu" + onloading_device = "meta" if is_meta else module_device + + # in the future, support compression on same device + with align_module_device(module, execution_device=exec_device): + state_dict = { + f"{prefix}.{name}": param + for name, param in module.named_parameters(recurse=False) + } + + # quantization first + if prefix in module_to_scheme: + state_dict = self.quantization_compressor.compress( + state_dict, + names_to_scheme=module_to_scheme, + show_progress=False, + compression_device=exec_device, + ) - # sparsity second - if prefix in sparse_compression_targets: - state_dict = self.sparsity_compressor.compress( - state_dict, - compression_targets=sparse_compression_targets, - show_progress=False, - ) + # sparsity second + if prefix in sparse_compression_targets: + state_dict = self.sparsity_compressor.compress( + state_dict, + compression_targets=sparse_compression_targets, + show_progress=False, + ) - # remove any existing parameters - offload_device = get_offloaded_device(module) - for name, _ in list(module.named_parameters(recurse=False)): - delete_offload_parameter(module, name) + # remove any existing parameters + offload_device = get_offloaded_device(module) + for name, _ in list(module.named_parameters(recurse=False)): + delete_offload_parameter(module, name) - # replace with compressed parameters - for name, value in state_dict.items(): - name = name.removeprefix(f"{prefix}.") - value = value.to(onloading_device) - param = torch.nn.Parameter(value, requires_grad=False) - register_offload_parameter(module, name, param, offload_device) + # replace with compressed parameters + for name, value in state_dict.items(): + name = name.removeprefix(f"{prefix}.") + value = value.to(onloading_device) + param = torch.nn.Parameter(value, requires_grad=False) + register_offload_parameter(module, name, param, offload_device) - module.quantization_status = QuantizationStatus.COMPRESSED + module.quantization_status = QuantizationStatus.COMPRESSED # TODO: consider sparse compression to also be compression if ( @@ -451,55 +461,64 @@ def decompress_model(self, model: Module): :param model: model containing parameters to compress """ module_to_scheme = map_module_to_scheme(model) - sparse_compression_targets: Set[str] = expand_target_names( - model=model, - targets=self.sparsity_config.targets if self.sparsity_config else [], - ignore=self.sparsity_config.ignore if self.sparsity_config else [], - ) - - for prefix, module in tqdm(model.named_modules(), desc="Decompressing model"): - if prefix in module_to_scheme or prefix in sparse_compression_targets: - # in the future, support decompression on same device - with align_module_device(module, execution_device="cpu"): - state_dict = { - f"{prefix}.{name}": param - for name, param in module.named_parameters(recurse=False) - } - - # sparsity first - if prefix in sparse_compression_targets: - # sparse_compression_targets are automatically inferred by this fn - generator = self.sparsity_compressor.decompress_from_state_dict( + sparse_compression_targets = [ + module_name + for module_name, _module in match_named_modules( + model=model, + targets=self.sparsity_config.targets if self.sparsity_config else [], + ignore=self.sparsity_config.ignore if self.sparsity_config else [], + ) + ] + + for prefix, module in tqdm( + match_named_modules( + model, + [*sparse_compression_targets, *module_to_scheme.keys()], + warn_on_fail=True, + ), + desc="Decompressing model", + ): + # in the future, support decompression on same device + with align_module_device(module, execution_device="cpu"): + state_dict = { + f"{prefix}.{name}": param + for name, param in module.named_parameters(recurse=False) + } + + # sparsity first + if prefix in sparse_compression_targets: + # sparse_compression_targets are automatically inferred by this fn + generator = self.sparsity_compressor.decompress_from_state_dict( + state_dict, + ) + # generates (param_path, param_val) + # of compressed and unused params + state_dict = {key: value for key, value in generator} + + # quantization second + if prefix in module_to_scheme: + state_dict = ( + self.quantization_compressor.decompress_module_from_state_dict( + prefix, state_dict, + scheme=module_to_scheme[prefix], ) - # generates (param_path, param_val) - # of compressed and unused params - state_dict = {key: value for key, value in generator} - - # quantization second - if prefix in module_to_scheme: - state_dict = ( - self.quantization_compressor.decompress_module_from_state_dict( - prefix, - state_dict, - scheme=module_to_scheme[prefix], - ) - ) + ) - # remove any existing parameters - exec_device = get_execution_device(module) - offload_device = get_offloaded_device(module) - for name, _ in list(module.named_parameters(recurse=False)): - delete_offload_parameter(module, name) + # remove any existing parameters + exec_device = get_execution_device(module) + offload_device = get_offloaded_device(module) + for name, _ in list(module.named_parameters(recurse=False)): + delete_offload_parameter(module, name) - # replace with decompressed parameters - for name, value in state_dict.items(): - name = name.removeprefix(f"{prefix}.") - value = value.to(exec_device) - param = torch.nn.Parameter(value, requires_grad=False) - register_offload_parameter(module, name, param, offload_device) + # replace with decompressed parameters + for name, value in state_dict.items(): + name = name.removeprefix(f"{prefix}.") + value = value.to(exec_device) + param = torch.nn.Parameter(value, requires_grad=False) + register_offload_parameter(module, name, param, offload_device) - module.quantization_status = QuantizationStatus.FROZEN + module.quantization_status = QuantizationStatus.FROZEN # ----- state dict compression pathways ----- # @@ -535,11 +554,14 @@ def compress( ) if self.sparsity_compressor is not None: - sparse_compression_targets: Set[str] = expand_target_names( - model=model, - targets=self.sparsity_config.targets, - ignore=self.sparsity_config.ignore, - ) + sparse_compression_targets: Set[str] = { + module_name + for module_name, _module in match_named_modules( + model=model, + targets=self.sparsity_config.targets, + ignore=self.sparsity_config.ignore, + ) + } state_dict = self.sparsity_compressor.compress( state_dict, compression_targets=sparse_compression_targets, @@ -598,7 +620,6 @@ def decompress(self, model_path: str, model: Module): with override_quantization_status( self.quantization_config, QuantizationStatus.FROZEN ): - names_to_scheme = apply_quantization_config( model, self.quantization_config ) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 7afd2aba..10483d11 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -13,12 +13,11 @@ # limitations under the License. import logging -import re -from collections import OrderedDict, defaultdict +from collections import OrderedDict from copy import deepcopy from typing import Dict, Iterable, List, Optional from typing import OrderedDict as OrderedDictType -from typing import Set, Union +from typing import Union import torch from compressed_tensors.config import CompressionFormat @@ -39,7 +38,8 @@ infer_quantization_status, is_kv_cache_quant_scheme, ) -from compressed_tensors.utils.helpers import fix_fsdp_module_name, replace_module +from compressed_tensors.utils.helpers import deprecated, replace_module +from compressed_tensors.utils.match import match_named_modules, match_targets from compressed_tensors.utils.offload import update_parameter_data from compressed_tensors.utils.safetensors_load import get_safetensors_folder from safetensors import safe_open @@ -51,8 +51,6 @@ "apply_quantization_config", "apply_quantization_status", "find_name_or_class_matches", - "expand_target_names", - "is_target", ] from compressed_tensors.quantization.utils.helpers import is_module_quantized @@ -147,47 +145,30 @@ def apply_quantization_config( if run_compressed: from compressed_tensors.linear.compressed_linear import CompressedLinear - # list of submodules to ignore - ignored_submodules = defaultdict(list) # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in model.named_modules(): - # potentially fix module name to remove FSDP wrapper prefix - name = fix_fsdp_module_name(name) - if matches := find_name_or_class_matches(name, submodule, config.ignore): - for match in matches: - ignored_submodules[match].append(name) - continue # layer matches ignore list, continue - - targets = find_name_or_class_matches(name, submodule, target_to_scheme) - - if targets: - # mark modules to be quantized by adding - # quant scheme to the matching layers - scheme = _scheme_from_targets(target_to_scheme, targets, name) - if run_compressed: - format = config.format - if format != CompressionFormat.dense.value: - if isinstance(submodule, torch.nn.Linear): - # TODO: expand to more module types - compressed_linear = CompressedLinear.from_linear( - submodule, - quantization_scheme=scheme, - quantization_format=format, - ) - replace_module(model, name, compressed_linear) - - # target matched - add layer and scheme to target list - submodule.quantization_scheme = scheme - - names_to_scheme[name] = submodule.quantization_scheme - - if config.ignore is not None and ignored_submodules is not None: - if set(config.ignore) - set(ignored_submodules): - _LOGGER.warning( - "Some layers that were to be ignored were " - "not found in the model: " - f"{set(config.ignore) - set(ignored_submodules)}" - ) + for name, submodule in match_named_modules( + model, target_to_scheme, config.ignore or [], warn_on_fail=True + ): + # mark modules to be quantized by adding + # quant scheme to the matching layers + matched_targets = match_targets(name, submodule, target_to_scheme) + scheme = _scheme_from_targets(target_to_scheme, matched_targets, name) + if run_compressed: + format = config.format + if format != CompressionFormat.dense.value: + if isinstance(submodule, torch.nn.Linear): + # TODO: expand to more module types + compressed_linear = CompressedLinear.from_linear( + submodule, + quantization_scheme=scheme, + quantization_format=format, + ) + replace_module(model, name, compressed_linear) + + # target matched - add layer and scheme to target list + submodule.quantization_scheme = scheme + + names_to_scheme[name] = submodule.quantization_scheme # apply current quantization status across all targeted layers apply_quantization_status(model, config.quantization_status) @@ -262,54 +243,10 @@ def apply_quantization_status(model: Module, status: QuantizationStatus): model.apply(compress_quantized_weights) -def expand_target_names( - model: Module, - targets: Optional[Iterable[str]] = None, - ignore: Optional[Iterable[str]] = None, -) -> Set[str]: - """ - Finds all unique module names in the model that match the given - targets and ignore lists. - - Note: Targets must be regexes, layer types, or full layer names. - - :param model: model to search for targets in - :param targets: Iterable of targets to search for - :param ignore: Iterable of targets to ignore - :return: set of all targets that match the given targets and should - not be ignored - """ - return { - name - for name, module in model.named_modules() - if is_target(name, module, targets, ignore) - } - - -def is_target( - name: str, - module: Module, - targets: Optional[Iterable[str]] = None, - ignore: Optional[Iterable[str]] = None, -) -> bool: - """ - Determines if a module should be included in the targets based on the - targets and ignore lists. - - Note: Targets must be regexes, layer types, or full layer names. - - :param name: name of the module - :param module: the module itself - :param targets: Iterable of targets to search for - :param ignore: Iterable of targets to ignore - :return: True if the module is a target and not ignored, False otherwise - """ - return bool( - find_name_or_class_matches(name, module, targets or []) - and not find_name_or_class_matches(name, module, ignore or []) - ) - - +@deprecated( + message="This function is deprecated and will be removed in a future release." + "Please use `match_targets` from `compressed_tensors.utils.match` instead." +) def find_name_or_class_matches( name: str, module: Module, targets: Iterable[str], check_contains: bool = False ) -> List[str]: @@ -322,38 +259,12 @@ def find_name_or_class_matches( 2. matches on regex patterns 3. matches on module names """ - from compressed_tensors import InternalModule - - if isinstance(module, InternalModule): - return [] - - targets = sorted(targets, key=lambda x: ("re:" in x, x)) - if isinstance(targets, Iterable): - matches = _find_matches(name, targets) + _find_matches( - module.__class__.__name__, targets, check_contains + if check_contains: + raise NotImplementedError( + "This function is deprecated, and the check_contains=True option has been removed." ) - matches = [match for match in matches if match is not None] - return matches - -def _find_matches( - value: str, targets: Iterable[str], check_contains: bool = False -) -> List[str]: - # returns all the targets that match value either - # exactly or as a regex after 're:'. if check_contains is set to True, - # additionally checks if the target string is contained with value. - matches = [] - for target in targets: - if target.startswith("re:"): - pattern = target[3:] - if re.match(pattern, value): - matches.append(target) - elif check_contains: - if target.lower() in value.lower(): - matches.append(target) - elif target == value: - matches.append(target) - return matches + return match_targets(name, module, targets) def _infer_status(model: Module) -> Optional[QuantizationStatus]: @@ -429,7 +340,6 @@ def _scheme_from_targets( def _merge_schemes( schemes_to_merge: List[QuantizationScheme], name: str ) -> QuantizationScheme: - kv_cache_quantization_scheme = [ scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme) ] diff --git a/src/compressed_tensors/utils/match.py b/src/compressed_tensors/utils/match.py index 30ead256..fa9b00ea 100644 --- a/src/compressed_tensors/utils/match.py +++ b/src/compressed_tensors/utils/match.py @@ -15,7 +15,7 @@ import logging import re from collections.abc import Generator -from typing import Iterable, Mapping, Optional, Tuple +from typing import Iterable, List, Mapping, Optional, Tuple import torch from compressed_tensors.utils.internal import InternalModule @@ -27,6 +27,7 @@ __all__ = [ "match_named_modules", "match_named_parameters", + "match_targets", "match_modules_set", "is_match", ] @@ -37,8 +38,8 @@ def match_named_modules( model: torch.nn.Module, - targets: Iterable[str], - ignore: Iterable[str] = tuple(), + targets: Iterable[str] | None, + ignore: Iterable[str] | None = None, fused: Optional[FusedMappping] = None, warn_on_fail: bool = False, ) -> Generator[Tuple[str, torch.nn.Module]]: @@ -54,14 +55,18 @@ def match_named_modules( :param warn_on_fail: if True, warns if any targets do not match any modules in model :return: generator of module names and modules """ + targets = targets or [] + ignore = ignore or [] + unmatched_targets = set(targets) + for name, module in model.named_modules(): for target in targets: if is_match(name, module, target, fused): unmatched_targets -= {target} - if not any(is_match(name, module, ign, fused) for ign in ignore): yield name, module + break if warn_on_fail: for target in unmatched_targets: @@ -72,8 +77,8 @@ def match_named_modules( def match_named_parameters( model: torch.nn.Module, - targets: Iterable[str], - ignore: Iterable[str] = tuple(), + targets: Iterable[str] | None = None, + ignore: Iterable[str] | None = None, fused: Optional[FusedMappping] = None, warn_on_fail: bool = False, ) -> Generator[Tuple[str, torch.nn.Module, torch.nn.Parameter]]: @@ -89,6 +94,9 @@ def match_named_parameters( :param warn_on_fail: if True, warns if any targets do not match any params in model :return: generator of fully-qualified param names, parent modules, and params """ + targets = targets or [] + ignore = ignore or [] + unmatched_targets = set(targets) for module_name, module in model.named_modules(): if isinstance(module, InternalModule): @@ -110,10 +118,48 @@ def match_named_parameters( ) +def match_targets( + name: str, module: torch.nn.Module, targets: Iterable[str] | None = None +) -> List[str]: + """ + Returns the targets that match the given name and module. + + :param name: the name of the module + :param module: the module to match + :param targets: the target strings, potentially containing "re:" prefixes + :return: the targets that match the given name and module + + Outputs are ordered by type: exact name match, regex name match, class name match + """ + targets = targets or [] + + if isinstance(module, InternalModule): + return [] + + # The order of the output `matches` list matters, the are arranged from most + # specific to least specific, and this order will be used when merging configs. + # The entries are sorted in the following order: + # 1. matches on exact strings + # 2. matches on regex patterns + # 3. matches on module names + + targets = sorted(targets, key=lambda x: ("re:" in x, x)) + matched_targets = [] + for target in targets: + if _match_name(name, target): + matched_targets.append(target) + + for target in targets: + if _match_class(module, target) and target not in matched_targets: + matched_targets.append(target) + + return matched_targets + + def match_modules_set( model: torch.nn.Module, - targets: Iterable[str], - ignore: Iterable[str] = tuple(), + targets: Iterable[str] | None = None, + ignore: Iterable[str] | None = None, ) -> Generator[Iterable[torch.nn.Module]]: """ Yields modules grouped with the same order and size as `targets`. @@ -151,6 +197,9 @@ def match_modules_set( :param targets: target strings, potentially containing "re:" prefixes :param ignore: targets to ignore, potentially containing "re:" prefixes """ + targets = targets or [] + ignore = ignore or [] + matches = dict.fromkeys(targets, None) for name, module in model.named_modules(): # match until we get a full set diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index 63a9a588..09e12cc5 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -28,8 +28,6 @@ from compressed_tensors.quantization.lifecycle import ( apply_quantization_config, apply_quantization_status, - expand_target_names, - is_target, ) from tests.testing_utils import requires_accelerate from transformers import AutoModelForCausalLM @@ -260,15 +258,13 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"): @requires_accelerate() @pytest.mark.parametrize( - "ignore,should_raise_warning", + "ignore", [ - [("lm_head", "re:.*gate"), False], - [("lm_head", "re:.*foobarbaz"), True], + ("lm_head", "re:.*gate"), + ("lm_head", "re:.*foobarbaz"), ], ) -def test_apply_quantization_status(caplog, ignore, should_raise_warning): - import logging - +def test_apply_quantization_status(ignore): # load a dense, unquantized tiny llama model model = get_tinyllama_model() quantization_config_dict = { @@ -292,80 +288,4 @@ def test_apply_quantization_status(caplog, ignore, should_raise_warning): config = QuantizationConfig(**quantization_config_dict) config.quantization_status = QuantizationStatus.CALIBRATION - # mismatch in the ignore key of quantization_config_dict - with caplog.at_level(logging.WARNING): - apply_quantization_config(model, config) - if should_raise_warning: - assert len(caplog.text) > 0 - else: - assert len(caplog.text) == 0 - - -@pytest.mark.parametrize( - "targets, ignore, expected_targets", - [ - ([], [], set()), - (["layer1", "layer2"], [], {"layer1", "layer2"}), - ([], ["layer1"], set()), - (["layer1", "layer2"], ["layer2"], {"layer1"}), - (["re:layer.*"], ["layer3"], {"layer1", "layer2"}), - ], -) -def test_expand_targets_with_mock(mock_model, targets, ignore, expected_targets): - expanded_targets = expand_target_names(mock_model, targets, ignore) - assert expanded_targets == expected_targets - - -@pytest.mark.parametrize( - "targets, ignore, expected_targets", - [ - ( - ["re:model.layers.[01].self_attn.q_proj"], - ["re:model.layers.1.self_attn.q_proj"], - set(["model.layers.0.self_attn.q_proj"]), - ), - ( - ["re:model.layers.[01].self_attn.q_proj"], - [], - set(["model.layers.0.self_attn.q_proj", "model.layers.1.self_attn.q_proj"]), - ), - ( - ["re:model.layers.[0-2].self_attn.q_proj"], - ["re:model.layers.1.self_attn.q_proj"], - set(["model.layers.0.self_attn.q_proj", "model.layers.2.self_attn.q_proj"]), - ), - ( - ["model.layers.0.self_attn.q_proj"], - ["model.layers.0.self_attn.q_proj"], - set(), - ), - ( - ["re:model.layers.*.self_attn.q_proj"], - ["re:model.layers.[01].self_attn.q_proj"], - set( - f"model.layers.{layer_idx}.self_attn.q_proj" - for layer_idx in range(2, 6) - ), - ), - ], -) -def test_expand_targets_with_llama_stories( - llama_stories_model, targets, ignore, expected_targets -): - expanded_targets = expand_target_names(llama_stories_model, targets, ignore) - assert expanded_targets == expected_targets - - -@pytest.mark.parametrize( - "name, targets, ignore, expected", - [ - ("layer1", ["layer1"], [], True), - ("layer1", ["layer1"], ["layer1"], False), - ("layer1", ["layer2"], [], False), - ("layer1", ["re:layer.*"], [], True), - ("layer1", ["re:layer.*"], ["re:layer1"], False), - ], -) -def test_is_target_with_mock(mock_module, name, targets, ignore, expected): - result = is_target(name, mock_module, targets, ignore) - assert result == expected + apply_quantization_config(model, config) diff --git a/tests/test_utils/test_match.py b/tests/test_utils/test_match.py index 7858c7c8..09ad7401 100644 --- a/tests/test_utils/test_match.py +++ b/tests/test_utils/test_match.py @@ -26,6 +26,15 @@ match_named_parameters, ) from compressed_tensors.utils.match import _match_class, _match_name +from transformers import AutoModelForCausalLM + + +@pytest.fixture +def llama_stories_model(): + return AutoModelForCausalLM.from_pretrained( + "Xenova/llama2.c-stories15M", + torch_dtype="auto", + ) class DummyModel(nn.Module): @@ -285,6 +294,58 @@ class InternalLinear(InternalModule, nn.Linear): matches = list(match_named_modules(linear, ["re:.*"])) assert len(matches) == 0 + @pytest.mark.parametrize( + "targets, ignore, expected_targets", + [ + ( + ["re:model.layers.[01].self_attn.q_proj"], + ["re:model.layers.1.self_attn.q_proj"], + set(["model.layers.0.self_attn.q_proj"]), + ), + ( + ["re:model.layers.[01].self_attn.q_proj"], + [], + set( + [ + "model.layers.0.self_attn.q_proj", + "model.layers.1.self_attn.q_proj", + ] + ), + ), + ( + ["re:model.layers.[0-2].self_attn.q_proj"], + ["re:model.layers.1.self_attn.q_proj"], + set( + [ + "model.layers.0.self_attn.q_proj", + "model.layers.2.self_attn.q_proj", + ] + ), + ), + ( + ["model.layers.0.self_attn.q_proj"], + ["model.layers.0.self_attn.q_proj"], + set(), + ), + ( + ["re:model.layers.*.self_attn.q_proj"], + ["re:model.layers.[01].self_attn.q_proj"], + set( + f"model.layers.{layer_idx}.self_attn.q_proj" + for layer_idx in range(2, 6) + ), + ), + ], + ) + def test_expand_targets_with_llama_stories( + self, llama_stories_model, targets, ignore, expected_targets + ): + expanded_targets = { + name + for name, _ in match_named_modules(llama_stories_model, targets, ignore) + } + assert expanded_targets == expected_targets + class TestMatchNamedParameters: """Test cases for match_named_parameters function"""