1515
1616from llmc .compression .quantization .module_utils import (
1717 _LLMC_LINEAR_TYPES_ , _LLMC_LN_TYPES_ , _TRANSFORMERS_LINEAR_TYPES_ ,
18- _TRANSFORMERS_LN_TYPES_ , AutoawqQuantLinearInt4 , LlmcFp8Linear ,
19- VllmQuantLinearFp8 , VllmQuantLinearInt8 )
18+ _TRANSFORMERS_LN_TYPES_ , LlmcFp8Linear , VllmQuantLinearFp8 ,
19+ VllmQuantLinearInt8 )
2020
2121
2222class BaseModel (metaclass = ABCMeta ):
@@ -27,7 +27,7 @@ def __init__(self, config, device_map=None, use_cache=False):
2727 self .tokenizer_mode = self .config .model .get ('tokenizer_mode' , 'fast' )
2828 self .use_cpu_to_save_cuda_mem_for_catcher = self .config .model .get ('use_cpu_to_save_cuda_mem_for_catcher' , False ) # noqa
2929 torch_dtype = self .config .model .torch_dtype
30- self .torch_dtype = torch_dtype if torch_dtype in ['auto' , 'int4' ] else eval (torch_dtype )
30+ self .torch_dtype = torch_dtype if torch_dtype in ['auto' ] else eval (torch_dtype )
3131 self .block_wise_quant = self .config .model .get ('block_wise_quant' , False )
3232 if self .block_wise_quant :
3333 assert self .torch_dtype == torch .float8_e4m3fn
@@ -202,7 +202,7 @@ def build_model(self):
202202 if hasattr (self .model_config , 'use_cache' ):
203203 self .model_config .use_cache = False
204204 logger .info (f'self.model_config : { self .model_config } ' )
205- if self .torch_dtype in [torch .float8_e4m3fn , torch .int8 , 'int4' ]:
205+ if self .torch_dtype in [torch .float8_e4m3fn , torch .int8 ]:
206206 with init_empty_weights ():
207207 self .model = AutoModelForCausalLM .from_config (config = self .model_config ,
208208 torch_dtype = torch .float16 ,
@@ -220,9 +220,6 @@ def build_model(self):
220220 elif self .torch_dtype == torch .int8 :
221221 params_dict = {}
222222 quant_linear_cls = VllmQuantLinearInt8
223- elif self .torch_dtype == 'int4' :
224- params_dict = {}
225- quant_linear_cls = AutoawqQuantLinearInt4
226223
227224 for block_idx , block in enumerate (self .blocks ):
228225 self .replace_module_block (quant_linear_cls ,
0 commit comments