5555from auto_round .export .export_to_autoround import AutoRoundFormat
5656from auto_round .export .export_to_gguf .config import GGUF_INNER_CONFIG , ModelType
5757from auto_round .logger import logger
58- from auto_round .schemes import QuantizationScheme , get_gguf_scheme , preset_name_to_scheme
58+ from auto_round .schemes import (
59+ SPECIAL_SCHEMES ,
60+ QuantizationScheme ,
61+ _handle_special_schemes ,
62+ get_gguf_scheme ,
63+ preset_name_to_scheme ,
64+ )
5965from auto_round .sign_sgd import SignSGD
6066from auto_round .special_model_handler import _handle_moe_model
6167from auto_round .utils import (
@@ -214,6 +220,33 @@ def __init__(
214220 ... }
215221 """
216222
223+ # Model related
224+ model_dtype = kwargs .pop ("model_dtype" , None )
225+ self .mllm = kwargs .pop ("mllm" ) if "mllm" in kwargs else False
226+ self .diffusion = kwargs .pop ("diffusion" ) if "diffusion" in kwargs else False
227+ self .quantized = False
228+ if isinstance (model , str ):
229+ model , tokenizer = llm_load_model (
230+ model ,
231+ platform = platform ,
232+ device = "cpu" , # always load cpu first
233+ model_dtype = model_dtype ,
234+ )
235+ elif tokenizer is None and not self .diffusion and iters > 0 :
236+ raise ValueError ("A tokenizer must be set for non-str model input" )
237+ if unsupported_meta_device (model ):
238+ raise RuntimeError (
239+ "AutoRound does not support parameters on meta device. "
240+ "Please use more GPUs by setting `--device 0,1,2,3` or just place the model on CPU."
241+ )
242+ check_and_mark_fp8_model (model )
243+ self .model = model .eval ()
244+ self .tokenizer = tokenizer
245+ self .shared_cache_keys = get_shared_keys (self .model )
246+
247+ self .layer_config = layer_config
248+
249+ # should be set after loading model and set layer_config, cause some special scheme need these.
217250 self .scheme , self .is_auto_scheme = self ._parse_and_set_scheme (scheme , kwargs )
218251
219252 gguf_scheme_name = get_gguf_scheme (self .scheme )
@@ -244,11 +277,8 @@ def __init__(
244277 platform = "model_scope"
245278 self .platform = platform
246279 self .quant_lm_head = kwargs .pop ("quant_lm_head" , False )
247- self .mllm = kwargs .pop ("mllm" ) if "mllm" in kwargs else False
248- self .diffusion = kwargs .pop ("diffusion" ) if "diffusion" in kwargs else False
249280
250281 self .fp_layers = kwargs .pop ("fp_layers" , "" )
251- self .layer_config = layer_config
252282 self .supported_types = SUPPORTED_LAYER_TYPES
253283 self .inner_supported_types = INNER_SUPPORTED_LAYER_TYPES
254284 self .scale_dtype = convert_dtype_str2torch (scale_dtype )
@@ -270,27 +300,6 @@ def __init__(
270300 else :
271301 torch .use_deterministic_algorithms (True , warn_only = True )
272302
273- # Model related
274- self .quantized = False
275- if isinstance (model , str ):
276- model , tokenizer = llm_load_model (
277- model ,
278- platform = platform ,
279- device = "cpu" , # always load cpu first
280- model_dtype = model_dtype ,
281- )
282- elif tokenizer is None and not self .diffusion and iters > 0 :
283- raise ValueError ("A tokenizer must be set for non-str model input" )
284- if unsupported_meta_device (model ):
285- raise RuntimeError (
286- "AutoRound does not support parameters on meta device. "
287- "Please use more GPUs by setting `--device 0,1,2,3` or just place the model on CPU."
288- )
289- check_and_mark_fp8_model (model )
290- self .model = model .eval ()
291- self .tokenizer = tokenizer
292- self .shared_cache_keys = get_shared_keys (self .model )
293-
294303 self .to_quant_block_names = to_quant_block_names
295304 if not hasattr (self , "quant_block_list" ):
296305 all_blocks = get_block_names (model )
@@ -524,6 +533,8 @@ def _parse_and_set(scheme, kwargs):
524533 scheme = scheme .strip ("'\" " )
525534 res = scheme
526535 scheme = scheme .upper ()
536+ if scheme in SPECIAL_SCHEMES :
537+ self .layer_config = _handle_special_schemes (scheme , self .layer_config , self .model )
527538 scheme = asdict (preset_name_to_scheme (scheme ))
528539 scheme_keys = [f .name for f in fields (QuantizationScheme )]
529540 for key in scheme_keys :
@@ -776,6 +787,8 @@ def remove_duplicates(lst):
776787
777788 if gguf_format_name :
778789 for i in range (len (formats )):
790+ if gguf_format_name .lower ().endswith ("mixed" ):
791+ gguf_format_name = gguf_format_name .lower ().replace ("_mixed" , "_s" )
779792 if formats [i ] != "fake" and formats [i ] != gguf_format_name .lower ():
780793 logger .warning (
781794 f"reset format { formats [i ]} to { gguf_format_name .lower ()} "
0 commit comments