Skip to content
Merged

Vlm #187

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions llmc/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,18 @@ def main(config):
dataset = BaseDataset(tokenizer.get_tokenizer(), config.calib, model.batch_process)
calib_data, padding_mask = dataset.get_calib_dataset()
padding_side = getattr(tokenizer.get_tokenizer(), 'padding_side', None)
if config.calib.type == 'img_txt':
model.collect_first_encoder_block_input(calib_data, padding_mask,
padding_side, config.calib.type)
blockwise_opt = ALGO_REGISTRY[config.quant.method](
model,
config.quant,
model.get_first_block_input(),
model.get_padding_mask(),
config,
'vision'
)
blockwise_opt.run_block_loop()
model.collect_first_block_input(calib_data, padding_mask, padding_side, config.calib.type)
del calib_data
gc.collect()
Expand Down
5 changes: 3 additions & 2 deletions llmc/compression/blockwise_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,10 @@


class BlockwiseOpt(metaclass=ABCMeta):
def __init__(self, model, quant_config, input, padding_mask, config):
def __init__(self, model, quant_config, input, padding_mask, config, modality):
self.model = model
self.blocks = model.get_blocks()
self.modality = modality
self.blocks = model.get_blocks(modality)
self.quant_config = quant_config
self.sparsity_config = quant_config
self.input = input
Expand Down
10 changes: 4 additions & 6 deletions llmc/compression/quantization/base_blockwise_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@


class BaseBlockwiseQuantization(BlockwiseOpt):
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)
self.set_quant_config()

def w_qdq(self, module, wquantizer):
Expand Down Expand Up @@ -63,9 +63,6 @@ def a_qdq(self, act, module, aquantizer, input_index=0):
else:
return aquantizer.fake_quant_act_dynamic(act)

def logit(self, x):
return torch.log(x / (1 - x))

def get_replacement_params(self, mode='fake_quant', w_only=False, name=None):
params_dict = {}
if mode == 'fake_quant':
Expand Down Expand Up @@ -436,7 +433,8 @@ def run(self, block, input_feat, handles):

def block_transform(self, block, input_feat, block_kwargs):
logger.info(f'Start transform the {self.block_idx}-th block')
subsets = self.model.get_subsets_in_block(block)
subsets = self.model.get_subsets_in_block(block) \
if self.modality == 'language' else self.model.get_encoder_subsets_in_block(block)

if self.act_static:
self.register_non_linear_qparams(block, input_feat)
Expand Down
4 changes: 2 additions & 2 deletions llmc/compression/quantization/rtn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

@ALGO_REGISTRY
class RTN(BaseBlockwiseQuantization):
def __init__(self, model, quant_config, input, padding_mask, config):
super().__init__(model, quant_config, input, padding_mask, config)
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
super().__init__(model, quant_config, input, padding_mask, config, modality)

@torch.no_grad()
def block_opt(self, *opt_kwargs):
Expand Down
7 changes: 4 additions & 3 deletions llmc/data/dataset/base_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,9 +255,10 @@ def get_calib_dataset(self):
samples = self.get_calib_samples()
if self.calib_dataset_type in ['txt', 'img', 'img_txt']:
logger.info(f'len(samples) all : {len(samples)}')
assert len(samples) % int(os.environ['WORLD_SIZE']) == 0
samples = samples[int(os.environ['RANK'])::int(os.environ['WORLD_SIZE'])]
logger.info(f'len(samples) rank : {len(samples)}')
if os.environ.get('WORLD_SIZE') is not None:
assert len(samples) % int(os.environ['WORLD_SIZE']) == 0
samples = samples[int(os.environ['RANK'])::int(os.environ['WORLD_SIZE'])]
logger.info(f'len(samples) rank : {len(samples)}')
calib_samples = []
if self.calib_dataset_type == 'txt':
if self.padding:
Expand Down
64 changes: 58 additions & 6 deletions llmc/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False):
self.build_model()
self.model.eval()
self.find_blocks()
self.find_encoder_blocks()
self.find_embed_layers()
self.find_block_name()
self.add_layernorms_class()
Expand All @@ -36,14 +37,20 @@ def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False):
def find_blocks(self):
pass

def find_encoder_blocks(self):
pass

def get_encoder_catcher(self, first_block_input):
pass

def find_block_name(self):
pass

def get_model(self):
return self.model

def get_blocks(self):
return self.blocks
def get_blocks(self, modality='language'):
return self.blocks if modality == 'language' else self.encoder_blocks

