diff --git a/src/compressed_tensors/compressors/model_compressors/model_compressor.py b/src/compressed_tensors/compressors/model_compressors/model_compressor.py index 2d88a3b02..4b92147c3 100644 --- a/src/compressed_tensors/compressors/model_compressors/model_compressor.py +++ b/src/compressed_tensors/compressors/model_compressors/model_compressor.py @@ -264,23 +264,18 @@ def parse_quantization_config( return quantization_config - def _fetch_unique_quantization_formats(self) -> List[str]: + def _fetch_unique_quantization_formats(self) -> List[Optional[str]]: """ Get all unique compression formats present in a model. :return: list of quantization formats """ - quantization_formats = [] - for _, scheme in self.quantization_config.config_groups.items(): - if scheme.format is not None and scheme.format not in quantization_formats: - quantization_formats.append(scheme.format) + quantization_formats = set( + scheme.format for scheme in self.quantization_config.config_groups.values() + ) + quantization_formats.add(self.quantization_config.format) - if ( - len(quantization_formats) == 0 - and self.quantization_config.format - != CompressionFormat.mixed_precision.value - ): - quantization_formats.append(self.quantization_config.format) - return quantization_formats + quantization_formats -= {CompressionFormat.mixed_precision.value, None} + return list(quantization_formats) def __init__( self, @@ -314,6 +309,9 @@ def __init__( self.quantization_compressor = {} for format in self.compression_formats: + if format is None: + format = CompressionFormat.dense.value + self.quantization_compressor[ format ] = BaseCompressor.load_from_registry( @@ -703,9 +701,12 @@ def decompress(self, model_path: str, model: Module): with override_quantization_status( self.quantization_config, QuantizationStatus.FROZEN ): - names_to_scheme = apply_quantization_config( - model, self.quantization_config - ) + apply_quantization_config(model, self.quantization_config) + names_to_scheme: Set[QuantizationScheme] = { + name: getattr(module, "quantization_scheme") + for name, module in model.named_modules() + if getattr(module, "quantization_scheme", None) is not None + } # Load activation scales/zp or any other quantization parameters # Conditionally load the weight quantization parameters if we have a # dense compressor or if a sparsity compressor has already been applied @@ -811,6 +812,8 @@ def _replace_sparsity_weights(self, dense_weight_generator, model: Module): params_device = next(module.parameters()).device device = "cpu" if has_offloaded_params(module) else params_device + if not hasattr(module, param_name): + breakpoint() delattr(module, param_name) requires_grad = data.dtype in (torch.float16, torch.float32, torch.bfloat16) param = torch.nn.Parameter(data.to(device), requires_grad=requires_grad) diff --git a/src/compressed_tensors/quantization/lifecycle/apply.py b/src/compressed_tensors/quantization/lifecycle/apply.py index 828f51ec8..a4fe0544b 100644 --- a/src/compressed_tensors/quantization/lifecycle/apply.py +++ b/src/compressed_tensors/quantization/lifecycle/apply.py @@ -13,11 +13,8 @@ # limitations under the License. import logging -from collections import OrderedDict from copy import deepcopy -from typing import Dict, Iterable, List, Optional -from typing import OrderedDict as OrderedDictType -from typing import Union +from typing import Dict, Iterable, List, Optional, Union import torch from compressed_tensors.config import CompressionFormat @@ -36,7 +33,6 @@ from compressed_tensors.quantization.utils import ( KV_CACHE_TARGETS, infer_quantization_status, - is_kv_cache_quant_scheme, ) from compressed_tensors.utils.helpers import deprecated, replace_module from compressed_tensors.utils.match import match_named_modules, match_targets @@ -44,12 +40,15 @@ from compressed_tensors.utils.safetensors_load import get_safetensors_folder from safetensors import safe_open from torch.nn import Module +from transformers import PreTrainedModel __all__ = [ "load_pretrained_quantization_parameters", "apply_quantization_config", "apply_quantization_status", + "attach_scheme", + "attach_config", "find_name_or_class_matches", ] @@ -114,8 +113,10 @@ def load_pretrained_quantization_parameters( def apply_quantization_config( - model: Module, config: Union[QuantizationConfig, None], run_compressed: bool = False -) -> Dict[str, QuantizationScheme]: + model: PreTrainedModel, + config: Union[QuantizationConfig, None], + run_compressed: bool = False, +): """ Initializes the model for quantization in-place based on the given config. Optionally coverts quantizable modules to compressed_linear modules @@ -125,54 +126,54 @@ def apply_quantization_config( :param run_compressed: Whether the model will be run in compressed mode or decompressed fully on load """ - # Workaround for when HF Quantizer passes None, see PR #180 - if config is None: - return dict() + from compressed_tensors.linear.compressed_linear import CompressedLinear - # remove reference to the original `config` - # argument. This function can mutate it, and we'd - # like to keep the original `config` as it is. config = deepcopy(config) - # build mapping of targets to schemes for easier matching - # use ordered dict to preserve target ordering in config - target_to_scheme = OrderedDict() - config = process_quantization_config(config) - names_to_scheme = dict() - for scheme in config.config_groups.values(): - for target in scheme.targets: - target_to_scheme[target] = scheme + if config is None: # see PR #180 + return dict() - if run_compressed: - from compressed_tensors.linear.compressed_linear import CompressedLinear + # preprocess to support kv cache scheme + config = process_quantization_config(config) # mark appropriate layers for quantization by setting their quantization schemes - for name, submodule in match_named_modules( - model, target_to_scheme, config.ignore, warn_on_fail=True - ): - # mark modules to be quantized by adding - # quant scheme to the matching layers - matched_targets = match_targets(name, submodule, target_to_scheme) - scheme = _scheme_from_targets(target_to_scheme, matched_targets, name) - if run_compressed: - format = config.format - if format != CompressionFormat.dense.value: - if isinstance(submodule, torch.nn.Linear): - # TODO: expand to more module types - compressed_linear = CompressedLinear.from_linear( - submodule, - quantization_scheme=scheme, - quantization_format=format, - ) - replace_module(model, name, compressed_linear) - - # target matched - add layer and scheme to target list - submodule.quantization_scheme = scheme - - names_to_scheme[name] = submodule.quantization_scheme + for scheme in config.config_groups.values(): + for name, submodule in match_named_modules( + model, scheme.targets, config.ignore or [], warn_on_fail=True + ): + # attach scheme to module (with merging) + attach_scheme(submodule, scheme) + + # replace with run compressed if applicable + # FUTURE: move this to model compressor + if isinstance(submodule, torch.nn.Linear) and run_compressed: + format = config.format + if format != CompressionFormat.dense.value: + if isinstance(submodule, torch.nn.Linear): + # TODO: expand to more module types + compressed_linear = CompressedLinear.from_linear( + submodule, + quantization_scheme=scheme, + quantization_format=format, + ) + replace_module(model, name, compressed_linear) # apply current quantization status across all targeted layers apply_quantization_status(model, config.quantization_status) - return names_to_scheme + + # attach config for serialization + attach_config(model, config) + + +def attach_scheme(module: Module, scheme: QuantizationScheme): + if existing_scheme := getattr(module, "quantization_scheme", None): + scheme = scheme.merge(existing_scheme) + setattr(module, "quantization_scheme", scheme) + + +def attach_config(model: PreTrainedModel, config: QuantizationConfig): + if existing_config := getattr(model, "quantization_config", None): + config = config.merge(existing_config) + setattr(model, "quantization_config", config) def process_quantization_config(config: QuantizationConfig) -> QuantizationConfig: @@ -268,14 +269,6 @@ def find_name_or_class_matches( return match_targets(name, module, targets) -def _infer_status(model: Module) -> Optional[QuantizationStatus]: - for module in model.modules(): - status = getattr(module, "quantization_status", None) - if status is not None: - return status - return None - - def _load_quant_args_from_mapping( base_name: str, module_name: str, module: Module, mapping: Dict ): @@ -318,67 +311,3 @@ def _load_quant_args_from_mapping( state_dict_zp = f.get_tensor(f"{module_name}.{zp_name}") update_parameter_data(module, state_dict_zp, zp_name) - - -def _scheme_from_targets( - target_to_scheme: OrderedDictType[str, QuantizationScheme], - targets: List[str], - name: str, -) -> QuantizationScheme: - if len(targets) == 1: - # if `targets` iterable contains a single element - # use it as the key - return target_to_scheme[targets[0]] - - # otherwise, we need to merge QuantizationSchemes corresponding - # to multiple targets. This is most likely because `name` module - # is being target both as an ordinary quantization target, as well - # as kv cache quantization target - schemes_to_merge = [target_to_scheme[target] for target in targets] - return _merge_schemes(schemes_to_merge, name) - - -def _merge_schemes( - schemes_to_merge: List[QuantizationScheme], name: str -) -> QuantizationScheme: - kv_cache_quantization_scheme = [ - scheme for scheme in schemes_to_merge if is_kv_cache_quant_scheme(scheme) - ] - if not kv_cache_quantization_scheme: - # if the schemes_to_merge do not contain any - # kv cache QuantizationScheme - # return the first scheme (the prioritized one, - # since the order of schemes_to_merge matters) - return schemes_to_merge[0] - else: - # fetch the kv cache QuantizationScheme and the highest - # priority non-kv cache QuantizationScheme and merge them - kv_cache_quantization_scheme = kv_cache_quantization_scheme[0] - quantization_scheme = [ - scheme - for scheme in schemes_to_merge - if not is_kv_cache_quant_scheme(scheme) - ][0] - schemes_to_merge = [kv_cache_quantization_scheme, quantization_scheme] - merged_scheme = {} - for scheme in schemes_to_merge: - scheme_dict = { - k: v for k, v in scheme.model_dump().items() if v is not None - } - # when merging multiple schemes, the final target will be - # the `name` argument - hence erase the original targets - del scheme_dict["targets"] - # make sure that schemes do not "clash" with each other - overlapping_keys = set(merged_scheme.keys()) & set(scheme_dict.keys()) - if overlapping_keys: - raise ValueError( - f"The module: {name} is being modified by two clashing " - f"quantization schemes, that jointly try to override " - f"properties: {overlapping_keys}. Fix the quantization config " - "so that it is not ambiguous." - ) - merged_scheme.update(scheme_dict) - - merged_scheme.update(targets=[name]) - - return QuantizationScheme(**merged_scheme) diff --git a/src/compressed_tensors/quantization/quant_config.py b/src/compressed_tensors/quantization/quant_config.py index 42df3a337..93a2d19c1 100644 --- a/src/compressed_tensors/quantization/quant_config.py +++ b/src/compressed_tensors/quantization/quant_config.py @@ -21,12 +21,7 @@ QuantizationScheme, preset_name_to_scheme, ) -from compressed_tensors.quantization.utils import ( - is_module_quantized, - module_type, - parse_out_kv_cache_args, -) -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from torch.nn import Module @@ -35,7 +30,6 @@ "QuantizationConfig", "LIFECYCLE_ORDER", "DEFAULT_QUANTIZATION_METHOD", - "DEFAULT_QUANTIZATION_FORMAT", ] @@ -102,7 +96,6 @@ def __le__(self, other): ] DEFAULT_QUANTIZATION_METHOD = "compressed-tensors" -DEFAULT_QUANTIZATION_FORMAT = "fakequant" class QuantizationConfig(BaseModel): @@ -138,7 +131,7 @@ class QuantizationConfig(BaseModel): config_groups: Dict[str, Union[QuantizationScheme, List[str]]] quant_method: str = DEFAULT_QUANTIZATION_METHOD kv_cache_scheme: Optional[QuantizationArgs] = None - format: str = DEFAULT_QUANTIZATION_FORMAT + format: str = CompressionFormat.dense.value quantization_status: QuantizationStatus = QuantizationStatus.INITIALIZED global_compression_ratio: Optional[float] = None ignore: Optional[List[str]] = Field(default_factory=list) @@ -159,97 +152,75 @@ def model_post_init(self, __context): targets=targets_or_scheme, ) - def to_dict(self): - # for compatibility with HFQuantizer - return self.model_dump() + @field_validator("format", mode="before") + def validate_format(cls, value: Any) -> str: + if value is None: + return CompressionFormat.dense.value - @staticmethod - def from_pretrained( - model: Module, format: Optional[str] = None - ) -> Optional["QuantizationConfig"]: - """ - Converts a model into its associated QuantizationConfig based on the - QuantizationScheme attached to each quantized module + if isinstance(value, list): + if len(value) == 0: + return CompressionFormat.dense.value + + if all(v == value[0] for v in value): + return QuantizationScheme.validate_format(value[0]) - :param model: model to calculate quantization scheme of - :return: filled out QuantizationScheme for the input model - """ - quant_scheme_to_layers = [] - quantization_status = None - ignore = {} - quantization_type_names = set() - for name, submodule in model.named_modules(): - layer_type = module_type(submodule) - if not is_module_quantized(submodule): - if layer_type not in ignore: - ignore[layer_type] = [] - ignore[layer_type].append(name) else: - quantization_status = submodule.quantization_status - scheme = submodule.quantization_scheme - quantization_type_names.add(layer_type) - - match_found = False - for existing_scheme in quant_scheme_to_layers: - if scheme == existing_scheme: - match_found = True - break - if not match_found: - quant_scheme_to_layers.append(scheme) - - if len(quant_scheme_to_layers) == 0: # No quantized layers - return None - - # kv-cache only, no weight/activation quantization - if ( - len(quantization_type_names) == 1 - and "attention" in list(quantization_type_names)[0].lower() - ): - quantization_type_names.add("Linear") - - # clean up ignore list, we can leave out layers types if none of the - # instances are quantized - consolidated_ignore = [] - for layer_type, ignore_names in ignore.items(): - if layer_type in quantization_type_names: - # specific layers of a quantized type are ignored - consolidated_ignore += ignore_names - # else we leave it off the ignore list, doesn't fall under any of the - # existing quantization schemes so it won't be quantized - - kv_cache_args, quant_scheme_to_layers = parse_out_kv_cache_args( - quant_scheme_to_layers - ) - kv_cache_scheme = ( - kv_cache_args.model_dump() if kv_cache_args is not None else kv_cache_args - ) + return CompressionFormat.mixed_precision.value - config_groups = {} - for idx, scheme in enumerate(quant_scheme_to_layers): - group_name = "group_" + str(idx) - config_groups[group_name] = scheme + return CompressionFormat(value).value - if format is None: - if quantization_status == QuantizationStatus.COMPRESSED: - format = CompressionFormat.int_quantized.value - else: - format = CompressionFormat.dense.value - elif isinstance(format, list): - format = ( - CompressionFormat.mixed_precision.value - if len(format) > 1 - else format[0] + def merge(self, other: "QuantizationConfig") -> "QuantizationConfig": + def merge_field(field_name: str, value_a: Any, value_b: Any) -> Any: + if field_name == "config_groups": + return value_a + value_b + + if field_name == "format": + return self.validate_format([value_a, value_b]) + + if value_a is not None and value_b is None: + return value_a + + if value_a is None and value_b is not None: + return value_b + + if value_a == value_b: + return value_a + + if field_name == "ignore": + if set(value_a) == set(value_b): + return value_a + + raise NotImplementedError( + "Cannot merge quantization configs with non-identical ignore lists " + "Please modify your config to resolve this ambiguity." + f"\n{self}\n{other}" + ) + + raise ValueError( + "The following fields have overlapping targets and conflicting values " + f"for `{field_name}`. Please modify your config to resolve this " + f"ambiguity.\n{self}\n{other}" ) - return QuantizationConfig( - config_groups=config_groups, - quantization_status=quantization_status, - kv_cache_scheme=kv_cache_scheme, - global_compression_ratio=None, - format=format, - ignore=consolidated_ignore, + dict_a = self.model_dump() + dict_b = other.model_dump() + + assert dict_a.keys() == dict_b.keys() + return self.model_validate( + {key: merge_field(key, dict_a[key], dict_b[key]) for key in dict_a.keys()} ) + @classmethod + def from_pretrained( + cls, model: Module, format: Optional[str] = None + ) -> "QuantizationConfig": + default_config = QuantizationConfig(config_groups={}) + config = getattr(model, "quantization_config", default_config) + + # silently override format + config.format = cls.validate_format(format) + return config + def requires_calibration_data(self): if self.kv_cache_scheme is not None: return True @@ -264,5 +235,9 @@ def requires_calibration_data(self): return False + def to_dict(self): + # for compatibility with HFQuantizer + return self.model_dump() + # TODO set `extra="forbid"` when upstream transformers is compatible model_config = ConfigDict(extra="ignore") diff --git a/src/compressed_tensors/quantization/quant_scheme.py b/src/compressed_tensors/quantization/quant_scheme.py index a9c8b45a2..630704f2c 100644 --- a/src/compressed_tensors/quantization/quant_scheme.py +++ b/src/compressed_tensors/quantization/quant_scheme.py @@ -14,7 +14,7 @@ import warnings from copy import deepcopy -from typing import List, Optional +from typing import Any, List, Optional from compressed_tensors.config import CompressionFormat from compressed_tensors.quantization.quant_args import ( @@ -23,7 +23,7 @@ QuantizationStrategy, QuantizationType, ) -from pydantic import BaseModel, ConfigDict, model_validator +from pydantic import BaseModel, ConfigDict, field_validator, model_validator __all__ = [ @@ -50,6 +50,8 @@ class QuantizationScheme(BaseModel): weights: Optional[QuantizationArgs] = None input_activations: Optional[QuantizationArgs] = None output_activations: Optional[QuantizationArgs] = None + + # set exclusively by infer_and_set_per_module_quantization_format format: Optional[str] = None @model_validator(mode="after") @@ -91,6 +93,51 @@ def validate_model_after(model: "QuantizationScheme") -> "QuantizationScheme": return model + @field_validator("format", mode="before") + def validate_format(cls, value: Any) -> str: + if value is None: + return None + + return CompressionFormat(value).value + + def merge(self, other: "QuantizationScheme") -> "QuantizationScheme": + def merge_field(field_name: str, value_a: Any, value_b: Any) -> Any: + if field_name == "targets": + return list(set(value_a + value_b)) + + if field_name == "kv_cache_only": + # nones defer to other value + if value_a is None: + return value_b + if value_b is None: + return value_a + + # kv_cache_only=True overrides + return not ((not value_a) or (not value_b)) + + if value_a is not None and value_b is None: + return value_a + + if value_a is None and value_b is not None: + return value_b + + if value_a == value_b: + return value_a + + raise ValueError( + "The following fields have overlapping targets and conflicting values " + f"for `{field_name}`. Please modify your config to resolve this " + f"ambiguity.\n{self}\n{other}" + ) + + dict_a = self.model_dump() + dict_b = other.model_dump() + + assert dict_a.keys() == dict_b.keys() + return self.model_validate( + {key: merge_field(key, dict_a[key], dict_b[key]) for key in dict_a.keys()} + ) + model_config = ConfigDict(extra="forbid") diff --git a/tests/test_quantization/lifecycle/test_apply.py b/tests/test_quantization/lifecycle/test_apply.py index d5fd6c2cd..13e4103c7 100644 --- a/tests/test_quantization/lifecycle/test_apply.py +++ b/tests/test_quantization/lifecycle/test_apply.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import re from collections import defaultdict from typing import Optional from unittest.mock import MagicMock @@ -57,53 +56,6 @@ def llama_stories_model(): ) -def test_target_prioritization(mock_frozen): - # tests that the config_groups are applied in the correct order - # of priority, where exact layer name > regex > module name - config = { - "quant_method": "compressed-tensors", - "format": "fakequant", - "config_groups": { - "group_1": { - "weights": { - "num_bits": 8, - }, - "targets": ["Linear"], - }, - "group_2": { - "weights": { - "num_bits": 4, - }, - "targets": ["re:.*down_proj"], - }, - "group_3": { - "weights": { - "num_bits": 2, - }, - "targets": ["model.layers.0.mlp.down_proj"], - }, - }, - } - - model = AutoModelForCausalLM.from_pretrained( - "HuggingFaceM4/tiny-random-LlamaForCausalLM", torch_dtype="auto" - ) - model.eval() - - config = QuantizationConfig(**config) - config.quantization_status = QuantizationStatus.CALIBRATION - apply_quantization_config(model, config) - mock_frozen(model) - - for name, module in model.named_modules(): - if name == "model.layers.0.mlp.down_proj": - assert module.quantization_scheme.weights.num_bits == 2 - elif re.match(".*down_proj", name): - assert module.quantization_scheme.weights.num_bits == 4 - elif isinstance(module, torch.nn.Linear): - assert module.quantization_scheme.weights.num_bits == 8 - - def test_apply_quantization_config_tinyllama(): quant_config = get_sample_tinyllama_quant_config(status="calibration") model = get_tinyllama_model() @@ -174,13 +126,16 @@ def test_serialize_config_tinyllama(): serialized_config = QuantizationConfig.from_pretrained(model) assert len(serialized_config.config_groups) == 2 - assert serialized_config.config_groups["group_0"].targets == ["Embedding"] - assert serialized_config.config_groups["group_0"].input_activations is None - assert serialized_config.config_groups["group_1"].targets == ["Linear"] - assert serialized_config.config_groups["group_1"].input_activations is not None + assert serialized_config.config_groups["group_0"].targets == ["Linear"] + assert serialized_config.config_groups["group_0"].input_activations is not None + assert serialized_config.config_groups["group_1"].targets == ["Embedding"] + assert serialized_config.config_groups["group_1"].input_activations is None assert serialized_config.format == CompressionFormat.dense.value assert serialized_config.quant_method == DEFAULT_QUANTIZATION_METHOD - assert serialized_config.ignore == ["model.layers.1.mlp.down_proj"] + assert serialized_config.ignore == [ + "LlamaRotaryEmbedding", + "model.layers.1.mlp.down_proj", + ] if serialized_config.global_compression_ratio is not None: assert serialized_config.global_compression_ratio > 1.0 assert serialized_config.global_compression_ratio < 8.0 @@ -221,11 +176,11 @@ def get_tinyllama_model(): def get_sample_tinyllama_quant_config(status: str = "frozen"): config_dict = { "quant_method": "compressed-tensors", - "format": "fakequant", + "format": "dense", "quantization_status": status, "global_compression_ratio": None, "config_groups": { - "group_1": { + "group_0": { "weights": { "num_bits": 8, "type": "int", @@ -240,7 +195,7 @@ def get_sample_tinyllama_quant_config(status: str = "frozen"): }, "targets": ["Linear"], }, - "group_2": { + "group_1": { "weights": { "num_bits": 8, "type": "int", diff --git a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py index 3ac91e851..720df208e 100644 --- a/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py +++ b/tests/test_quantization/lifecycle/test_dynamic_lifecycle.py @@ -86,7 +86,7 @@ def get_tinyllama_model(): def get_sample_dynamic_tinyllama_quant_config(): config_dict = { "quant_method": "compressed-tensors", - "format": "fakequant", + "format": "dense", "quantization_status": "calibration", "global_compression_ratio": None, "config_groups": { diff --git a/tests/test_quantization/test_quant_config.py b/tests/test_quantization/test_quant_config.py index c3830a02d..dd00720e4 100644 --- a/tests/test_quantization/test_quant_config.py +++ b/tests/test_quantization/test_quant_config.py @@ -14,7 +14,6 @@ import pytest from compressed_tensors.quantization import ( - DEFAULT_QUANTIZATION_FORMAT, DEFAULT_QUANTIZATION_METHOD, QuantizationConfig, QuantizationScheme, @@ -29,7 +28,7 @@ def test_basic_config(): assert config.config_groups == config_groups assert config.quant_method == DEFAULT_QUANTIZATION_METHOD - assert config.format == DEFAULT_QUANTIZATION_FORMAT + assert config.format == "dense" assert config.quantization_status == QuantizationStatus.INITIALIZED assert config.global_compression_ratio is None assert isinstance(config.ignore, list) and len(config.ignore) == 0 diff --git a/tests/test_quantization/test_quant_scheme.py b/tests/test_quantization/test_quant_scheme.py index 003d41241..8e565e13f 100644 --- a/tests/test_quantization/test_quant_scheme.py +++ b/tests/test_quantization/test_quant_scheme.py @@ -61,3 +61,39 @@ def test_defaults(): assert output.input_activations is None assert output.output_activations is None assert output.format is None + + +@pytest.mark.parametrize( + "a,b,exp", + [ + ( + QuantizationScheme( + targets=["Linear"], weights=QuantizationArgs(num_bits=4) + ), + QuantizationScheme( + targets=["Attention"], input_activations=QuantizationArgs() + ), + QuantizationScheme( + targets=["Attention", "Linear"], + weights=QuantizationArgs(num_bits=4), + input_activations=QuantizationArgs(), + ), + ), + ( + QuantizationScheme( + targets=["Linear"], input_activations=QuantizationArgs(num_bits=4) + ), + QuantizationScheme( + targets=["model.layer.0.self_attn.q_proj"], + input_activations=QuantizationArgs(), + ), + "error", + ), + ], +) +def test_merge(a, b, exp): + if exp == "error": + with pytest.raises(ValueError): + a.merge(b) + else: + assert a.merge(b) == exp