Skip to content

Commit 2d96f20

Browse files
author
gushiqiao
committed
Support load from quantized weights
1 parent 9de6db6 commit 2d96f20

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

llmc/models/base_model.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
from llmc.compression.quantization.module_utils import (
1717
_LLMC_LINEAR_TYPES_, _LLMC_LN_TYPES_, _TRANSFORMERS_LINEAR_TYPES_,
18-
_TRANSFORMERS_LN_TYPES_, AutoawqQuantLinearInt4, LlmcFp8Linear,
19-
VllmQuantLinearFp8, VllmQuantLinearInt8)
18+
_TRANSFORMERS_LN_TYPES_, LlmcFp8Linear, VllmQuantLinearFp8,
19+
VllmQuantLinearInt8)
2020

2121

2222
class BaseModel(metaclass=ABCMeta):
@@ -27,7 +27,7 @@ def __init__(self, config, device_map=None, use_cache=False):
2727
self.tokenizer_mode = self.config.model.get('tokenizer_mode', 'fast')
2828
self.use_cpu_to_save_cuda_mem_for_catcher = self.config.model.get('use_cpu_to_save_cuda_mem_for_catcher', False) # noqa
2929
torch_dtype = self.config.model.torch_dtype
30-
self.torch_dtype = torch_dtype if torch_dtype in ['auto', 'int4'] else eval(torch_dtype)
30+
self.torch_dtype = torch_dtype if torch_dtype in ['auto'] else eval(torch_dtype)
3131
self.block_wise_quant = self.config.model.get('block_wise_quant', False)
3232
if self.block_wise_quant:
3333
assert self.torch_dtype == torch.float8_e4m3fn
@@ -202,7 +202,7 @@ def build_model(self):
202202
if hasattr(self.model_config, 'use_cache'):
203203
self.model_config.use_cache = False
204204
logger.info(f'self.model_config : {self.model_config}')
205-
if self.torch_dtype in [torch.float8_e4m3fn, torch.int8, 'int4']:
205+
if self.torch_dtype in [torch.float8_e4m3fn, torch.int8]:
206206
with init_empty_weights():
207207
self.model = AutoModelForCausalLM.from_config(config=self.model_config,
208208
torch_dtype=torch.float16,
@@ -220,9 +220,6 @@ def build_model(self):
220220
elif self.torch_dtype == torch.int8:
221221
params_dict = {}
222222
quant_linear_cls = VllmQuantLinearInt8
223-
elif self.torch_dtype == 'int4':
224-
params_dict = {}
225-
quant_linear_cls = AutoawqQuantLinearInt4
226223

227224
for block_idx, block in enumerate(self.blocks):
228225
self.replace_module_block(quant_linear_cls,

0 commit comments

Comments
 (0)