@abstractmethod
def find_embed_layers(self):
Expand Down Expand Up @@ -186,6 +193,43 @@ def collect_first_block_input(self, calib_data, padding_mask=None, padding_side=
self.blocks[0] = self.blocks[0].cpu()
self.move_embed_to_device('cpu')

@torch.no_grad()
def collect_first_encoder_block_input(self, calib_data, padding_mask=None, padding_side=None, data_type='txt'): # noqa
first_block_input = defaultdict(list)

Catcher = self.get_encoder_catcher(first_block_input)

self.move_embed_to_device('cuda')
if data_type == 'img_txt':
self.vision_model = self.vision_model.to('cuda')
self.projector = self.projector.to('cuda')
self.encoder_blocks[0] = self.encoder_blocks[0].cuda()
self.encoder_blocks[0] = Catcher(self.encoder_blocks[0])

for data in calib_data:
if isinstance(data, BatchFeature):
data = data.to(next(self.model.parameters()).device)
else:
data = {
k: (v.to(next(self.model.parameters()).device) if torch.is_tensor(v) else v)
for k, v in data.items()
}
try:
if data_type in ['txt', 'img']:
self.model(**data)
elif data_type == 'img_txt':
self.vlm_model.generate(**data, max_new_tokens=128, do_sample=False)
except ValueError:
pass
self.first_block_input = first_block_input
self.padding_mask = None
if data_type == 'img_txt':
self.vision_model = self.vision_model.cpu()
self.projector = self.projector.cpu()
self.encoder_blocks[0] = self.encoder_blocks[0].module
self.encoder_blocks[0] = self.encoder_blocks[0].cpu()
self.move_embed_to_device('cpu')

def get_one_pad_setting(self, padding_side, length):
if padding_side == 'left':
return [0, length]
Expand Down Expand Up @@ -280,17 +324,25 @@ def set_mix_bits_params_dict(self, block_idx, name, params_dict):
params_mix_dict['a_qdq'] = None
return params_mix_dict

def replace_module_all(self, module, params_dict, keep_device=False):
for block_idx in range(len(self.blocks)):
logger.info(f'Replace block index: {block_idx}/{len(self.blocks)}')
block = self.blocks[block_idx]
def replace_modality_module_all(self, module, blocks, params_dict, keep_device=False):
for block_idx in range(len(blocks)):
logger.info(f'Replace block index: {block_idx}/{len(blocks)}')
block = blocks[block_idx]
if keep_device:
self.replace_module_block(module, block, block_idx, params_dict)
else:
block = block.cuda()
self.replace_module_block(module, block, block_idx, params_dict)
block = block.cpu()

def replace_module_all(self, module, params_dict, keep_device=False):
if hasattr(self, 'encoder_blocks'):
logger.info('start replace vision blocks')
self.replace_modality_module_all(module, self.encoder_blocks, params_dict, keep_device)

logger.info('start replace language blocks')
self.replace_modality_module_all(module, self.blocks, params_dict, keep_device)

gc.collect()
torch.cuda.empty_cache()
logger.info(f'The Replaced model: {self.model}')
Expand Down
66 changes: 66 additions & 0 deletions llmc/models/llava.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import inspect

import torch.nn as nn
from loguru import logger
from PIL import Image
from transformers import (AutoConfig, AutoProcessor,
Expand Down Expand Up @@ -31,6 +34,31 @@ def build_model(self):
self.model = self.vlm_model.language_model
self.model_config = self.vlm_model_config.text_config

def find_encoder_blocks(self):
self.encoder_blocks = self.vision_model.vision_model.encoder.layers

def get_encoder_catcher(self, first_block_input):

class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
self.signature = inspect.signature(module.forward)

def forward(self, *args, **kwargs):
params = list(self.signature.parameters.keys())
for i, arg in enumerate(args):
if i > 0:
kwargs[params[i]] = arg
first_block_input['data'].append(args[0])
if 'output_router_logits' in kwargs:
assert kwargs['output_router_logits'] is False
kwargs.pop('output_router_logits')
first_block_input['kwargs'].append(kwargs)
raise ValueError

return Catcher

def batch_process(self, img_qas):
if len(img_qas) == 1:
return self.single_process(img_qas[0])
Expand Down Expand Up @@ -83,3 +111,41 @@ def single_process(self, img_qas):
return_tensors='pt'
).to(next(self.vlm_model.parameters()).dtype) # noqa
return inputs

def get_encoder_subsets_in_block(self, block):
return [
{
'layers': {
'self_attn.q_proj': block.self_attn.q_proj,
'self_attn.k_proj': block.self_attn.k_proj,
'self_attn.v_proj': block.self_attn.v_proj,
},
'prev_op': [block.layer_norm1],
'input': ['self_attn.q_proj'],
'inspect': block.self_attn,
'has_kwargs': True,
},
{
'layers': {'self_attn.out_proj': block.self_attn.out_proj},
'prev_op': [block.self_attn.v_proj],
'input': ['self_attn.out_proj'],
'inspect': block.self_attn.out_proj,
'has_kwargs': False,
},
{
'layers': {'mlp.fc1': block.mlp.fc1},
'prev_op': [block.layer_norm2],
'input': ['mlp.fc1'],
'inspect': block.mlp.fc1,
'has_kwargs': False,
'is_mlp': True,
},
{
'layers': {'mlp.fc2': block.mlp.fc2},
'prev_op': [block.mlp.fc1],
'input': ['mlp.fc2'],
'inspect': block.mlp.fc2,
'has_kwargs': False,
'is_mlp': True,
},
]
Loading