diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 64a210d07..0f6adfd73 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -54,8 +54,8 @@ ) from auto_round.data_type import QUANT_FUNC_WITH_DTYPE from auto_round.data_type.utils import reshape_pad_tensor_by_group_size -from auto_round.export.export_to_autoround import AutoRoundFormat from auto_round.export.export_to_gguf.config import GGUF_INNER_CONFIG, ModelType +from auto_round.formats import OutputFormat from auto_round.logger import logger from auto_round.schemes import ( SPECIAL_SCHEMES, @@ -69,7 +69,6 @@ from auto_round.utils import ( INNER_SUPPORTED_LAYER_TYPES, SUPPORTED_DTYPES, - SUPPORTED_FORMATS, SUPPORTED_LAYER_TYPES, TORCH_VERSION_AT_LEAST_2_6, CpuInfo, @@ -735,34 +734,18 @@ def _check_configs(self) -> None: def _check_compatibility(self) -> None: """Checks compatibility of the configurations and model.""" - # Check gguf and others has_gguf = False if hasattr(self, "formats"): - has_besides_gguf = False - for format_ in self.formats: - if "gguf" in format_: - has_gguf = True - elif format_ != "fake": - has_besides_gguf = True - if has_gguf and has_besides_gguf: - raise ValueError("GGUF format is not compatible with other formats, please choose only one of them") - if has_gguf: - from transformers.utils.versions import require_version - - require_version( - "sentencepiece", - "GGUF format requires SentencePiece to be installed. " - "Please install it with `pip install sentencepiece`", - ) - if has_gguf and self.iters != 0 and self.bits != 3 and not self.enable_alg_ext: - logger.warning( - "`iters=0` is recommended when exporting to current GGUF format" - " or add `enable_alg_ext` for better accuracy with much more tuning cost." - " Please refer to https://github.com/intel/auto-round/tree/main/docs/gguf_alg_ext_acc.md" - " for the accuracy results." - ) - elif self.bits >= 8 and self.iters != 0: - logger.warning("`iters=0` is recommended for bits>=8") + has_gguf = any([f.is_gguf() for f in self.formats]) + if has_gguf and self.iters != 0 and self.bits != 3 and not self.enable_alg_ext: + logger.warning( + "`iters=0` is recommended when exporting to current GGUF format" + " or add `enable_alg_ext` for better accuracy with much more tuning cost." + " Please refer to https://github.com/intel/auto-round/tree/main/docs/gguf_alg_ext_acc.md" + " for the accuracy results." + ) + elif has_gguf and self.bits >= 8 and self.iters != 0: + logger.warning("`iters=0` is recommended for bits>=8") if ( self.seqlen is not None @@ -788,230 +771,6 @@ def _check_compatibility(self) -> None: if self.group_size == 0 and "fp8" not in self.data_type: logger.warning("`group_size==0` is not supported for data_type other than fp8 ") - def _parse_format_to_list(self, format: str) -> list: - """Parses the format string into a list of formats. - - This method checks the requested format(s) against the model's - quantization settings and adjusts them if necessary. It ensures that - the formats are compatible with the model's data type, bit width, - and activation quantization settings. - - Args: - format (str): The requested format(s) for quantization, separated by commas. - - Returns: - list: A list of validated and updated formats. - """ - - # Remove duplicates from formats list - def remove_duplicates(lst): - seen = set() - return [x for x in lst if not (x in seen or seen.add(x))] - - formats = format.replace("q*_", f"q{self.bits}_").replace(" ", "").split(",") - formats = remove_duplicates(formats) # need the keep origin order - - gguf_format_name = get_gguf_scheme(self.scheme) - - if gguf_format_name: - for i in range(len(formats)): - if gguf_format_name.lower().endswith("mixed"): - gguf_format_name = gguf_format_name.lower().replace("_mixed", "_s") - if formats[i] != "fake" and formats[i] != gguf_format_name.lower(): - logger.warning( - f"reset format {formats[i]} to {gguf_format_name.lower()} " - f"since scheme {gguf_format_name} can only be exported to format {gguf_format_name.lower()}" - ) - formats[i] = gguf_format_name.lower() - - gguf_args_check(self, formats, model_type=ModelType.TEXT) - if self.mllm: - gguf_args_check(self, formats, model_type=ModelType.MMPROJ) - - for f in formats: - if f.startswith("gguf"): - self.scheme = f.upper() - break - - for format_ in formats: - if format_ not in SUPPORTED_FORMATS: - logger.error(f"Unsupported format {format_}, please choose from {SUPPORTED_FORMATS}") - exit(-1) - if self.scale_dtype != torch.float32: - only_gguf = True - for format_ in formats: - if not ("gguf" in format_ or "fake" in format_): - only_gguf = False - break - if len(formats) == 1 and "fake" == formats[0]: - only_gguf = False - if only_gguf: - self.scale_dtype = torch.float32 - logger.info("change `scale_dtype` to `torch.float32`") - - # Adjust format settings based on compatibility - for index in range(len(formats)): - format = formats[index] - if format == "auto_round": - if self.sym and "int" in self.data_type: - format = "auto_round:auto_gptq" - elif self.bits == 4 and not self.sym and "int" in self.data_type: - enable_awq = all( - config["bits"] == self.bits or config["bits"] >= 16 for config in self.layer_config.values() - ) - if enable_awq: - format = "auto_round:auto_awq" - elif is_nv_fp(self.data_type) or is_mx_fp(self.data_type): - format = f"auto_round:{self.data_type}" - elif is_static_wfp8afp8(self): # static wfp8afp8 - format = f"auto_round:{AutoRoundFormat.FP8_STATIC.value}" - elif self.data_type.startswith("fp") and self.bits == 8 and self.act_bits >= 16: # woq fp8 - format = f"auto_round:{AutoRoundFormat.FP8.value}" - elif self.act_bits < 16: - raise ValueError( - "AutoRound format does not support exporting " - "for the current quantization configuration, " - "please change to `fake` format for research purpose" - ) - formats[index] = format - elif format == "llm_compressor": - from auto_round.export.export_to_llmcompressor import check_compressed_tensors_supported - - if is_nv_fp(self.data_type) or is_mx_fp(self.data_type): - check_compressed_tensors_supported() - format = format.replace("llm_compressor", f"llm_compressor:{self.data_type}") - formats[index] = format - elif is_static_wfp8afp8(self): - format = f"llm_compressor:{AutoRoundFormat.FP8_STATIC.value}" - formats[index] = format - if self.act_group_size != 0: - logger.warning( - f"scheme FP8_STATIC export to llm_compressor format only support for act_group_size 0," - f" ,but got act_group_size={self.act_group_size}, reset = 0" - ) - self.act_group_size = 0 - if self.group_size > 0: - logger.warning( - f"please note that group_size={self.group_size}" - " may not be supported for llm_compressor format, and cannot be loaded in llm_compressor" - ) - elif not is_wfp8afp8(self): - logger.error( - "Currently, the llm_compressor format only supports MXFP/NVFP/FP8. " - "Please change format to fake or auto_round etc." - ) - elif "auto_awq" in format: - from auto_round.compressors.utils import check_awq_gemm_compatibility - - awq_supported, info = check_awq_gemm_compatibility( - self.model, self.bits, self.group_size, self.sym, self.layer_config - ) - if not awq_supported: - logger.warning(f"The AutoAWQ format may not be supported due to {info}") - else: - if (is_nv_fp(self.data_type) or is_mx_fp(self.data_type)) and format != "fake": - logger.warning(f"nv_fp and mx_fp dtypes are not supported for export format: {format}") - - formats = remove_duplicates(formats) - for i in range(len(formats)): - formats[i] = self._check_supported_format(formats[i]) - formats = remove_duplicates(formats) - return formats - - def _check_supported_format(self, format: str) -> bool: - """Checks if the specified format is supported. - - This method validates the requested format against the model's bit width, - group size, symmetry, and activation quantization settings. It raises an - error if the format is incompatible with the current model configuration. - - Args: - format (str): The requested format for quantization. - - Returns: - bool: True if the format is supported, False otherwise. - """ - if format == "fake": - return format - format = format.replace("q*_", f"q{self.bits}_") - - # format check for fp8 - w_fp8 = self.data_type.startswith("fp") and self.bits == 8 - act_fp8 = self.act_data_type.startswith("fp") and self.act_bits == 8 - if (w_fp8 or act_fp8) and re.search("^auto_round|^llm_compressor", format) is None: - error_msg = ( - f"is only supported to export auto_round or llm_compressor format," f" but got {format}, please check." - ) - error_msg = ("act_data_type " + error_msg) if act_fp8 else error_msg - error_msg = ("data_type " + error_msg) if w_fp8 else error_msg - logger.error(error_msg) - sys.exit(-1) - - # Only support to export afp8/nv_fp/mx_fp - if self.act_bits <= 8: - if not is_standard_fp(self.act_data_type) or self.act_dynamic: - if "llm_compressor" in format: - if (is_nv_fp(self.act_data_type) and "static_gs" in self.act_data_type) or ( - is_mx_fp(self.act_data_type) - ): - return format - bits, group_size, sym, act_bits = 8, -1, True, 8 - assert ( - self.bits == bits - and self.group_size == group_size - and self.sym == sym - and self.act_bits == act_bits - and self.act_dynamic - ), ( - f"Currently only support to export llm_compressor format for sym dynamic quantized" - f" W{self.bits}A{self.act_bits} model with group_size={group_size}," - f" but got bits={self.bits}, group_size={self.group_size}, sym={self.sym}," - f" act_bits={self.act_bits}" - ) - elif "auto_round" in format and ( - is_mx_fp(self.act_data_type) or (is_nv_fp(self.act_data_type) and "static_gs" in self.act_data_type) - ): - pass - elif format != "fake": - logger.warning( - "Currently only support to export auto_round format quantized model" - " with fp8, mx_fp and nv_fp4 dtype activation for activation quantization." - " Change format to fake and save." - ) - format = "fake" - else: - if format not in [ - "auto_round", - f"auto_round:{AutoRoundFormat.FP8_STATIC.value}", - f"llm_compressor:{AutoRoundFormat.FP8_STATIC.value}", - "auto_round:llm_compressor", - ]: - logger.warning( - f"Currently only support to export auto_round or fake format for static W{self.bits}AFP8 model," - f" change format {format} to auto_round" - ) - if is_static_wfp8afp8(self): - format = f"auto_round:{AutoRoundFormat.FP8_STATIC.value}" - else: - format = f"auto_round:{AutoRoundFormat.FP8.value}" - if ( - self.act_group_size != 0 - and not self.act_dynamic - and format == f"auto_round:{AutoRoundFormat.FP8.value}" - ): - logger.warning( - f"Please note that quantize activation with act_group_size={self.act_group_size}" - " may result in failure to export or import normally." - ) - if re.search(r"q\d_k", format) and not self.data_type.endswith("_dq"): - logger.error( - f"datatype<{self.data_type}> not support to export {format} format." - " Please change export format or `data_type`." - ) - sys.exit(-1) - - return format - def quantize_and_save( self, output_dir: str = "tmp_autoround", format: str = "auto_round", inplace: bool = True, **kwargs ) -> tuple[torch.nn.Module, dict[str, Any]]: @@ -1041,7 +800,7 @@ def quantize_and_save( self.orig_output_dir = output_dir # check and update the format based on the current configuration - format_list = self._parse_format_to_list(format) + format_list = OutputFormat.get_formats(format, self) self.formats = format_list # If multiple formats are specified, enforce inplace=False @@ -1064,24 +823,12 @@ def quantize_and_save( else: model, _ = self.quantize() # Save the quantized model in the specified format_list - folders = [] - for format in format_list: - if "gptq" in format and not self.sym: - logger.warning( - "The asymmetrical kernel of the GPTQ format may result in a noticeable accuracy drop," - " particularly for 2-bit quantization and smaller models." - " We recommend exporting to either the AutoAWQ format ( only 4 bits) or " - "the AutoRound format(2/3/4/8 bits)." - ) - save_folder = self._get_save_folder_name(format) - self.save_quantized(save_folder, format=format, inplace=inplace, **kwargs) - - folders.append(save_folder) + model, folders = self.save_quantized(output_dir, format=format, inplace=inplace, return_folders=True, **kwargs) memory_monitor.log_summary() return model, folders - def _get_save_folder_name(self, format_str: str) -> str: + def _get_save_folder_name(self, format: OutputFormat) -> str: """Generates the save folder name based on the provided format string. If there are multiple formats to handle, the function creates a subfolder @@ -1095,7 +842,7 @@ def _get_save_folder_name(self, format_str: str) -> str: str: The path to the folder where results should be saved. """ # Replace special characters to make the folder name filesystem-safe - sanitized_format = format_str.replace(":", "-").replace("_", "-") + sanitized_format = format.get_backend_name().replace(":", "-").replace("_", "-") # Use a subfolder only if there are multiple formats if len(self.formats) > 1: @@ -1306,7 +1053,7 @@ def _quantize_layer_via_rtn(self, name: str, dtype: torch.dtype = None, to_cpu=T set_module(self.model, name, m) tuning_device = m.tuning_device if hasattr(m, "tuning_device") else self.device # Step 1: Try quantization on GPU first, fall back to CPU if OOM - if self.immediate_packing and self.iters == 0 and "gguf" in self.formats[0] and not self.disable_opt_rtn: + if self.immediate_packing and self.iters == 0 and self.formats[0].is_gguf() and not self.disable_opt_rtn: m = m.to(tuning_device) m.scale = None m.zp = None @@ -1366,8 +1113,8 @@ def _immediate_pack(self, name: str): return from auto_round.export import PACKING_LAYER_WITH_FORMAT - target_backend = self.formats[0].split(":")[0] if ":" in self.formats[0] else self.formats[0] - has_gguf = any("gguf" in fmt for fmt in self.formats) + target_backend = self.formats[0].output_format + has_gguf = any(fmt.is_gguf() for fmt in self.formats) if has_gguf: from auto_round.export.export_to_gguf.export import pack_gguf_layer @@ -1377,7 +1124,7 @@ def _immediate_pack(self, name: str): pack_gguf_layer( name, self.model, - self.formats[0], + self.formats[0].output_format, output_dir, self.layer_config, self.tokenizer, @@ -1387,7 +1134,9 @@ def _immediate_pack(self, name: str): device=self.device, ) else: - PACKING_LAYER_WITH_FORMAT[target_backend](name, self.model, self.formats[0], device=self.device) + PACKING_LAYER_WITH_FORMAT[target_backend]( + name, self.model, self.formats[0].output_format, device=self.device + ) @torch.inference_mode() def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]: @@ -1421,7 +1170,7 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]: for module in tqdm(modules, desc="Update weight global scale for fuse module"): update_fused_layer_global_scales(module) - if not (any("gguf" in fmt for fmt in getattr(self, "formats", [])) or self.super_bits is not None): + if not (any(fmt.is_gguf() for fmt in getattr(self, "formats", [])) or self.super_bits is not None): self._quantize_embedding_layer() # leave to gguf itself to handle self.model.to("cpu") @@ -1431,7 +1180,8 @@ def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]: enable_imatrix = False if not self.disable_opt_rtn: has_gguf_k = ( - any("gguf" in fmt and "k" in fmt for fmt in getattr(self, "formats", [])) or self.super_bits is not None + any(fmt.is_gguf() and "k" in fmt.output_format for fmt in getattr(self, "formats", [])) + or self.super_bits is not None ) if has_gguf_k: enable_imatrix = True @@ -1670,20 +1420,11 @@ def quantize(self) -> tuple[torch.nn.Module, dict[str, Any]]: else: # Determine if immediate packing is required formats = self.formats - if ( - len(formats) == 1 - and ( - "awq" in formats[0] - or "gptq" in formats[0] - or "auto_round" in formats[0] - or "gguf" in formats[0] - or "llm_compressor" in formats[0] - ) - and self.inplace - ): + if len(formats) == 1 and not formats[0].is_fake() and self.inplace: self.immediate_packing = True - if "gguf" not in formats[0] and self.low_cpu_mem_usage: + if not formats[0].is_gguf() and self.low_cpu_mem_usage: self.immediate_saving = True + if self.immediate_saving and "int" not in self.data_type: logger.warning("immediate_saving is only supported for int quantization, set to False") self.immediate_saving = False @@ -1860,7 +1601,7 @@ def _quantize_layers(self, layer_names: list, layer_inputs: dict) -> None: has_gguf = False if hasattr(self, "formats"): - has_gguf = any("gguf" in format_ for format_ in self.formats) + has_gguf = any(format_.is_gguf() for format_ in self.formats) if has_gguf and self.immediate_packing: enable_quanted_input = False @@ -3131,7 +2872,12 @@ def _quantize_blocks( clear_memory(device_list=self.device_list) def save_quantized( - self, output_dir: str = None, format: str = "auto_round", inplace: bool = True, **kwargs + self, + output_dir: str = None, + format: Union[str, list[OutputFormat]] = "auto_round", + inplace: bool = True, + return_folders=False, + **kwargs, ) -> torch.nn.Module: """Save the quantized model to the specified output directory in the specified format. @@ -3144,122 +2890,130 @@ def save_quantized( Returns: object: The compressed model object. """ - format = self._check_supported_format(format) + self.orig_output_dir = output_dir + if isinstance(format, str): + formats = OutputFormat.get_formats(format, self) + if not hasattr(self, "formats"): + self.formats = formats if not self.quantized: logger.warning("please run autoround.quantize first") return - if format == "fake" or format == "qdq": # TODO fix act quantization later - self.model = self.model.to("cpu") - self.model.save_pretrained(output_dir) - if self.tokenizer is not None and hasattr(self.tokenizer, "save_pretrained"): - self.tokenizer.save_pretrained(output_dir) - processor = kwargs.get("processor", None) - if processor is not None: - processor.save_pretrained(output_dir) - try: - copy_python_files_from_model_cache(self.model, output_dir) - except Exception as e: - logger.warning("Skipping source model Python file copy due to error: %s", e) - return - if self.act_bits <= 8 and format == "qdq": - logger.warning( - "Support for exporting activation quantization is limited. " - "Please ensure that your configuration is supported." - ) - # if format == "llm_compressor" and (is_nv_fp(self.data_type) or is_mx_fp(self.data_type)): - # format = format.replace("llm_compressor", f"llm_compressor:{self.data_type}") - if format == "llm_compressor" and (is_nv_fp(self.data_type) or is_mx_fp(self.data_type)): - format = format.replace("llm_compressor", f"llm_compressor:{self.data_type}") - if format == "llm_compressor" and is_static_wfp8afp8(self): - format = format.replace("llm_compressor", "llm_compressor:{AutoRoundFormat.FP8_STATIC.value}") - - from auto_round.export import EXPORT_FORMAT - - backend = format - format = format.split(":")[0] - if format not in EXPORT_FORMAT: - logger.error(f"export format only supports {EXPORT_FORMAT.keys()}") - raise ValueError(f"export format only supports {EXPORT_FORMAT.keys()}, but got {format}") - save_quantized_as_format = EXPORT_FORMAT.get(format) - if "gptq" in format and not self.sym: - logger.warning( - "the asymmetrical kernel of the GPTQ format may result in a noticeable accuracy drop," - " particularly for 2-bit quantization and smaller models." - " We recommend exporting to either the AutoAWQ format ( only 4 bits) or " - "the AutoRound format(2/3/4/8 bits)." + folders = [] + for format in formats: + if format.is_gptq() and not self.sym: + logger.warning( + "The asymmetrical kernel of the GPTQ format may result in a noticeable accuracy drop," + " particularly for 2-bit quantization and smaller models." + " We recommend exporting to either the AutoAWQ format ( only 4 bits) or " + "the AutoRound format(2/3/4/8 bits)." + ) + save_folder = self._get_save_folder_name(format) + if format.is_fake(): # TODO fix act quantization later + self.model = self.model.to("cpu") + self.model.save_pretrained(output_dir) + if self.tokenizer is not None and hasattr(self.tokenizer, "save_pretrained"): + self.tokenizer.save_pretrained(output_dir) + processor = kwargs.get("processor", None) + if processor is not None: + processor.save_pretrained(output_dir) + try: + copy_python_files_from_model_cache(self.model, output_dir) + except Exception as e: + logger.warning("Skipping source model Python file copy due to error: %s", e) + compressed_model = self.model + continue + if self.act_bits <= 8 and format.is_fake(): + logger.warning( + "Support for exporting activation quantization is limited. " + "Please ensure that your configuration is supported." + ) + from auto_round.export import EXPORT_FORMAT + + backend = format.get_backend_name() + output_format = format.output_format + if output_format not in EXPORT_FORMAT: + logger.error(f"export format only supports {EXPORT_FORMAT.keys()}") + raise ValueError(f"export format only supports {EXPORT_FORMAT.keys()}, but got {output_format}") + save_quantized_as_format = EXPORT_FORMAT.get(output_format) + if format.is_gptq() and not self.sym: + logger.warning( + "the asymmetrical kernel of the GPTQ format may result in a noticeable accuracy drop," + " particularly for 2-bit quantization and smaller models." + " We recommend exporting to either the AutoAWQ format ( only 4 bits) or " + "the AutoRound format(2/3/4/8 bits)." + ) + if format.is_awq() and not self.bits == 4: + raise ValueError("The AWQ format only supports W4 quantization ") + serialization_keys = [ + "bits", + "group_size", + "sym", + "data_type", + "enable_quanted_input", + "enable_minmax_tuning", + "seqlen", + "batch_size", + "scale_dtype", + "lr", + "minmax_lr", + "gradient_accumulate_steps", + "iters", + "amp", + "nsamples", + "low_gpu_mem_usage", + "to_quant_block_names", + "enable_norm_bias_tuning", + "act_bits", + "act_group_size", + "act_sym", + "act_dynamic", + "act_data_type", + "super_bits", + "super_group_size", + "regex_config", + ] + if isinstance(self.dataset, str): + serialization_keys.append("dataset") + serialization_dict = {} + for key in serialization_keys: + serialization_dict[key] = getattr(self, key) + from auto_round.version import __version__ + + serialization_dict["autoround_version"] = __version__ + if "scale_dtype" in serialization_dict.keys(): + serialization_dict["scale_dtype"] = str(serialization_dict["scale_dtype"]) + compressed_model = save_quantized_as_format( # TODO refine the code + save_folder, + model=self.model, + layer_config=self.layer_config, + inplace=inplace, + bits=self.bits, + act_bits=self.act_bits, + group_size=self.group_size, + sym=self.sym, + iters=self.iters, + lr=self.lr, + minmax_lr=self.minmax_lr, + enable_minmax_tuning=self.enable_minmax_tuning, + enable_quanted_input=self.enable_quanted_input, + scale_dtype=self.scale_dtype, + tokenizer=self.tokenizer, + supported_types=self.supported_types, + data_type=self.data_type, + act_data_type=self.act_data_type, + serialization_dict=serialization_dict, + backend=backend, + to_quant_block_names=self.to_quant_block_names, + quant_block_list=self.quant_block_list, + device=self.device, + static_kv_dtype=self.static_kv_dtype, + static_attention_dtype=self.static_attention_dtype, + **kwargs, ) - if "awq" in format and not self.bits == 4: - raise ValueError("The AWQ format only supports W4 quantization ") - serialization_keys = [ - "bits", - "group_size", - "sym", - "data_type", - "enable_quanted_input", - "enable_minmax_tuning", - "seqlen", - "batch_size", - "scale_dtype", - "lr", - "minmax_lr", - "gradient_accumulate_steps", - "iters", - "amp", - "nsamples", - "low_gpu_mem_usage", - "to_quant_block_names", - "enable_norm_bias_tuning", - "act_bits", - "act_group_size", - "act_sym", - "act_dynamic", - "act_data_type", - "super_bits", - "super_group_size", - "regex_config", - "static_kv_dtype", - "static_attention_dtype", - ] - if isinstance(self.dataset, str): - serialization_keys.append("dataset") - serialization_dict = {} - for key in serialization_keys: - serialization_dict[key] = getattr(self, key) - from auto_round.version import __version__ - - serialization_dict["autoround_version"] = __version__ - if "scale_dtype" in serialization_dict.keys(): - serialization_dict["scale_dtype"] = str(serialization_dict["scale_dtype"]) - compressed_model = save_quantized_as_format( # TODO refine the code - output_dir, - model=self.model, - layer_config=self.layer_config, - inplace=inplace, - bits=self.bits, - act_bits=self.act_bits, - group_size=self.group_size, - sym=self.sym, - iters=self.iters, - lr=self.lr, - minmax_lr=self.minmax_lr, - enable_minmax_tuning=self.enable_minmax_tuning, - enable_quanted_input=self.enable_quanted_input, - scale_dtype=self.scale_dtype, - tokenizer=self.tokenizer, - supported_types=self.supported_types, - data_type=self.data_type, - act_data_type=self.act_data_type, - serialization_dict=serialization_dict, - backend=backend, - to_quant_block_names=self.to_quant_block_names, - quant_block_list=self.quant_block_list, - device=self.device, - static_kv_dtype=self.static_kv_dtype, - static_attention_dtype=self.static_attention_dtype, - **kwargs, - ) - return compressed_model + folders.append(save_folder) + + return compressed_model, folders if return_folders else compressed_model def _get_quantized_layer_names_outside_blocks(self) -> list: """Gets the names of quantized layers outside blocks in the model. diff --git a/auto_round/export/export_to_autoround/__init__.py b/auto_round/export/export_to_autoround/__init__.py index 6cdcd5aed..216139f42 100644 --- a/auto_round/export/export_to_autoround/__init__.py +++ b/auto_round/export/export_to_autoround/__init__.py @@ -12,4 +12,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .export import save_quantized_as_autoround, AutoRoundFormat +from .export import save_quantized_as_autoround, AutoRoundExportFormat diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index a67dc7c21..c98303698 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -45,7 +45,7 @@ ) -class AutoRoundFormat(str, Enum): +class AutoRoundExportFormat(str, Enum): # Weight: FP8, per-channel, may be extended to per-tensor in future # Activation: FP8, per-tensor FP8_STATIC = "fp8_static" @@ -165,8 +165,8 @@ def pack_layer(layer_name, model, backend, device=None): return pack_layer(layer_name, model, backend, device) if ( - backend == f"auto_round:{AutoRoundFormat.FP8.value}" - or backend == f"auto_round:{AutoRoundFormat.FP8_STATIC.value}" + backend == f"auto_round:{AutoRoundExportFormat.FP8.value}" + or backend == f"auto_round:{AutoRoundExportFormat.FP8_STATIC.value}" ): from auto_round.export.export_to_autoround.export_to_fp8 import pack_layer @@ -298,7 +298,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex if ( (kwargs.get("sym") is None or kwargs.get("sym")) and ("gptq" not in backend and "awq" not in backend) - and (AutoRoundFormat.FP8_STATIC.value not in backend) + and (AutoRoundExportFormat.FP8_STATIC.value not in backend) ): backend = backend.replace("auto_round", "auto_round:auto_gptq") diff --git a/auto_round/formats.py b/auto_round/formats.py new file mode 100644 index 000000000..db3416779 --- /dev/null +++ b/auto_round/formats.py @@ -0,0 +1,419 @@ +# Copyright (c) 2025 Intel Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import re +import sys +from typing import TYPE_CHECKING, Callable, Union + +import torch + +from auto_round.compressors.utils import ( + gguf_args_check, + is_mx_fp, + is_nv_fp, + is_standard_fp, + is_static_wfp8afp8, + is_wfp8afp8, +) +from auto_round.export.export_to_autoround import AutoRoundExportFormat +from auto_round.export.export_to_gguf.config import ModelType +from auto_round.schemes import ( + PRESET_SCHEMES, + QuantizationScheme, + get_gguf_scheme, +) +from auto_round.utils import SUPPORTED_FORMATS, logger + +if TYPE_CHECKING: + from auto_round.compressors.base import BaseCompressor + + +def _check_gguf_compatibility(ar: BaseCompressor, formats: list): + if len([f for f in formats if f.lower() != "fake"]) > 1: + raise ValueError( + f"GGUF format is not compatible with other formats, but got {formats}, please choose only one of them" + ) + gguf_format_name = get_gguf_scheme(ar.scheme) + if gguf_format_name: + if gguf_format_name.lower().endswith("mixed"): + gguf_format_name = gguf_format_name.lower().replace("_mixed", "_s") + if any([f.lower() not in ["fake", gguf_format_name.lower()] for f in formats]): + tmp_format_name = gguf_format_name.lower() if "fake" not in formats else f"{gguf_format_name.lower()},fake" + logger.warning( + f"reset format {','.join(formats)} to {tmp_format_name} " + f"since scheme {gguf_format_name} can only be exported to format {gguf_format_name.lower()} or fake" + ) + formats = tmp_format_name.split(",") + return formats + + +def _check_act_compatibility(ar: BaseCompressor, formats: list[str]) -> list[str]: + for i in range(len(formats)): + if formats[i] == "fake": + continue + + # format check for fp8 + w_fp8 = ar.data_type.startswith("fp") and ar.bits == 8 + act_fp8 = ar.act_data_type.startswith("fp") and ar.act_bits == 8 + if (w_fp8 or act_fp8) and re.search("^auto_round|^llm_compressor", formats[i]) is None: + error_msg = ( + f"is only supported to export auto_round or llm_compressor format," + f" but got {formats[i]}, please check." + ) + error_msg = ("act_data_type " + error_msg) if act_fp8 else error_msg + error_msg = ("data_type " + error_msg) if w_fp8 else error_msg + logger.error(error_msg) + sys.exit(-1) + + # Only support to export afp8/nv_fp/mx_fp + if ar.act_bits <= 8: + if not is_standard_fp(ar.act_data_type) or ar.act_dynamic: + if "llm_compressor" in formats[i]: + if (is_nv_fp(ar.act_data_type) and "static_gs" in ar.act_data_type) or (is_mx_fp(ar.act_data_type)): + continue + bits, group_size, sym, act_bits = 8, -1, True, 8 + assert ( + ar.bits == bits + and ar.group_size == group_size + and ar.sym == sym + and ar.act_bits == act_bits + and ar.act_dynamic + ), ( + f"Currently only support to export llm_compressor format for sym dynamic quantized" + f" W{ar.bits}A{ar.act_bits} model with group_size={group_size}," + f" but got bits={ar.bits}, group_size={ar.group_size}, sym={ar.sym}," + f" act_bits={ar.act_bits}" + ) + elif "auto_round" in formats[i] and ( + is_mx_fp(ar.act_data_type) or (is_nv_fp(ar.act_data_type) and "static_gs" in ar.act_data_type) + ): + pass + elif formats[i] != "fake": + logger.warning( + "Currently only support to export auto_round format quantized model" + " with fp8, mx_fp and nv_fp4 dtype activation for activation quantization." + f" Change format <{formats[i]}> to fake and save." + ) + formats[i] = "fake" + else: + if ( + ar.act_group_size != 0 + and not ar.act_dynamic + and formats[i] == f"auto_round:{AutoRoundExportFormat.FP8.value}" + ): + logger.warning( + f"Please note that quantize activation with act_group_size={ar.act_group_size}" + " may result in failure to export or import normally." + ) + if re.search(r"q\d_k", formats[i]) and not ar.data_type.endswith("_dq"): + logger.error( + f"datatype<{ar.data_type}> not support to export {formats[i]} format." + " Please change export format or `data_type`." + ) + sys.exit(-1) + + return formats + + +class OutputFormat: + support_schemes: list = [] + _format_list: dict[str, OutputFormat] = {} + format_name = "base" + + def __init__(self, format: str, ar: BaseCompressor): + if not self.is_support_scheme(ar.scheme): + logger.error( + f"Currently, the {self.format_name} format only supports {self.support_schemes}, " + f"but got scheme {ar.scheme}, please change to fake or auto_round etc." + ) + exit(-1) + self.output_format = format + self.backend = None + + @classmethod + def register(cls, *names: str) -> Callable[[OutputFormat], OutputFormat]: + assert names + + def func(output_format: OutputFormat) -> OutputFormat: + for name in names: + cls._format_list[name] = output_format + return output_format + + return func + + @classmethod + def get_formats( + cls, + format: str, + ar: BaseCompressor, + ) -> list[OutputFormat]: + """Get the list of OutputFormat instances based on the provided name.""" + + def remove_duplicates(lst): + seen = set() + return [x for x in lst if not (x in seen or seen.add(x))] + + formats = format.replace("q*_", f"q{ar.bits}_").replace(" ", "").split(",") + formats = remove_duplicates(formats) # need the keep origin order + + # check gguf scheme compatibility + formats = _check_gguf_compatibility(ar, formats) + + # check activation quantization compatibility + formats = _check_act_compatibility(ar, formats) + + formats = remove_duplicates(formats) + + for i in range(len(formats)): + if formats[i].startswith("gguf:"): + formats[i] = GGUFFormat(formats[i], ar) + elif formats[i] not in cls._format_list: + raise KeyError(f"Unsupported format {formats[i]}, please choose from {SUPPORTED_FORMATS}") + else: + formats[i] = cls._format_list[formats[i]](formats[i], ar) + + if len(formats) == 1 and formats[0].is_gguf and ar.scale_dtype != torch.float32: + ar.scale_dtype = torch.float32 + logger.info("change `scale_dtype` to `torch.float32` for gguf format") + + return formats + + @classmethod + def get_support_matrix(cls: OutputFormat) -> str: + output_str = "" + for k, v in cls._format_list.items(): + support_scheme = ", ".join(v.support_schemes).rstrip(",") + output_str += f"\x1b[31;1m{k}\x1b[0m support scheme:\n\t{support_scheme}\n" + return output_str + + def get_backend_name(self) -> str: + if self.backend is None: + return self.output_format + # for format like auto_round:fp8, auto_round:fp8_static + if self.backend.output_format.startswith("auto_round"): + return self.backend.output_format if self.backend else self.output_format + + return f"{self.output_format}:{self.backend.output_format}" + + @classmethod + def is_support_scheme(cls: OutputFormat, scheme: Union[str, QuantizationScheme]) -> bool: + if isinstance(scheme, str) and scheme in cls.support_schemes: + return True + if isinstance(scheme, QuantizationScheme): + return True + return False + + def is_gguf(self) -> bool: + return "gguf" in self.output_format + + def is_fake(self) -> bool: + return self.output_format == "fake" + + def is_gptq(self) -> bool: + return "gptq" in self.output_format or (self.backend is not None and self.backend.is_gptq()) + + def is_awq(self) -> bool: + return "awq" in self.output_format or (self.backend is not None and self.backend.is_awq()) + + def is_llm_compressor(self) -> bool: + return "llm_compressor" in self.output_format or (self.backend is not None and self.backend.is_llm_compressor()) + + +@OutputFormat.register("fake") +class FakeFormat(OutputFormat): + support_schemes = [ + "W4A16", + "W2A16", + "W3A16", + "W8A16", + "MXFP4", + "MXFP8", + "NVFP4", + "FPW8A16", + "W2A16G64", + "W2A16G32", + "FP8_STATIC", + "BF16", + "GGUF:Q4_0", + "GGUF:Q4_1", + "GGUF:Q5_0", + "GGUF:Q5_1", + "GGUF:Q2_K_S", + "GGUF:Q3_K_S", + "GGUF:Q3_K_M", + "GGUF:Q3_K_L", + "GGUF:Q4_K_S", + "GGUF:Q4_K_M", + "GGUF:Q5_K_S", + "GGUF:Q5_K_M", + "GGUF:Q6_K", + "GGUF:Q8_0", + ] + format_name = "fake" + + +@OutputFormat.register("llm_compressor") +class LLMCompressorFormat(OutputFormat): + support_schemes = ["MXFP4", "MXFP8", "NVFP4", "FPW8A16", "FP8_STATIC"] + format_name = "llm_compressor" + + def __init__(self, format, ar): + if not self.is_support_scheme(ar.scheme): + logger.error( + f"Currently, the llm_compressor format only supports {self.support_schemes}, " + f"but got scheme {ar.scheme}, please change to fake or auto_round etc." + ) + exit(-1) + if is_nv_fp(ar.data_type) or is_mx_fp(ar.data_type): + from auto_round.export.export_to_llmcompressor import check_compressed_tensors_supported + + check_compressed_tensors_supported() + format = format.replace("llm_compressor", f"llm_compressor:{ar.data_type}") + elif is_static_wfp8afp8(ar): + format = f"llm_compressor:{AutoRoundExportFormat.FP8_STATIC.value}" + if ar.act_group_size != 0: + logger.warning( + f"scheme FP8_STATIC export to llm_compressor format only support for act_group_size 0," + f" ,but got act_group_size={ar.act_group_size}, reset = 0" + ) + ar.act_group_size = 0 + if ar.group_size > 0: + logger.warning( + f"please note that group_size={ar.group_size}" + " may not be supported for llm_compressor format, and cannot be loaded in llm_compressor" + ) + elif not is_wfp8afp8(ar): + logger.error( + "Currently, the llm_compressor format only supports MXFP/NVFP/FP8. " + "Please change format to fake or auto_round etc." + ) + self.output_format = format + self.backend = None + + +@OutputFormat.register("auto_gptq") +class AutoGPTQFormat(OutputFormat): + support_schemes = ["W4A16", "W2A16", "W3A16", "W8A16", "BF16", "W2A16G64", "W2A16G32"] + format_name = "auto_gptq" + + +@OutputFormat.register("auto_awq") +class AutoAWQFormat(OutputFormat): + support_schemes = ["W4A16", "W2A16", "W3A16", "W8A16", "BF16", "W2A16G64", "W2A16G32"] + format_name = "auto_awq" + + def __init__(self, format, ar): + from auto_round.compressors.utils import check_awq_gemm_compatibility + + awq_supported, info = check_awq_gemm_compatibility(ar.model, ar.bits, ar.group_size, ar.sym, ar.layer_config) + if not awq_supported: + logger.warning(f"The AutoAWQ format may not be supported due to {info}") + super().__init__(format, ar) + + +@OutputFormat.register("itrex") +@OutputFormat.register("itrex_xpu") +class ITREXFormat(OutputFormat): + support_schemes = ["W4A16", "W2A16", "W3A16", "W8A16", "BF16", "W2A16G64", "W2A16G32"] + format_name = "itrex" + + +@OutputFormat.register("gguf") +class GGUFFormat(OutputFormat): + support_schemes = [ + "GGUF:Q4_0", + "GGUF:Q4_1", + "GGUF:Q5_0", + "GGUF:Q5_1", + "GGUF:Q2_K_S", + "GGUF:Q3_K_S", + "GGUF:Q3_K_M", + "GGUF:Q3_K_L", + "GGUF:Q4_K_S", + "GGUF:Q4_K_M", + "GGUF:Q5_K_S", + "GGUF:Q5_K_M", + "GGUF:Q6_K", + "GGUF:Q8_0", + ] + format_name = "gguf" + + def __init__(self, format: str, ar: BaseCompressor): + gguf_args_check(ar, format, model_type=ModelType.TEXT) + if ar.mllm: + gguf_args_check(ar, format, model_type=ModelType.MMPROJ) + ar.scheme = format.upper() + + self.output_format = format + self.backend_cls = GGUFFormat + self.backend = None + + +@OutputFormat.register("auto_round") +@OutputFormat.register("auto_round:auto_awq") +@OutputFormat.register("auto_round:llm_compressor") +@OutputFormat.register("auto_round:gptqmodel", "auto_round:auto_gptq") +class AutoRoundFormat(OutputFormat): + support_schemes = [ + "W4A16", + "W2A16", + "W3A16", + "W8A16", + "MXFP4", + "MXFP8", + "NVFP4", + "FPW8A16", + "W2A16G64", + "W2A16G32", + "FP8_STATIC", + "BF16", + ] + format_name = "auto_round" + + def __init__(self, format: str, ar: BaseCompressor): + self.output_format = "auto_round" + self.backend = None + + if format == "auto_round": + if ar.sym and "int" in ar.data_type: + self.backend = AutoGPTQFormat("auto_gptq", ar) + elif ar.bits == 4 and not ar.sym and "int" in ar.data_type: + enable_awq = all( + config["bits"] == ar.bits or config["bits"] >= 16 for config in ar.layer_config.values() + ) + if enable_awq: + self.backend = AutoAWQFormat("auto_awq", ar) + elif is_nv_fp(ar.data_type) or is_mx_fp(ar.data_type): + self.backend = AutoRoundFormat(ar.data_type, ar) + elif is_static_wfp8afp8(ar): # static wfp8afp8 + self.backend = AutoRoundFormat(AutoRoundExportFormat.FP8_STATIC.value, ar) + elif ar.data_type.startswith("fp") and ar.bits == 8 and ar.act_bits >= 16: # woq fp8 + self.backend = AutoRoundFormat(AutoRoundExportFormat.FP8.value, ar) + elif ar.act_bits < 16: + raise ValueError( + "AutoRound format does not support exporting " + "for the current quantization configuration, " + "please change to `fake` format for research purpose" + ) + elif not format.startswith("auto_round"): + self.output_format = f"auto_round:{format}" + self.backend = None + else: + backend = format.split(":")[1] if ":" in format else None + self.backend = self._format_list.get(backend)(format, ar) if backend else None + + if self.backend is not None: + self.support_schemes = self.backend.support_schemes diff --git a/auto_round/inference/convert_model.py b/auto_round/inference/convert_model.py index 7a55eda03..c1c7b650a 100644 --- a/auto_round/inference/convert_model.py +++ b/auto_round/inference/convert_model.py @@ -20,7 +20,7 @@ from tqdm import tqdm from transformers.pytorch_utils import Conv1D -from auto_round.export.export_to_autoround import AutoRoundFormat +from auto_round.export.export_to_autoround import AutoRoundExportFormat from auto_round.inference.backend import ( BackendInfos, dynamic_import_inference_linear, @@ -414,10 +414,10 @@ def _create_quant_layer(layer, layer_backend, config, in_features, out_features) bias=bias, ) elif ( - AutoRoundFormat.FP8_STATIC.value in layer_backend - or AutoRoundFormat.MXFP8.value in layer_backend - or AutoRoundFormat.MXFP4.value in layer_backend - or AutoRoundFormat.NVFP4.value in layer_backend + AutoRoundExportFormat.FP8_STATIC.value in layer_backend + or AutoRoundExportFormat.MXFP8.value in layer_backend + or AutoRoundExportFormat.MXFP4.value in layer_backend + or AutoRoundExportFormat.NVFP4.value in layer_backend ): return QuantLinear.from_original(config, layer) diff --git a/auto_round/special_model_handler.py b/auto_round/special_model_handler.py index fe3865b85..ef688634f 100644 --- a/auto_round/special_model_handler.py +++ b/auto_round/special_model_handler.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import auto_round.modelling as auto_round_modelling +from auto_round.formats import OutputFormat from auto_round.utils import LazyImport, logger, unsupported_meta_device mllms_with_limited_bs = ("llava", "qwen2_vl", "phi3_v", "mllama") # Limitations on batch_size @@ -67,8 +68,8 @@ def _handle_special_model(model): return model -def _handle_moe_model(model, formats=None): - if formats is not None and any(["gguf" in format_ for format_ in formats]): +def _handle_moe_model(model, formats: list[OutputFormat] = None): + if formats is not None and any([format_.is_gguf() for format_ in formats]): return model if hasattr(model.config, "model_type") and model.config.model_type in CONVERT_EXPERT_TO_LINEAR_MODELS: from tqdm import tqdm diff --git a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py index 0ff5241ab..75e298acd 100644 --- a/auto_round_extension/vllm_ext/moe_impl_mxfp4.py +++ b/auto_round_extension/vllm_ext/moe_impl_mxfp4.py @@ -283,7 +283,6 @@ def revert_interleaved_bias(bias): return revert_bias - # breakpoint() if self.has_bias: if envs.VLLM_AR_POST_PROCESS_GPTOSS: w13_bias_swapped = revert_interleaved_bias(layer.w13_bias) diff --git a/test/test_cpu/test_auto_scheme.py b/test/test_cpu/test_auto_scheme.py index a0eb21001..af997675d 100644 --- a/test/test_cpu/test_auto_scheme.py +++ b/test/test_cpu/test_auto_scheme.py @@ -23,3 +23,7 @@ def test_auto_scheme_export(self): ar = AutoRound(model=model_name, scheme=scheme) ar.quantize_and_save(self.save_dir) shutil.rmtree(self.save_dir, ignore_errors=True) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/test_cpu/test_mx_quant_linear.py b/test/test_cpu/test_mx_quant_linear.py index e8e18c3bb..a9a7095f7 100644 --- a/test/test_cpu/test_mx_quant_linear.py +++ b/test/test_cpu/test_mx_quant_linear.py @@ -3,14 +3,14 @@ from auto_round.data_type.utils import get_quant_func from auto_round.experimental import qmodules as ar_qmodules -from auto_round.export.export_to_autoround import AutoRoundFormat +from auto_round.export.export_to_autoround import AutoRoundExportFormat from auto_round.export.export_to_autoround.qlinear_fp import QuantLinear as _MXFPLinear from auto_round.schemes import PRESET_SCHEMES -mx_schemes = [AutoRoundFormat.MXFP8.value, AutoRoundFormat.MXFP4.value] +mx_schemes = [AutoRoundExportFormat.MXFP8.value, AutoRoundExportFormat.MXFP4.value] QMODULE_MAPPING = { - AutoRoundFormat.MXFP8.value: ar_qmodules.MXFP8QuantLinear, - AutoRoundFormat.MXFP4.value: ar_qmodules.MXFP4QuantLinear, + AutoRoundExportFormat.MXFP8.value: ar_qmodules.MXFP8QuantLinear, + AutoRoundExportFormat.MXFP4.value: ar_qmodules.MXFP4QuantLinear, } diff --git a/test/test_cpu/test_mxfp_save_load.py b/test/test_cpu/test_mxfp_save_load.py index aca5c7592..5e6e89093 100644 --- a/test/test_cpu/test_mxfp_save_load.py +++ b/test/test_cpu/test_mxfp_save_load.py @@ -9,22 +9,22 @@ from auto_round import AutoRound from auto_round import schemes as ar_schemes from auto_round.experimental import qmodules as ar_qmodules -from auto_round.export.export_to_autoround import AutoRoundFormat +from auto_round.export.export_to_autoround import AutoRoundExportFormat from auto_round.export.export_to_autoround import qlinear_fp as ar_qlinear_fp from auto_round.inference.backend import MX_TENSOR_DATA_TYPES from auto_round.testing_utils import has_module testing_scheme_name_lst = [ - AutoRoundFormat.MXFP8.value, - AutoRoundFormat.MXFP4.value, + AutoRoundExportFormat.MXFP8.value, + AutoRoundExportFormat.MXFP4.value, ] QMODULE_MAPPING = { - AutoRoundFormat.MXFP8.value: ar_qmodules.MXFP8QuantLinear, - AutoRoundFormat.MXFP4.value: ar_qmodules.MXFP4QuantLinear, + AutoRoundExportFormat.MXFP8.value: ar_qmodules.MXFP8QuantLinear, + AutoRoundExportFormat.MXFP4.value: ar_qmodules.MXFP4QuantLinear, } SCHEMES_MAPPING = { - AutoRoundFormat.MXFP8.value: ar_schemes.MXFP8, - AutoRoundFormat.MXFP4.value: ar_schemes.MXFP4, + AutoRoundExportFormat.MXFP8.value: ar_schemes.MXFP8, + AutoRoundExportFormat.MXFP4.value: ar_schemes.MXFP4, } diff --git a/test/test_cpu/test_nvfp4_quant_linear.py b/test/test_cpu/test_nvfp4_quant_linear.py index 0a42f009a..33c32f466 100644 --- a/test/test_cpu/test_nvfp4_quant_linear.py +++ b/test/test_cpu/test_nvfp4_quant_linear.py @@ -4,12 +4,12 @@ from auto_round.data_type.nvfp import calculate_gparam from auto_round.data_type.utils import get_quant_func from auto_round.experimental import qmodules as ar_qmodules -from auto_round.export.export_to_autoround import AutoRoundFormat +from auto_round.export.export_to_autoround import AutoRoundExportFormat from auto_round.export.export_to_autoround.qlinear_fp import QuantLinear as _FPLinear from auto_round.schemes import PRESET_SCHEMES QMODULE_MAPPING = { - AutoRoundFormat.NVFP4.value: ar_qmodules.NVFP4QuantLinear, + AutoRoundExportFormat.NVFP4.value: ar_qmodules.NVFP4QuantLinear, } @@ -26,7 +26,7 @@ def fixed_seed(): # (Optional) cleanup or reset after test -@pytest.mark.parametrize("scheme", [AutoRoundFormat.NVFP4.value]) +@pytest.mark.parametrize("scheme", [AutoRoundExportFormat.NVFP4.value]) @torch.inference_mode() def test_nvfp4_quantlinear_from_original_and_forward(scheme): """ diff --git a/test/test_cuda/test_mxfp_and_nvfp_quant.py b/test/test_cuda/test_mxfp_and_nvfp_quant.py index 0dc43b093..829954e9d 100644 --- a/test/test_cuda/test_mxfp_and_nvfp_quant.py +++ b/test/test_cuda/test_mxfp_and_nvfp_quant.py @@ -8,15 +8,19 @@ from auto_round import AutoRound from auto_round import schemes as ar_schemes from auto_round.experimental import qmodules as ar_qmodules -from auto_round.export.export_to_autoround import AutoRoundFormat +from auto_round.export.export_to_autoround import AutoRoundExportFormat from auto_round.export.export_to_autoround import qlinear_fp as ar_qlinear_fp from auto_round.testing_utils import has_module -testing_schemes = [AutoRoundFormat.MXFP8.value, AutoRoundFormat.MXFP4.value, AutoRoundFormat.NVFP4.value] +testing_schemes = [ + AutoRoundExportFormat.MXFP8.value, + AutoRoundExportFormat.MXFP4.value, + AutoRoundExportFormat.NVFP4.value, +] QMODULE_MAPPING = { - AutoRoundFormat.MXFP8.value: ar_qmodules.MXFP8QuantLinear, - AutoRoundFormat.MXFP4.value: ar_qmodules.MXFP4QuantLinear, - AutoRoundFormat.NVFP4.value: ar_qmodules.NVFP4QuantLinear, + AutoRoundExportFormat.MXFP8.value: ar_qmodules.MXFP8QuantLinear, + AutoRoundExportFormat.MXFP4.value: ar_qmodules.MXFP4QuantLinear, + AutoRoundExportFormat.NVFP4.value: ar_qmodules.NVFP4QuantLinear, }