From 98dcb538c152d9f1854e49d846728320701f1695 Mon Sep 17 00:00:00 2001 From: chengtao-lv <897674362@qq.com> Date: Thu, 14 Nov 2024 13:59:50 +0800 Subject: [PATCH 1/2] update vlm --- llmc/__main__.py | 12 ++++ llmc/compression/blockwise_optimization.py | 5 +- .../base_blockwise_quantization.py | 7 +- llmc/compression/quantization/rtn.py | 12 ++-- llmc/data/dataset/base_dataset.py | 7 +- llmc/models/base_model.py | 64 ++++++++++++++++-- llmc/models/llava.py | 66 +++++++++++++++++++ 7 files changed, 153 insertions(+), 20 deletions(-) diff --git a/llmc/__main__.py b/llmc/__main__.py index 9c7a376e8..f00fdd8d6 100644 --- a/llmc/__main__.py +++ b/llmc/__main__.py @@ -76,6 +76,18 @@ def main(config): dataset = BaseDataset(tokenizer.get_tokenizer(), config.calib, model.batch_process) calib_data, padding_mask = dataset.get_calib_dataset() padding_side = getattr(tokenizer.get_tokenizer(), 'padding_side', None) + if config.calib.type == 'img_txt': + model.collect_first_encoder_block_input(calib_data, padding_mask, + padding_side, config.calib.type) + blockwise_opt = ALGO_REGISTRY[config.quant.method]( + model, + config.quant, + model.get_first_block_input(), + model.get_padding_mask(), + config, + 'vision' + ) + blockwise_opt.run_block_loop() model.collect_first_block_input(calib_data, padding_mask, padding_side, config.calib.type) del calib_data gc.collect() diff --git a/llmc/compression/blockwise_optimization.py b/llmc/compression/blockwise_optimization.py index c00e83775..4174fa259 100644 --- a/llmc/compression/blockwise_optimization.py +++ b/llmc/compression/blockwise_optimization.py @@ -6,9 +6,10 @@ class BlockwiseOpt(metaclass=ABCMeta): - def __init__(self, model, quant_config, input, padding_mask, config): + def __init__(self, model, quant_config, input, padding_mask, config, modality): self.model = model - self.blocks = model.get_blocks() + self.modality = modality + self.blocks = model.get_blocks(modality) self.quant_config = quant_config self.sparsity_config = quant_config self.input = input diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index a791b6161..743a34a47 100644 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -27,8 +27,8 @@ class BaseBlockwiseQuantization(BlockwiseOpt): - def __init__(self, model, quant_config, input, padding_mask, config): - super().__init__(model, quant_config, input, padding_mask, config) + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): + super().__init__(model, quant_config, input, padding_mask, config, modality) self.set_quant_config() def w_qdq(self, module, wquantizer): @@ -436,7 +436,8 @@ def run(self, block, input_feat, handles): def block_transform(self, block, input_feat, block_kwargs): logger.info(f'Start transform the {self.block_idx}-th block') - subsets = self.model.get_subsets_in_block(block) + subsets = self.model.get_subsets_in_block(block) \ + if self.modality == 'language' else self.model.get_encoder_subsets_in_block(block) if self.act_static: self.register_non_linear_qparams(block, input_feat) diff --git a/llmc/compression/quantization/rtn.py b/llmc/compression/quantization/rtn.py index 35ba840cf..599ea37e1 100644 --- a/llmc/compression/quantization/rtn.py +++ b/llmc/compression/quantization/rtn.py @@ -8,13 +8,13 @@ @ALGO_REGISTRY class RTN(BaseBlockwiseQuantization): - def __init__(self, model, quant_config, input, padding_mask, config): - super().__init__(model, quant_config, input, padding_mask, config) + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): + super().__init__(model, quant_config, input, padding_mask, config, modality) - @torch.no_grad() - def block_opt(self, *opt_kwargs): - if self.act_static: - super().block_opt(*opt_kwargs) + # @torch.no_grad() + # def block_opt(self, *opt_kwargs): + # if self.act_static: + # super().block_opt(*opt_kwargs) @torch.no_grad() def subset_transform( diff --git a/llmc/data/dataset/base_dataset.py b/llmc/data/dataset/base_dataset.py index 306f373aa..ad847a5b2 100644 --- a/llmc/data/dataset/base_dataset.py +++ b/llmc/data/dataset/base_dataset.py @@ -255,9 +255,10 @@ def get_calib_dataset(self): samples = self.get_calib_samples() if self.calib_dataset_type in ['txt', 'img', 'img_txt']: logger.info(f'len(samples) all : {len(samples)}') - assert len(samples) % int(os.environ['WORLD_SIZE']) == 0 - samples = samples[int(os.environ['RANK'])::int(os.environ['WORLD_SIZE'])] - logger.info(f'len(samples) rank : {len(samples)}') + if os.environ.get('WORLD_SIZE') is not None: + assert len(samples) % int(os.environ['WORLD_SIZE']) == 0 + samples = samples[int(os.environ['RANK'])::int(os.environ['WORLD_SIZE'])] + logger.info(f'len(samples) rank : {len(samples)}') calib_samples = [] if self.calib_dataset_type == 'txt': if self.padding: diff --git a/llmc/models/base_model.py b/llmc/models/base_model.py index 057b46032..11993d750 100644 --- a/llmc/models/base_model.py +++ b/llmc/models/base_model.py @@ -28,6 +28,7 @@ def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): self.build_model() self.model.eval() self.find_blocks() + self.find_encoder_blocks() self.find_embed_layers() self.find_block_name() self.add_layernorms_class() @@ -36,14 +37,20 @@ def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): def find_blocks(self): pass + def find_encoder_blocks(self): + pass + + def get_encoder_catcher(self, first_block_input): + pass + def find_block_name(self): pass def get_model(self): return self.model - def get_blocks(self): - return self.blocks + def get_blocks(self, modality='language'): + return self.blocks if modality == 'language' else self.encoder_blocks @abstractmethod def find_embed_layers(self): @@ -186,6 +193,43 @@ def collect_first_block_input(self, calib_data, padding_mask=None, padding_side= self.blocks[0] = self.blocks[0].cpu() self.move_embed_to_device('cpu') + @torch.no_grad() + def collect_first_encoder_block_input(self, calib_data, padding_mask=None, padding_side=None, data_type='txt'): # noqa + first_block_input = defaultdict(list) + + Catcher = self.get_encoder_catcher(first_block_input) + + self.move_embed_to_device('cuda') + if data_type == 'img_txt': + self.vision_model = self.vision_model.to('cuda') + self.projector = self.projector.to('cuda') + self.encoder_blocks[0] = self.encoder_blocks[0].cuda() + self.encoder_blocks[0] = Catcher(self.encoder_blocks[0]) + + for data in calib_data: + if isinstance(data, BatchFeature): + data = data.to(next(self.model.parameters()).device) + else: + data = { + k: (v.to(next(self.model.parameters()).device) if torch.is_tensor(v) else v) + for k, v in data.items() + } + try: + if data_type in ['txt', 'img']: + self.model(**data) + elif data_type == 'img_txt': + self.vlm_model.generate(**data, max_new_tokens=128, do_sample=False) + except ValueError: + pass + self.first_block_input = first_block_input + self.padding_mask = None + if data_type == 'img_txt': + self.vision_model = self.vision_model.cpu() + self.projector = self.projector.cpu() + self.encoder_blocks[0] = self.encoder_blocks[0].module + self.encoder_blocks[0] = self.encoder_blocks[0].cpu() + self.move_embed_to_device('cpu') + def get_one_pad_setting(self, padding_side, length): if padding_side == 'left': return [0, length] @@ -280,10 +324,10 @@ def set_mix_bits_params_dict(self, block_idx, name, params_dict): params_mix_dict['a_qdq'] = None return params_mix_dict - def replace_module_all(self, module, params_dict, keep_device=False): - for block_idx in range(len(self.blocks)): - logger.info(f'Replace block index: {block_idx}/{len(self.blocks)}') - block = self.blocks[block_idx] + def replace_modality_module_all(self, module, blocks, params_dict, keep_device=False): + for block_idx in range(len(blocks)): + logger.info(f'Replace block index: {block_idx}/{len(blocks)}') + block = blocks[block_idx] if keep_device: self.replace_module_block(module, block, block_idx, params_dict) else: @@ -291,6 +335,14 @@ def replace_module_all(self, module, params_dict, keep_device=False): self.replace_module_block(module, block, block_idx, params_dict) block = block.cpu() + def replace_module_all(self, module, params_dict, keep_device=False): + if hasattr(self, 'encoder_blocks'): + logger.info('start replace vision blocks') + self.replace_modality_module_all(module, self.encoder_blocks, params_dict, keep_device) + + logger.info('start replace language blocks') + self.replace_modality_module_all(module, self.blocks, params_dict, keep_device) + gc.collect() torch.cuda.empty_cache() logger.info(f'The Replaced model: {self.model}') diff --git a/llmc/models/llava.py b/llmc/models/llava.py index f203f457a..0211700de 100644 --- a/llmc/models/llava.py +++ b/llmc/models/llava.py @@ -1,3 +1,6 @@ +import inspect + +import torch.nn as nn from loguru import logger from PIL import Image from transformers import (AutoConfig, AutoProcessor, @@ -31,6 +34,31 @@ def build_model(self): self.model = self.vlm_model.language_model self.model_config = self.vlm_model_config.text_config + def find_encoder_blocks(self): + self.encoder_blocks = self.vision_model.vision_model.encoder.layers + + def get_encoder_catcher(self, first_block_input): + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + self.signature = inspect.signature(module.forward) + + def forward(self, *args, **kwargs): + params = list(self.signature.parameters.keys()) + for i, arg in enumerate(args): + if i > 0: + kwargs[params[i]] = arg + first_block_input['data'].append(args[0]) + if 'output_router_logits' in kwargs: + assert kwargs['output_router_logits'] is False + kwargs.pop('output_router_logits') + first_block_input['kwargs'].append(kwargs) + raise ValueError + + return Catcher + def batch_process(self, img_qas): if len(img_qas) == 1: return self.single_process(img_qas[0]) @@ -83,3 +111,41 @@ def single_process(self, img_qas): return_tensors='pt' ).to(next(self.vlm_model.parameters()).dtype) # noqa return inputs + + def get_encoder_subsets_in_block(self, block): + return [ + { + 'layers': { + 'self_attn.q_proj': block.self_attn.q_proj, + 'self_attn.k_proj': block.self_attn.k_proj, + 'self_attn.v_proj': block.self_attn.v_proj, + }, + 'prev_op': [block.layer_norm1], + 'input': ['self_attn.q_proj'], + 'inspect': block.self_attn, + 'has_kwargs': True, + }, + { + 'layers': {'self_attn.out_proj': block.self_attn.out_proj}, + 'prev_op': [block.self_attn.v_proj], + 'input': ['self_attn.out_proj'], + 'inspect': block.self_attn.out_proj, + 'has_kwargs': False, + }, + { + 'layers': {'mlp.fc1': block.mlp.fc1}, + 'prev_op': [block.layer_norm2], + 'input': ['mlp.fc1'], + 'inspect': block.mlp.fc1, + 'has_kwargs': False, + 'is_mlp': True, + }, + { + 'layers': {'mlp.fc2': block.mlp.fc2}, + 'prev_op': [block.mlp.fc1], + 'input': ['mlp.fc2'], + 'inspect': block.mlp.fc2, + 'has_kwargs': False, + 'is_mlp': True, + }, + ] From 0028a28294086d0c8214c3a5314d74c210a5d309 Mon Sep 17 00:00:00 2001 From: chengtao-lv <897674362@qq.com> Date: Thu, 14 Nov 2024 15:59:39 +0800 Subject: [PATCH 2/2] fix vlm bug --- .../quantization/base_blockwise_quantization.py | 3 --- llmc/compression/quantization/rtn.py | 8 ++++---- 2 files changed, 4 insertions(+), 7 deletions(-) diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 743a34a47..6bb0bf273 100644 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -63,9 +63,6 @@ def a_qdq(self, act, module, aquantizer, input_index=0): else: return aquantizer.fake_quant_act_dynamic(act) - def logit(self, x): - return torch.log(x / (1 - x)) - def get_replacement_params(self, mode='fake_quant', w_only=False, name=None): params_dict = {} if mode == 'fake_quant': diff --git a/llmc/compression/quantization/rtn.py b/llmc/compression/quantization/rtn.py index 599ea37e1..966524c87 100644 --- a/llmc/compression/quantization/rtn.py +++ b/llmc/compression/quantization/rtn.py @@ -11,10 +11,10 @@ class RTN(BaseBlockwiseQuantization): def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): super().__init__(model, quant_config, input, padding_mask, config, modality) - # @torch.no_grad() - # def block_opt(self, *opt_kwargs): - # if self.act_static: - # super().block_opt(*opt_kwargs) + @torch.no_grad() + def block_opt(self, *opt_kwargs): + if self.act_static: + super().block_opt(*opt_kwargs) @torch.no_grad() def subset_transform(