Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 17 additions & 13 deletions llmc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from llmc.compression.quantization import *
from llmc.compression.sparsification import *
from llmc.data import BaseDataset, BaseTokenizer
from llmc.eval import AccuracyEval, PerplexityEval, TokenConsistencyEval
from llmc.eval import (AccuracyEval, PerplexityEval, TokenConsistencyEval,
VLMEval)
from llmc.models import *
from llmc.utils import (check_config, mkdirs, print_important_package_version,
seed_all, update_autoawq_quant_config,
Expand Down Expand Up @@ -48,6 +49,9 @@ def main(config):
if config.eval.type == 'acc':
acc_eval = AccuracyEval(eval_config)
eval_list.append(acc_eval)
elif config.eval.type == 'img_txt':
acc_eval = VLMEval(eval_config)
eval_list.append(acc_eval)
else:
ppl_eval = PerplexityEval(tokenizer.get_tokenizer(), eval_config)
eval_list.append(ppl_eval)
Expand All @@ -57,6 +61,10 @@ def main(config):
for acc_eval in eval_list:
acc = acc_eval.eval(model)
logger.info(f'{config.eval.name} acc : {acc}')
elif config.eval.type == 'img_txt':
for vlm_eval in eval_list:
results = vlm_eval.eval(model, tokenizer)
logger.info(f'{config.eval.name} results : {results}')
else:
for ppl_eval in eval_list:
ppl = ppl_eval.eval(model)
Expand All @@ -76,18 +84,6 @@ def main(config):
dataset = BaseDataset(tokenizer.get_tokenizer(), config.calib, model.batch_process)
calib_data, padding_mask = dataset.get_calib_dataset()
padding_side = getattr(tokenizer.get_tokenizer(), 'padding_side', None)
if config.calib.type == 'img_txt':
model.collect_first_encoder_block_input(calib_data, padding_mask,
padding_side, config.calib.type)
blockwise_opt = ALGO_REGISTRY[config.quant.method](
model,
config.quant,
model.get_first_block_input(),
model.get_padding_mask(),
config,
'vision'
)
blockwise_opt.run_block_loop()
model.collect_first_block_input(calib_data, padding_mask, padding_side, config.calib.type)
del calib_data
gc.collect()
Expand Down Expand Up @@ -118,6 +114,10 @@ def main(config):
for acc_eval in eval_list:
acc = acc_eval.eval(model)
logger.info(f'{config.eval.name} acc : {acc}')
elif config.eval.type == 'img_txt':
for vlm_eval in eval_list:
results = vlm_eval.eval(model, tokenizer)
logger.info(f'{config.eval.name} results : {results}')
else:
for ppl_eval in eval_list:
ppl = ppl_eval.eval(model)
Expand All @@ -142,6 +142,10 @@ def main(config):
for acc_eval in eval_list:
acc = acc_eval.eval(model)
logger.info(f'{config.eval.name} acc : {acc}')
elif config.eval.type == 'img_txt':
for vlm_eval in eval_list:
results = vlm_eval.eval(model, tokenizer)
logger.info(f'{config.eval.name} results : {results}')
else:
for ppl_eval in eval_list:
ppl = ppl_eval.eval(model)
Expand Down
35 changes: 25 additions & 10 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@


class BaseBlockwiseQuantization(BlockwiseOpt):
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
self.set_quant_config()

def w_qdq(self, module, wquantizer):
Expand Down Expand Up @@ -63,6 +63,9 @@ def a_qdq(self, act, module, aquantizer, input_index=0):
else:
return aquantizer.fake_quant_act_dynamic(act)

def logit(self, x):
return torch.log(x / (1 - x))

def get_replacement_params(self, mode='fake_quant', w_only=False, name=None):
params_dict = {}
if mode == 'fake_quant':
Expand Down Expand Up @@ -268,6 +271,9 @@ def set_quant_config(self):
self.intermediate_size = self.model.model_config.intermediate_size
self.fp32_had = special_config.get('fp32_had', False)

self.quant_objects = self.quant_config.get('quant_objects', ['language'])
logger.info(f'self.quant_objects : {self.quant_objects}')

def replace_rotate_linears(self, block):
for n, m in block.named_modules():
if isinstance(m, nn.Linear) and ('down_proj' in n
Expand Down Expand Up @@ -433,8 +439,7 @@ def run(self, block, input_feat, handles):

def block_transform(self, block, input_feat, block_kwargs):
logger.info(f'Start transform the {self.block_idx}-th block')
subsets = self.model.get_subsets_in_block(block) \
if self.modality == 'language' else self.model.get_encoder_subsets_in_block(block)
subsets = self.model.get_subsets_in_block(block)

if self.act_static:
self.register_non_linear_qparams(block, input_feat)
Expand Down Expand Up @@ -804,12 +809,22 @@ def deploy(self, quant_format, keep_device=False):
)

module = module_mapping[quant_format]
self.model.replace_module_all(
module,
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
keep_device=keep_device
)
self.set_non_linear_mode(quant_format, self.model.model, False)
if 'vision' in self.quant_objects:
self.model.replace_vision_module_all(
module,
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
keep_device=keep_device
)
if 'language' in self.quant_objects:
self.model.replace_language_module_all(
module,
self.get_replacement_params(mode=quant_format, w_only=self.w_only),
keep_device=keep_device
)
self.set_non_linear_mode(quant_format, self.model.model, False)

if hasattr(self.model, 'vlm_model'):
logger.info(f'Now, the vlm_model is: {self.model.vlm_model}')

logger.info(f'-- deploy_{quant_format}_model done --')

Expand Down
2 changes: 1 addition & 1 deletion llmc/compression/quantization/llmint8.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def deploy(self, quant_format):
logger.info(f'-- deploy_{quant_format}_model start --')
logger.info(f'quant_config : {self.quant_config}')

self.model.replace_module_all(
self.model.replace_language_module_all(
FakeQuantLinear,
self.get_replacement_params(
mode='fake_quant', w_only=self.w_only, name=None
Expand Down
1 change: 1 addition & 0 deletions llmc/eval/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .eval_acc import AccuracyEval
from .eval_ppl import PerplexityEval
from .eval_token_consist import TokenConsistencyEval
from .eval_vlm import VLMEval
Loading
Loading