diff --git a/auto_round/__init__.py b/auto_round/__init__.py index 5bb4a396c..d7be4984c 100644 --- a/auto_round/__init__.py +++ b/auto_round/__init__.py @@ -13,9 +13,8 @@ # limitations under the License. from auto_round.autoround import AutoRound -# support for old api from auto_round.autoround import AutoRoundLLM, AutoRoundMLLM, AutoRoundAdam -from auto_round.schemes import QuantizationScheme +from auto_round.schemes import QuantizationScheme, AutoScheme from auto_round.utils import LazyImport diff --git a/auto_round/__main__.py b/auto_round/__main__.py index 07bc3f273..a25fc5421 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -110,7 +110,7 @@ def __init__(self, *args, **kwargs): self.add_argument( "--scale_dtype", - default="fp16", + default=None, choices=["fp16", "float16", "bf16", "bfloat16", "fp32", "float32"], help="scale data type to use for quantization", ) @@ -470,6 +470,14 @@ def tune(args): extra_config.scheme_config = scheme_config extra_config.mllm_config = mllm_config + layer_config = {} + # from auto_round.auto_schemes.haha import get_mixed_config_layer_config + # layer_config = {} + # best_path = get_mixed_config_layer_config(model_name, target_bits=3) + # for item in best_path: + # layer_config[item[0]] = {} + # layer_config[item[0]]["bits"] = item[1] + autoround: BaseCompressor = AutoRound( model=model_name, scheme=scheme, @@ -486,6 +494,7 @@ def tune(args): not_use_best_mse=args.not_use_best_mse, enable_adam=args.adam, extra_config=extra_config, + layer_config=layer_config, ) model_name = args.model.rstrip("/") diff --git a/auto_round/auto_schemes/__init__.py b/auto_round/auto_schemes/__init__.py new file mode 100644 index 000000000..d3b055be2 --- /dev/null +++ b/auto_round/auto_schemes/__init__.py @@ -0,0 +1,39 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +AUTO_SCHEMES_ALGS = {} + + +def register_dtype(names): + """Class decorator to register a mixed precision algorithm to the registry. + + Decorator function used before a Pattern subclass. + + Args: + names: A string. Define the export type. + + Returns: + cls: The class of register. + """ + + def register(alg): + if isinstance(names, (tuple, list)): + for name in names: + AUTO_SCHEMES_ALGS[name] = alg + else: + AUTO_SCHEMES_ALGS[names] = alg + + return alg + + return register diff --git a/auto_round/auto_schemes/gen_scheme.py b/auto_round/auto_schemes/gen_scheme.py new file mode 100644 index 000000000..e009e12de --- /dev/null +++ b/auto_round/auto_schemes/gen_scheme.py @@ -0,0 +1,84 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Iterable, Union + +import torch + +from auto_round import AutoScheme +from auto_round.utils import get_layer_features + + +class GenScheme: + def __init__( + self, + auto_scheme: AutoScheme, + model: torch.nn.Module, + quant_layer_names: Iterable[str], + fixed_layer_scheme: dict[str, dict], + scale_dtype: str = "fp16", + dataset="pile-10k", + ): + self.auto_scheme = auto_scheme + self.model = model + self.quant_layer_names = quant_layer_names + self.fixed_layer_scheme = fixed_layer_scheme + self.scale_dtype = scale_dtype + self.dataset = dataset + + def _get_min_max_avg_bits(self) -> tuple[float, float]: + pass + + # not validate yet + def get_layer_bits(self, layer): + weight = layer.weight + n_param = weight.numel() + weight_bits = getattr(layer, "bits", 16) + group_size = getattr(layer, "group_size", 128) + super_group_size = getattr(layer, "super_group_size", None) + super_weight_bits = getattr(layer, "super_bits", None) + + # Main quantization cost + weight_total_bits = weight_bits * n_param + if weight_bits >= 16: # Unquantized layer + return weight_total_bits, 16 + + in_features, output_features = get_layer_features(layer) + # Determine number of groups + if group_size > 0: # group-wise + n_group = output_features * (in_features + group_size - 1) // group_size + elif group_size == 0: # per-tensor + n_group = 1 + elif group_size == -1: # per-channel + n_group = output_features # out_channels + else: + raise ValueError(f"Invalid group_size {group_size}") + aux_total_bits = 0 + if not super_group_size: + # Scale and zero point bitwidths + scale_bits = 16 + zp_bits = weight_bits if not super_group_size else 32 # default: same as weight_bits + # Overhead from scales and zero points + aux_total_bits = n_group * (scale_bits + zp_bits) + + # Double quantization case + if super_group_size: + # Number of super-groups + aux_total_bits += n_group * super_weight_bits * 2 # sclae and min int count + n_super_group = (n_group + super_group_size - 1) // super_group_size + aux_total_bits += n_super_group * 32 * 2 # double quant scale and min_v + + total_bits = weight_total_bits + aux_total_bits + avg_bits = total_bits / n_param + return total_bits, avg_bits diff --git a/auto_round/auto_schemes/utils.py b/auto_round/auto_schemes/utils.py new file mode 100644 index 000000000..e01da9913 --- /dev/null +++ b/auto_round/auto_schemes/utils.py @@ -0,0 +1,21 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +def get_total_bits(model, layer_config): + pass + + +def get_bits(layer): + pass diff --git a/auto_round/autoround.py b/auto_round/autoround.py index 4074213a9..ccdca1f09 100644 --- a/auto_round/autoround.py +++ b/auto_round/autoround.py @@ -25,7 +25,7 @@ MLLMCompressor, ) from auto_round.logger import deprecated, logger -from auto_round.schemes import QuantizationScheme +from auto_round.schemes import AutoScheme, QuantizationScheme from auto_round.utils import is_mllm_model @@ -63,7 +63,7 @@ def __new__( cls, model: Union[torch.nn.Module, str], tokenizer=None, - scheme: Union[str, dict, QuantizationScheme] = "W4A16", + scheme: Union[str, dict, QuantizationScheme, AutoScheme] = "W4A16", layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None, dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", iters: int = 200, @@ -77,7 +77,6 @@ def __new__( seed: int = 42, # for adam enable_adam: bool = False, - # for MLLM extra_config: ExtraConfig = None, **kwargs, ) -> BaseCompressor: diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index f56109525..f1036ecac 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -35,7 +35,7 @@ from auto_round.export.export_to_gguf.config import GGUF_CONFIG, GGUF_INNER_CONFIG, ModelType from auto_round.logger import logger from auto_round.low_cpu_mem.utils import get_layers_before_block -from auto_round.schemes import QuantizationScheme, preset_name_to_scheme +from auto_round.schemes import AutoScheme, QuantizationScheme, get_gguf_scheme, preset_name_to_scheme from auto_round.sign_sgd import SignSGD from auto_round.special_model_handler import _handle_moe_model from auto_round.utils import ( @@ -75,7 +75,6 @@ get_lm_head_name, get_max_vram, get_module, - get_quant_keys, get_shared_keys, htcore, infer_bits_by_data_type, @@ -91,6 +90,7 @@ mv_module_from_gpu, reset_params, set_amax_for_all_moe_layers, + set_layer_config, set_module, to_device, to_dtype, @@ -129,7 +129,7 @@ def __init__( self, model: Union[torch.nn.Module, str], tokenizer=None, - scheme: Union[str, dict, QuantizationScheme] = "W4A16", + scheme: Union[str, dict, QuantizationScheme, AutoScheme] = "W4A16", layer_config: dict[str, Union[str, dict, QuantizationScheme]] = None, dataset: Union[str, list, tuple, torch.utils.data.DataLoader] = "NeelNanda/pile-10k", iters: int = 200, @@ -202,9 +202,14 @@ def __init__( ... # ... ... } """ - self.scheme = None - self._parse_and_set_scheme(scheme, kwargs) + if isinstance(scheme, AutoScheme): # TODO AutoScheme could also be patched by group_size, etc + self.scheme = self._parse_and_set_scheme(scheme.options[0], kwargs) + else: + self.scheme = self._parse_and_set_scheme(scheme, kwargs) + gguf_scheme_name = get_gguf_scheme(self.scheme) + # GGUF uses fp32 scale dtype as default + scale_dtype = kwargs.pop("scale_dtype", "fp32") if gguf_scheme_name else kwargs.pop("scale_dtype", "fp16") # Extra/legacy kwargs for backward compatibility # Major version releases may pack them with extra configuration options amp = kwargs.pop("amp", True) @@ -217,7 +222,6 @@ def __init__( sampler = kwargs.pop("sampler", "rand") not_use_best_mse = kwargs.pop("not_use_best_mse", False) dynamic_max_gap = kwargs.pop("dynamic_max_gap", -1) - scale_dtype = kwargs.pop("scale_dtype", "fp16") nblocks = kwargs.pop("nblocks", 1) low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", False) to_quant_block_names: Union[str, list, None] = kwargs.pop("to_quant_block_names", None) @@ -231,13 +235,17 @@ def __init__( self.mllm = kwargs.pop("mllm") if "mllm" in kwargs else False # Scale factor for RAM usage per parameter. self.mem_per_param_scale = kwargs.pop("mem_per_param_scale", None) - fp_layers = kwargs.pop("fp_layers", None) + self.fp_layers = kwargs.pop("fp_layers", "") + self.layer_config = layer_config + self.supported_types = SUPPORTED_LAYER_TYPES + self.inner_supported_types = INNER_SUPPORTED_LAYER_TYPES + self.scale_dtype = convert_dtype_str2torch(scale_dtype) if kwargs: logger.warning(f"unrecognized keys {list(kwargs.keys())} were passed. Please check them.") if "CUBLAS_WORKSPACE_CONFIG" not in os.environ: os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" - # deprecated, default not to use torch.use_deterministic_algorithms + # Deprecated, default not to use torch.use_deterministic_algorithms if not disable_deterministic_algorithms or enable_deterministic_algorithms: if not disable_deterministic_algorithms: logger.warning( @@ -254,8 +262,8 @@ def __init__( if isinstance(model, str): model, tokenizer, low_cpu_mem_usage = llm_load_model( model, - device="cpu", - low_cpu_mem_mode=low_cpu_mem_usage, # always load cpu first + device="cpu", # always load cpu first + low_cpu_mem_mode=low_cpu_mem_usage, ) elif tokenizer is None and iters > 0: raise ValueError("A tokenizer must be set for non-str model input") @@ -281,6 +289,32 @@ def __init__( if device_map is None: device_map = 0 + if isinstance(scheme, AutoScheme): + if self.mllm: + logger.info("AutoScheme with MLLM is not supported yet.") + sys.exit(1) + layer_config, _ = set_layer_config( + self.model, + self.layer_config, + self.scheme, + self.scale_dtype, + self.supported_types, + self.inner_supported_types, + self.quant_block_list, + self.fp_layers, + self.quant_lm_head, + enable_gguf_official_mixed=False, + is_mllm=self.mllm, + ) + quant_layer_names = layer_config.keys() + fixed_layer_scheme = {k: v for k, v in layer_config.items() if v.get("fixed_by_user", False)} + # mainly using quant_layers and fixed by users + from auto_round.auto_schemes.gen_scheme import GenScheme + + gen_scheme = GenScheme( + scheme, self.model, quant_layer_names, fixed_layer_scheme, self.scale_dtype, self.dataset + ) + # Set device, must place after model loading self._set_device(device_map) @@ -294,15 +328,6 @@ def __init__( self.device_map = None self._set_device_map_in_blocks(self.device_map) - not_quantize_layer_names = get_fp_layer_names(self.model, fp_layers) - if len(not_quantize_layer_names) > 0: - logger.info(f"{not_quantize_layer_names} will not be quantized.") - if layer_config is None: - layer_config = {} - for name in not_quantize_layer_names: - layer_config[name] = {"bits": 16, "act_bits": 16, "data_type": "float", "act_data_type": "float"} - self._parse_layer_config(layer_config) # must place after model init - # Tuning hyperparameters self.seed = seed set_seed(self.seed) @@ -339,7 +364,6 @@ def __init__( if self.static_kv_dtype is not None: logger.warning("The static kv is experimental and currently has limited support.") - self.scale_dtype = convert_dtype_str2torch(scale_dtype) self._set_amp_dtype() self.cache_device = torch.device("cpu") if self.low_gpu_mem_usage else self.device if self.act_bits <= 8 and self.amp_dtype == torch.float16: @@ -351,8 +375,6 @@ def __init__( logger.info(f"using {self.model.dtype} for quantization tuning") # Some helpers - self.supported_types = SUPPORTED_LAYER_TYPES - self.inner_supported_types = INNER_SUPPORTED_LAYER_TYPES if "hpu" in str(self.device): self.inner_supported_types = tuple(x for x in INNER_SUPPORTED_LAYER_TYPES if x != "FP8Linear") self.batch_dim = None @@ -367,7 +389,7 @@ def __init__( import habana_frameworks.torch.core as htcore # pylint: disable=E0401 import habana_frameworks.torch.hpu as hthpu # pylint: disable=E0401] - def _set_device(self, device_map): + def _set_device(self, device_map: Union[str, torch.device, int, dict]) -> None: if hasattr(self, "device") and self.device is not None: return if isinstance(device_map, (str, torch.device, int)): @@ -391,65 +413,16 @@ def _set_device(self, device_map): else: raise TypeError(f"device_map should be [str, torch.device, int, dict], but got {type(device_map)}") - def _parse_layer_config(self, layer_config: dict[str, Union[str, dict, QuantizationScheme]]) -> None: - """Parse and set the layer-wise quantization configuration.""" - # Some other quantization configs - self.layer_config = copy.deepcopy(layer_config) if layer_config is not None else {} - scheme_keys = {f.name for f in fields(QuantizationScheme)} - - for key, item in self.layer_config.items(): - if isinstance(item, str): - config = asdict(preset_name_to_scheme(item.upper())) - elif isinstance(item, QuantizationScheme): - config = asdict(item) - elif isinstance(item, dict): - invalid_keys = set(item) - scheme_keys - if invalid_keys: - raise ValueError( - f"Invalid keys {invalid_keys} in layer_config for layer '{key}', " - f"only {scheme_keys} are supported" - ) - config = dict(item) - - # Drop None values - config = {k: v for k, v in config.items() if v is not None} - self.layer_config[key] = config - - if not self.quant_lm_head or (isinstance(self.scheme, str) and self.scheme.lower().startswith("gguf")): - return - for n, _ in self.model.named_modules(): - lm_head_layer_name = n - - if ( - hasattr(self.model, "config") - and self.model.config.tie_word_embeddings - and hasattr(self.model, "_tied_weights_keys") - ): - tied_keys = self.model._tied_weights_keys - for item in tied_keys: - if lm_head_layer_name in item: # TODO extend to encoder-decoder layer, seq classification model - self.quant_lm_head = False - logger.warning( - "reset `quant_lm_head` to `False` as quantizing lm_head with tied weights has not been " - "supported currently" - ) - break - - lm_head_layer_config = self.layer_config[lm_head_layer_name] if lm_head_layer_name in self.layer_config else {} - - for key in scheme_keys: - if key not in lm_head_layer_config: - lm_head_layer_config[key] = getattr(self, key) - - def _parse_and_set_scheme(self, scheme: Union[str, dict, QuantizationScheme], kwargs) -> None: + def _parse_and_set_scheme(self, scheme: Union[str, dict, QuantizationScheme], kwargs) -> QuantizationScheme: """Parse and set the quantization scheme.""" + res = "" if isinstance(scheme, QuantizationScheme): scheme = asdict(scheme) elif isinstance(scheme, dict): scheme = scheme elif isinstance(scheme, str): + res = scheme # gguf:q4_k_s and gguf_q4_k_m has the same dict scheme, but the result is different scheme = scheme.upper() - self.scheme = scheme scheme = asdict(preset_name_to_scheme(scheme)) scheme_keys = [f.name for f in fields(QuantizationScheme)] for key in scheme_keys: @@ -495,6 +468,9 @@ def _parse_and_set_scheme(self, scheme: Union[str, dict, QuantizationScheme], kw if supported_dtype + str(tmp_act_bits) == self.act_data_type: # could not replace FP8_e4m3 self.act_data_type = supported_dtype break + for key in scheme_keys: + scheme[key] = getattr(self, key) + return res if res else QuantizationScheme.from_dict(scheme) def _adjust_torch_compile(self, enable_torch_compile: bool) -> None: """Sets the torch compile configuration for the tuning.""" @@ -738,20 +714,20 @@ def _check_compatibility(self) -> None: " We are likely to release new algorithm for certain configurations in the future." ) - # Check group_size 32 for auto_round - if ( - self.data_type == "int" - and hasattr(self, "formats") - and any(key in fmt for fmt in self.formats for key in ("auto_round", "auto_gptq", "auto_awq")) - ): - for n, m in self.model.named_modules(): - if type(m) in self.supported_types: - if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0: - self.layer_config[n] = {"bits": 16} - logger.info( - f"{n} will not be quantized due to its shape not being divisible by 32," - " resulting in an exporting issue to autogptq" - ) + # # Check group_size 32 for auto_round + # if ( + # self.data_type == "int" + # and hasattr(self, "formats") + # and any(key in fmt for fmt in self.formats for key in ("auto_round", "auto_gptq", "auto_awq")) + # ): + # for n, m in self.model.named_modules(): + # if type(m) in self.supported_types: + # if m.weight.shape[0] % 32 != 0 or m.weight.shape[1] % 32 != 0: + # self.layer_config[n] = {"bits": 16} + # logger.info( + # f"{n} will not be quantized due to its shape not being divisible by 32," + # " resulting in an exporting issue to autogptq" + # ) if ( self.seqlen is not None @@ -800,19 +776,26 @@ def remove_duplicates(lst): formats = format.replace("q*_", f"q{self.bits}_").replace(" ", "").split(",") formats = remove_duplicates(formats) # need the keep origin order - if isinstance(self.scheme, str) and self.scheme.lower().startswith("gguf"): + gguf_format_name = get_gguf_scheme(self.scheme) + + if gguf_format_name: for i in range(len(formats)): - if formats[i] != "fake" and formats[i] != self.scheme.lower(): + if formats[i] != "fake" and formats[i] != gguf_format_name.lower(): logger.warning( - f"reset format {formats[i]} to {self.scheme.lower()} " - f"since scheme {self.scheme} can only be exported to format {self.scheme.lower()}" + f"reset format {formats[i]} to {gguf_format_name.lower()} " + f"since scheme {gguf_format_name} can only be exported to format {gguf_format_name.lower()}" ) - formats[i] = self.scheme.lower() + formats[i] = gguf_format_name.lower() _gguf_args_check(self, formats, model_type=ModelType.TEXT) if self.mllm: _gguf_args_check(self, formats, model_type=ModelType.MMPROJ) + for f in formats: + if f.startswith("gguf"): + self.scheme = f.upper() + break + for format_ in formats: if format_ not in SUPPORTED_FORMATS: logger.error(f"Unsupported format {format_}, please choose from {SUPPORTED_FORMATS}") @@ -1296,92 +1279,6 @@ def get_imatrix_hook(module, input, output): for hook in hooks: hook.remove() - def _check_need_to_quantize_lm_head_embedding(self) -> bool: - """Checks if LM head and embedding layers need quantization for GGUF format. - - This function inspects the current model's formats and determines whether - it needs to apply quantization settings to the embedding and LM head layers. - The function modifies `self.layer_config` in-place and updates the model modules. - - Returns: - bool: True if the LM head needs quantization, otherwise False. - - Raises: - NotImplementedError: If multiple non-fake GGUF formats are specified. - """ - gguf_scheme = False - if isinstance(self.scheme, str) and "gguf" in self.scheme.lower(): - gguf_scheme = True - - if not hasattr(self, "formats") and not gguf_scheme: - return False - - has_gguf: bool = gguf_scheme or any("gguf" in fmt for fmt in self.formats) - if not has_gguf: - return False - if hasattr(self, "formats"): - formats: list[str] = [fmt for fmt in self.formats if "fake" not in fmt] - if not (len(formats) == 1 and "gguf" in formats[0]): - raise NotImplementedError("Only one GGUF format can be set at a time.") - target_format: str = formats[0] - - else: - target_format = self.scheme.lower() - - tie_word_embeddings: bool = getattr(getattr(self.model, "config", None), "tie_word_embeddings", True) - for name, module in self.model.named_modules(): - if isinstance(module, torch.nn.Embedding): - key: str = "lm_head" if tie_word_embeddings else "embedding" - config: dict[str, Any] = GGUF_INNER_CONFIG[GGUF_CONFIG[target_format][key]] - self._apply_config_to_layer(name, config, True) - - if not tie_word_embeddings: - lm_head_name: str = get_lm_head_name(self.model) - config: dict[str, Any] = GGUF_CONFIG[GGUF_CONFIG[target_format]["lm_head"]] - check_fixed_by_user = ( - self.layer_config[lm_head_name].get("fixed_by_user", False) - if lm_head_name in self.layer_config - else None - ) - self._apply_config_to_layer(lm_head_name, config, check_fixed_by_user=check_fixed_by_user) - return True - - return False - - def _apply_config_to_layer( - self, - layer_name: str, - config: dict[str, Any], - check_fixed_by_user: bool = False, - ) -> None: - """Applies GGUF quantization configuration to a given layer. - - Args: - layer_name (str): Name of the layer to configure. - config (dict[str, Any]): GGUF layer configuration. - check_fixed_by_user (bool): If True, preserve user-defined settings. - """ - act_bits: int = 16 - scale_dtype: Any = self.scale_dtype - keys: list[str] = ["bits", "group_size", "super_bits", "super_group_size", "data_type", "sym"] - - self.layer_config[layer_name] = self.layer_config.get(layer_name, {}) - - for key in keys: - if ( - key in self.layer_config[layer_name] - and check_fixed_by_user - # and self.layer_config[layer_name].get("fixed_by_user", False) - ): - continue - self.layer_config[layer_name][key] = config.get(key) - setattr(get_module(self.model, layer_name), key, config.get(key)) - - self.layer_config[layer_name]["act_bits"] = act_bits - self.layer_config[layer_name]["scale_dtype"] = scale_dtype - setattr(get_module(self.model, layer_name), "act_bits", act_bits) - setattr(get_module(self.model, layer_name), "scale_dtype", scale_dtype) - def _quantize_layer_via_rtn(self, name: str) -> None: """Quantizes a layer using RTN (Round-To-Nearest) if available. @@ -1681,32 +1578,32 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: Returns: The quantized model and layer configurations. """ - for n, m in self.model.named_modules(): + for n, m in self.model.named_modules(): # TODO check if could removed m.tmp_name = n self._check_compatibility() formats = self.formats if hasattr(self, "formats") else None # It is best to modify the model structure in the quantize function and check the format, # because it may cause the gguf format to not be exported normally. self.model = _handle_moe_model(self.model, formats=formats) - self.has_qlayer_outside_block = self._set_layerwise_config(self.layer_config) + + # TODO check scale_dtype + self.layer_config, self.has_qlayer_outside_block = set_layer_config( + self.model, + self.layer_config, + self.scheme, + self.scale_dtype, + self.supported_types, + self.inner_supported_types, + self.quant_block_list, + self.fp_layers, + self.quant_lm_head, + enable_gguf_official_mixed=True, + is_mllm=self.mllm, + ) + if not hasattr(self, "formats"): logger.warning("this API is deprecated, please use `quantize_and_save` instead") else: - only_gguf = True - for format_ in self.formats: - if not ("gguf" in format_ or "fake" in format_): - only_gguf = False - break - if len(self.formats) == 1 and self.formats[0] == "fake": - only_gguf = False - if only_gguf: - self.layer_config, gguf_format_config = get_layer_config_by_gguf_format( - self.layer_config, self.formats, self.model, model_type=ModelType.TEXT - ) - if self.mllm: - self.layer_config, gguf_format_config = get_layer_config_by_gguf_format( - self.layer_config, self.formats, self.model, model_type=ModelType.MMPROJ - ) # Determine if immediate packing is required formats = self.formats if ( @@ -1818,7 +1715,7 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: cost_time = end_time - self.start_time logger.info(f"quantization tuning time {cost_time}") - ## dump a summary + # Dump a summary quantized_layers = [] unquantized_layers = [] for n, m in self.model.named_modules(): @@ -1914,141 +1811,6 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: del layer_input clear_memory(q_layer_input) - def _set_layerwise_config(self, layer_config: dict) -> bool: - """ - Sets the layer-wise configuration based on the provided `layer_config`. - By default, only quantize layers in blocks. - - Args: - layer_config (dict): The configuration dictionary for each layer containing various configuration options. - - Returns: - bool: Returns True if there are quantized layers outside the blocks (e.g., lm-head), - otherwise returns False. - """ - # Get the names of layers in quantization blocks - supported_types = self.supported_types - layers_in_blocks = get_layer_names_in_block( - self.model, supported_types, self.quant_block_list, self.inner_supported_types - ) - ##process regex in layer_config - all_supported_layer_names = [] - # List of configuration keys - keys = get_quant_keys() - - for n, m in self.model.named_modules(): - # Delete previous configuration to avoid conflicts with prior tuning - for key in keys: - if hasattr(m, key): - delattr(m, key) - - if not isinstance(m, supported_types) and m.__class__.__name__ not in self.inner_supported_types: - continue - all_supported_layer_names.append(n) - - names_in_layer_config = list(layer_config.keys()) - for name in names_in_layer_config: - if name in all_supported_layer_names: - continue - matched_names = [] - for layer_name in all_supported_layer_names: - if re.search(re.compile(name), layer_name) is not None: - matched_names.append(layer_name) - if len(matched_names) > 0: - val = layer_config[name] - layer_config.pop(name) - for match_name in matched_names: - layer_config[match_name] = val - else: - tmp_m = get_module(self.model, name) - if not isinstance(tmp_m, torch.nn.Embedding): # TODO not good code style - raise ValueError(f"key {name} in layer_config is invalid, please have a double check") - - has_qlayer_outside_block = False # Flag to track if there are quantized layers outside blocks (e.g., lm-head) - - # Iterate through all modules in the model - is_gguf = hasattr(self, "formats") and any("gguf" in format_ for format_ in self.formats) - for n, m in self.model.named_modules(): - # Skip unsupported types - if type(m) not in supported_types and m.__class__.__name__ not in self.inner_supported_types: - if n in self.layer_config: - if not isinstance(m, torch.nn.Embedding): - logger.warning(f"{n} is not supported, layer_config {n}: {layer_config[n]} will be ignored.") - self.layer_config.pop(n) - continue - if not is_gguf: - if not check_to_quantized(layer_config[n]): - self.layer_config.pop(n) - continue - else: - continue - - # If the layer is not in the config and is part of a quantization block, use default configuration - if n not in layer_config.keys() and n in layers_in_blocks: - layer_config[n] = {} - for key in keys: - layer_config[n][key] = getattr(self, key) - - # If the layer is partially configured, fill in missing values - elif n in layer_config.keys(): - if "data_type" in layer_config[n] and "bits" not in layer_config[n]: - tmp_bits = infer_bits_by_data_type(layer_config[n]["data_type"]) - if tmp_bits is not None and tmp_bits != self.bits: - logger.warning( - f"'data_type' do not match the specified 'bits' setting for {n}." - f" Resetting 'bits' to {tmp_bits}." - ) - layer_config[n]["bits"] = tmp_bits - if "act_data_type" in layer_config[n] and "act_bits" not in layer_config[n]: - tmp_bits = infer_bits_by_data_type(layer_config[n]["act_data_type"]) - if tmp_bits is not None and tmp_bits != self.act_bits: - logger.warning( - f"'act_data_type' do not match the specified 'act_bits' setting for {n}." - f" Resetting 'act_bits' to {tmp_bits}." - ) - layer_config[n]["act_bits"] = tmp_bits - - for key in keys: - if key not in layer_config[n].keys(): - layer_config[n][key] = getattr(self, key) - layer_config[n]["fixed_by_user"] = True - - # If the layer is not in the config and not part of a quantization block, - # use default configuration and set specific values - else: - layer_config[n] = {} - for key in keys: - layer_config[n][key] = getattr(self, key) - layer_config[n]["bits"] = 16 - layer_config[n]["act_bits"] = 16 - - if n in layers_in_blocks: - layer_config[n]["in_blocks"] = True - else: - layer_config[n]["in_blocks"] = False - - # If the layer is outside a block and requires quantization, mark it as a quantized layer outside the block - if ( - n not in layers_in_blocks - and check_to_quantized(layer_config[n]) - and not isinstance(m, torch.nn.Embedding) - ): - has_qlayer_outside_block = True - - in_features, out_features = get_layer_features(m) - if in_features <= layer_config[n]["group_size"]: - layer_config[n]["group_size"] = -1 - - # Apply the configuration to the corresponding layer in the model - for key in keys: - setattr(m, key, layer_config[n][key]) - need_to_quantize_lm_head = self._check_need_to_quantize_lm_head_embedding() - if need_to_quantize_lm_head: - has_qlayer_outside_block = True - - # Return whether there are quantized layers outside the blocks - return has_qlayer_outside_block - @torch.no_grad() def _get_block_outputs( self, diff --git a/auto_round/data_type/register.py b/auto_round/data_type/register.py index 12c4406a4..fca259ed6 100644 --- a/auto_round/data_type/register.py +++ b/auto_round/data_type/register.py @@ -22,8 +22,7 @@ def register_dtype(names): Decorator function used before a Pattern subclass. Args: - cls (class): The subclass of register. - name: A string. Define the export type. + names: A string. Define the export type. Returns: cls: The class of register. diff --git a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py index c4a02f673..eaf3ad9ae 100644 --- a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py +++ b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py @@ -174,7 +174,7 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs): for n, m in model.named_modules(): if type(m) in SUPPORTED_LAYER_TYPES: layer = m - if layer.act_bits < 8 and not getattr(layer, "input_global_scale", None): + if hasattr(layer, "act_bits") and layer.act_bits < 8 and not getattr(layer, "input_global_scale", None): assert hasattr(layer, "act_max") from auto_round.data_type.nvfp import calculate_gparam @@ -198,7 +198,7 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs): for layer_name in layer_config: if ( not layer_config[layer_name]["in_blocks"] and layer_config[layer_name]["bits"] <= 8 - ): ##lm head ##TODO fix act and so on + ): ##lm head # TODO fix act and so on extra_config[layer_name] = {} extra_config[layer_name]["bits"] = layer_config[layer_name]["bits"] extra_config[layer_name]["data_type"] = layer_config[layer_name]["data_type"] diff --git a/auto_round/schemes.py b/auto_round/schemes.py index a5c5975c9..cde37a0c9 100644 --- a/auto_round/schemes.py +++ b/auto_round/schemes.py @@ -14,9 +14,9 @@ import copy from copy import deepcopy from dataclasses import dataclass, fields -from typing import Generator, List, Optional +from typing import Iterable, Optional, Union -__all__ = ["QuantizationScheme", "preset_name_to_scheme"] +__all__ = ["QuantizationScheme", "get_gguf_scheme", "preset_name_to_scheme", "AutoScheme"] @dataclass @@ -38,7 +38,7 @@ def from_dict(cls, config: dict): return cls(**config) @classmethod - def get_attributes(cls: "QuantizationScheme") -> List[str]: + def get_attributes(cls: "QuantizationScheme") -> list[str]: return [field.name for field in fields(cls)] def __getitem__(self, key: str): @@ -180,6 +180,7 @@ def is_preset_scheme(name: str) -> bool: } ) + # FP8 = asdict(QuantArgs.from_dict({ # "bits": 8, # "group_size": 128, @@ -201,6 +202,18 @@ def is_preset_scheme(name: str) -> bool: } ) +# For AutoScheme 16 bits options +BF16 = QuantizationScheme.from_dict( + { + "bits": 16, + "group_size": 0, + "data_type": "fp", + "act_bits": 16, + "act_data_type": "fp", + } +) + + PRESET_SCHEMES = { "W4A16": W4A16, "W2A16": W2A16, @@ -211,6 +224,7 @@ def is_preset_scheme(name: str) -> bool: "NVFP4": NVFP4, "FPW8A16": FPW8A16, "FP8_STATIC": FP8_STATIC, + "BF16": BF16, } from auto_round.export.export_to_gguf.config import GGUF_CONFIG @@ -220,3 +234,29 @@ def is_preset_scheme(name: str) -> bool: value.pop("embedding", None) value.pop("lm_head", None) PRESET_SCHEMES[key.upper()] = QuantizationScheme.from_dict(value) + + +def get_gguf_scheme(scheme: Union[str, QuantizationScheme]) -> str: + if isinstance(scheme, str) and scheme.upper().startswith("GGUF"): + return scheme + if isinstance(scheme, str): + return "" + for key, val in PRESET_SCHEMES.items(): + if not key.upper().startswith("GGUF"): + continue + equal = True + for scheme_key in val.keys(): + if val[scheme_key] is not None and val[scheme_key] != scheme.get(scheme_key, None): + equal = False + break + if equal: + return key + return "" + + +@dataclass +class AutoScheme: + options: Optional[Iterable[QuantizationScheme | str]] + avg_bits: float + shared_layers: Optional[Iterable[Iterable[str]]] = None + method: str = "default" diff --git a/auto_round/utils.py b/auto_round/utils.py index 26ec5f996..a1c411373 100644 --- a/auto_round/utils.py +++ b/auto_round/utils.py @@ -21,6 +21,7 @@ import re import sys from collections import UserDict +from dataclasses import asdict, fields from enum import Enum from functools import lru_cache from pathlib import Path @@ -34,7 +35,7 @@ from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, GGUF_CONFIG, GGUF_INNER_CONFIG, QK_K, ModelType from auto_round.logger import logger -from auto_round.schemes import QuantizationScheme +from auto_round.schemes import QuantizationScheme, get_gguf_scheme, preset_name_to_scheme SHARED_CACHE_KEYS = ("position_ids", "cache_position", "position_embeddings") @@ -765,8 +766,11 @@ def check_memory_availability(device, inputs, weight, org_seqlen, org_bs): def get_layer_names_in_block( - model, supported_types=(torch.nn.Linear, transformers.pytorch_utils.Conv1D), quant_block_list=None, class_names=None -): + model: torch.nn.Module, + supported_types=(torch.nn.Linear, transformers.pytorch_utils.Conv1D), + quant_block_list: list = None, + class_names: tuple = None, +) -> list[str]: """Retrieves the names of layers within each block of the model. Returns: @@ -790,7 +794,6 @@ def get_layer_names_in_block( if hasattr(m, "bk_tmp_name"): layers_in_block.append(m.bk_tmp_name) delattr(m, "bk_tmp_name") - return layers_in_block @@ -1047,7 +1050,7 @@ def can_pack_with_numba(): # pragma: no cover return True -def get_fp_layer_names(model, fp_layers): +def get_fp_layer_names(model: torch.nn.Module, fp_layers: str): """Identifies and returns layers in the model to exclude from quantization. This function processes a comma-separated list of fully precision (FP) layers, @@ -1845,9 +1848,9 @@ def _gguf_type_fallback(gguf_type): ##https://github.com/ggml-org/llama.cpp/blob/9e31bec4fd53634c9e5b04650488a09a055f5dab/src/llama-quant.cpp#L129 -def get_layer_config_by_gguf_format(layer_config, gguf_format, model, model_type=ModelType.TEXT): - # TODO: support for other format later - target_gguf_format = next((fmt for fmt in gguf_format if fmt != "fake"), None) +def get_layer_config_by_gguf_format(layer_config, target_gguf_format: str, model, model_type=ModelType.TEXT): + # # TODO: support for other format later + # target_gguf_format = next((fmt for fmt in gguf_format if fmt != "fake"), None) import gguf # pylint: disable=E0401 @@ -1941,6 +1944,34 @@ def _set_config(config, target_config): ) new_type = new_type[:bits_index] + target_bits + new_type[bits_index + 1 :] else: + config_tmp = config.copy() + scheme_keys = [f.name for f in fields(QuantizationScheme)] + for key in config.keys(): + if key not in scheme_keys: + config_tmp.pop(key, None) + matched_scheme = get_gguf_scheme(QuantizationScheme.from_dict(config_tmp)) # check matched + if not matched_scheme: + if config.get("super_group_size", None) is not None: + new_type = new_type[:bits_index] + str(config["bits"]) + "_k" + if config.get("super_group_size", None) is None or new_type not in GGUF_INNER_CONFIG: + if config.get("sym", True): + new_type = new_type[:bits_index] + str(config["bits"]) + "_0" + if new_type not in GGUF_INNER_CONFIG: + new_type = new_type[:bits_index] + str(config["bits"]) + "_1" + if not config.get("sym", True): + new_type = new_type[:bits_index] + str(config["bits"]) + "_1" + if new_type not in GGUF_INNER_CONFIG: + new_type = new_type[:bits_index] + str(config["bits"]) + "_0" + if new_type not in GGUF_INNER_CONFIG: + raise ValueError( + f"the setting in layer_config {layer_name} " + f"could not match any supported gguf format, please have a check." + ) + else: + logger.warning_once( + f"the setting in layer_config {layer_name} " + f"could not match any supported gguf format, reset to {new_type}" + ) new_type = new_type[:bits_index] + str(config["bits"]) + new_type[bits_index + 1 :] new_type = _search_gguf_type(new_type) if new_type is None: @@ -2284,8 +2315,8 @@ def convert_fp8_layer_to_linear(layer, dtype=torch.bfloat16): new_layer = torch.nn.Linear(layer.in_features, layer.out_features, bias=layer.bias is not None, dtype=dtype) if layer.bias is not None: new_layer.bias.data.copy_(layer.bias.data.to(dtype=dtype)) - - keys = get_quant_keys() + ["tmp_name"] + scheme_keys = (f.name for f in fields(QuantizationScheme)) + keys = tuple(scheme_keys) + ("tmp_name", "scale_dtype") for key in keys: setattr(new_layer, key, getattr(layer, key, None)) @@ -2314,24 +2345,6 @@ def convert_fp8_model_to_16b_model(model, dtype=torch.bfloat16): return model -def get_quant_keys(): - keys = [ - "bits", - "group_size", - "sym", - "data_type", - "scale_dtype", - "act_bits", - "act_group_size", - "act_sym", - "act_dynamic", - "act_data_type", - "super_bits", - "super_group_size", - ] - return keys - - def out_of_vram(error_msg): error_msg = str(error_msg) # CUDA @@ -2761,3 +2774,190 @@ def is_mllm_model(model_or_path: Union[str, torch.nn.Module]): return True return False + + +def set_layer_config( + model: torch.nn.Module, + layer_config: dict[str, Union[str, dict, "QuantizationScheme"]], + default_scheme: Union[str, "QuantizationScheme"], + default_scale_dtype: torch.dtype | str, + supported_types: tuple, + inner_supported_types: tuple, + quant_block_list=None, + fp_layers: str = "", + quant_lm_head: bool = False, + enable_gguf_official_mixed: bool = True, + is_mllm: bool = False, +) -> tuple[dict, bool]: + """ + Normalize, validate, and expand layer-specific quantization configs. + Returns (final_layer_config, has_quant_layer_outside_block) + """ + + from auto_round.schemes import get_gguf_scheme + + # ---- helpers ------------------------------------------------- + def dispatch_layer_config(layer_config: dict[str, dict]) -> None: + """Assign scheme values as attributes to matched modules.""" + for layer_name, scheme in layer_config.items(): + module = get_module(model, layer_name) + for attr, value in scheme.items(): + setattr(module, attr, value) + + def normalize_item(item: Union[str, dict, "QuantizationScheme"], layer_name: str) -> dict: + """Convert config entry into dict and validate keys.""" + if isinstance(item, str): + config = asdict(preset_name_to_scheme(item.upper())) + elif isinstance(item, QuantizationScheme): + config = asdict(item) + elif isinstance(item, dict): + invalid = set(item) - set(scheme_keys) + if invalid: + raise ValueError( + f"Invalid keys {invalid} in layer_config for '{layer_name}'. " f"Allowed keys: {scheme_keys}" + ) + config = dict(item) + else: + raise TypeError( + f"Unsupported type for layer_config[{layer_name}]: {type(item)}. " + f"Expected str, dict, or QuantizationScheme." + ) + # Clean up + config = {k: v for k, v in config.items() if v is not None} + config["fixed_by_user"] = True + return config + + # ---- main logic ---------------------------------------------- + scheme_keys = tuple(f.name for f in fields(QuantizationScheme)) + ("scale_dtype",) + layer_config = copy.deepcopy(layer_config) or {} + + # 1. fp_layers -> force 16 + for name in get_fp_layer_names(model, fp_layers): + layer_config[name] = { + "bits": 16, + "act_bits": 16, + "data_type": "float", + "act_data_type": "float", + "fixed_by_user": True, + } + + # 2. normalize + layer_config = {k: normalize_item(v, k) for k, v in layer_config.items()} + + # 3. infer missing bits + for cfg in layer_config.values(): + if "data_type" in cfg and "bits" not in cfg: + if (b := infer_bits_by_data_type(cfg["data_type"])) is not None: + cfg["bits"] = b + if "act_data_type" in cfg and "act_bits" not in cfg: + if (b := infer_bits_by_data_type(cfg["act_data_type"])) is not None: + cfg["act_bits"] = b + + # 4. fill defaults + if isinstance(default_scheme, str): + default_dict = asdict(preset_name_to_scheme(default_scheme.upper())) + else: + default_dict = asdict(default_scheme) + default_dict["scale_dtype"] = default_scale_dtype + for cfg in layer_config.values(): + for key in scheme_keys: + cfg.setdefault(key, default_dict.copy().get(key)) + + # 5. collect supported modules + gguf_name = get_gguf_scheme(default_scheme) + if gguf_name and torch.nn.Embedding not in supported_types: + supported_types = (*supported_types, torch.nn.Embedding) + + all_layer_names, embedding_layer_names = [], [] + all_module_names = [] + for n, m in model.named_modules(): + all_module_names.append(n) + # cleanup stale attributes + for key in scheme_keys: + if hasattr(m, key): + delattr(m, key) + if type(m) not in supported_types and m.__class__.__name__ not in inner_supported_types: + continue + all_layer_names.append(n) + if isinstance(m, torch.nn.Embedding): + embedding_layer_names.append(n) + + # 6. expand regex configs + for name in list(layer_config.keys()): + if name in all_layer_names: + continue + if name in all_module_names: + logger.warning_once(f"the type of `{name}` is not supported in your scheme, ignore it for now.") + continue + regex = re.compile(name) + matched = [ln for ln in all_layer_names if regex.search(ln)] + if not matched: + raise ValueError(f"Invalid '{name}' in layer_config, no match found.") + val = layer_config.pop(name) + for match in matched: + layer_config[match] = val + + # 7. lm_head + lm_head_name = get_lm_head_name(model) + tie_word_embeddings = False + if hasattr(model, "config") and hasattr(model.config, "tie_word_embeddings"): + tie_word_embeddings = model.config.tie_word_embeddings + + if quant_lm_head and tie_word_embeddings: + quant_lm_head = False + logger.warning( + "reset `quant_lm_head` to false as quantizing " "lm_head with tied weights has not been supported currently" + ) + + if lm_head_name not in layer_config and quant_lm_head: + layer_config[lm_head_name] = default_dict.copy() + + # 8. enforce shape divisibility for int weight-only + if default_dict["data_type"] == "int" and default_dict["act_bits"] >= 16 and not gguf_name: + for n, m in model.named_modules(): + if type(m) in supported_types or m.__class__.__name__ in inner_supported_types: + if m.weight.shape[0] % 32 or m.weight.shape[1] % 32: + layer_config.setdefault(n, default_dict.copy()) + layer_config[n].update({"bits": 16, "data_type": "fp", "fixed_by_user": True}) + logger.warning_once(f"{n} skipped quantization (shape not divisible by 32).") + + # 9. block layers: mark as in_blocks=True + for name in get_layer_names_in_block(model, supported_types, quant_block_list, inner_supported_types): + if name not in layer_config: + layer_config[name] = default_dict.copy() + layer_config[name]["fixed_by_user"] = False + layer_config[name]["in_blocks"] = True + + # ---- restore: ensure missing in_blocks are set to False and compute flag ---- + has_qlayer_outside_block = False + for cfg in layer_config.values(): + if "in_blocks" not in cfg: + cfg["in_blocks"] = False + # mark layer outside block + if not cfg["in_blocks"] and check_to_quantized(cfg): + has_qlayer_outside_block = True + + # 10. GGUF handling + if not gguf_name: + dispatch_layer_config(layer_config) + return layer_config, has_qlayer_outside_block + + # embed + lm_head defaults for gguf + if lm_head_name not in layer_config and not tie_word_embeddings: + cfg = GGUF_INNER_CONFIG[GGUF_CONFIG[gguf_name.lower()]["lm_head"]] + cfg = {**cfg, "fixed_by_user": False, "scale_dtype": default_scale_dtype} + layer_config[lm_head_name] = cfg + has_qlayer_outside_block = True + for emd_name in embedding_layer_names: + if emd_name in layer_config: + continue + cfg = GGUF_INNER_CONFIG[GGUF_CONFIG[gguf_name.lower()]["embedding"]] + cfg = {**cfg, "fixed_by_user": False, "scale_dtype": default_scale_dtype} + layer_config[emd_name] = cfg + + if enable_gguf_official_mixed: + model_type = ModelType.MMPROJ if is_mllm else ModelType.TEXT + layer_config, _ = get_layer_config_by_gguf_format(layer_config, gguf_name.lower(), model, model_type) + + dispatch_layer_config(layer_config) + return layer_config, has_qlayer_outside_block diff --git a/test/test_cpu/test_autoround.py b/test/test_cpu/test_autoround.py index 9511f0cf8..aac524800 100644 --- a/test/test_cpu/test_autoround.py +++ b/test/test_cpu/test_autoround.py @@ -720,6 +720,7 @@ def test_invalid_layer_config(self): iters=1, layer_config=layer_config, ) + ar.quantize() def test_quant_lm_head(self): model_name = "/tf_dataset/auto_round/models/Qwen/Qwen3-8B" diff --git a/test/test_cuda/test_auto_scheme.py b/test/test_cuda/test_auto_scheme.py new file mode 100644 index 000000000..b9fffdee9 --- /dev/null +++ b/test/test_cuda/test_auto_scheme.py @@ -0,0 +1,27 @@ +import copy +import re +import shutil +import sys +import unittest + +sys.path.insert(0, "../..") + +from auto_round import AutoRound, AutoRoundConfig, AutoScheme + + +class TestAutoScheme(unittest.TestCase): + @classmethod + def setUpClass(self): + self.save_dir = "./saved" + self.tasks = "lambada_openai" + + @classmethod + def tearDownClass(self): + shutil.rmtree("./saved", ignore_errors=True) + shutil.rmtree("runs", ignore_errors=True) + + def test_auto_scheme(self): + model_name = "facebook/opt-125m" + scheme = AutoScheme(avg_bits=3, options=("W2A16", "W4A16", "BF16")) + ar = AutoRound(model=model_name, scheme=scheme, iters=1, nsamples=1) + ar.quantize_and_save(self.save_dir)