Skip to content

Commit cde839b

Browse files
update vlm (#186)
Co-authored-by: chengtao-lv <[email protected]>
1 parent aecd6c4 commit cde839b

File tree

7 files changed

+153
-20
lines changed

7 files changed

+153
-20
lines changed

llmc/__main__.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,18 @@ def main(config):
7676
dataset = BaseDataset(tokenizer.get_tokenizer(), config.calib, model.batch_process)
7777
calib_data, padding_mask = dataset.get_calib_dataset()
7878
padding_side = getattr(tokenizer.get_tokenizer(), 'padding_side', None)
79+
if config.calib.type == 'img_txt':
80+
model.collect_first_encoder_block_input(calib_data, padding_mask,
81+
padding_side, config.calib.type)
82+
blockwise_opt = ALGO_REGISTRY[config.quant.method](
83+
model,
84+
config.quant,
85+
model.get_first_block_input(),
86+
model.get_padding_mask(),
87+
config,
88+
'vision'
89+
)
90+
blockwise_opt.run_block_loop()
7991
model.collect_first_block_input(calib_data, padding_mask, padding_side, config.calib.type)
8092
del calib_data
8193
gc.collect()

llmc/compression/blockwise_optimization.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,10 @@
66

77

88
class BlockwiseOpt(metaclass=ABCMeta):
9-
def __init__(self, model, quant_config, input, padding_mask, config):
9+
def __init__(self, model, quant_config, input, padding_mask, config, modality):
1010
self.model = model
11-
self.blocks = model.get_blocks()
11+
self.modality = modality
12+
self.blocks = model.get_blocks(modality)
1213
self.quant_config = quant_config
1314
self.sparsity_config = quant_config
1415
self.input = input

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,8 @@
2727

2828

2929
class BaseBlockwiseQuantization(BlockwiseOpt):
30-
def __init__(self, model, quant_config, input, padding_mask, config):
31-
super().__init__(model, quant_config, input, padding_mask, config)
30+
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
31+
super().__init__(model, quant_config, input, padding_mask, config, modality)
3232
self.set_quant_config()
3333

3434
def w_qdq(self, module, wquantizer):
@@ -436,7 +436,8 @@ def run(self, block, input_feat, handles):
436436

437437
def block_transform(self, block, input_feat, block_kwargs):
438438
logger.info(f'Start transform the {self.block_idx}-th block')
439-
subsets = self.model.get_subsets_in_block(block)
439+
subsets = self.model.get_subsets_in_block(block) \
440+
if self.modality == 'language' else self.model.get_encoder_subsets_in_block(block)
440441

441442
if self.act_static:
442443
self.register_non_linear_qparams(block, input_feat)

llmc/compression/quantization/rtn.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,13 @@
88

99
@ALGO_REGISTRY
1010
class RTN(BaseBlockwiseQuantization):
11-
def __init__(self, model, quant_config, input, padding_mask, config):
12-
super().__init__(model, quant_config, input, padding_mask, config)
11+
def __init__(self, model, quant_config, input, padding_mask, config, modality='language'):
12+
super().__init__(model, quant_config, input, padding_mask, config, modality)
1313

14-
@torch.no_grad()
15-
def block_opt(self, *opt_kwargs):
16-
if self.act_static:
17-
super().block_opt(*opt_kwargs)
14+
# @torch.no_grad()
15+
# def block_opt(self, *opt_kwargs):
16+
# if self.act_static:
17+
# super().block_opt(*opt_kwargs)
1818

1919
@torch.no_grad()
2020
def subset_transform(

llmc/data/dataset/base_dataset.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -255,9 +255,10 @@ def get_calib_dataset(self):
255255
samples = self.get_calib_samples()
256256
if self.calib_dataset_type in ['txt', 'img', 'img_txt']:
257257
logger.info(f'len(samples) all : {len(samples)}')
258-
assert len(samples) % int(os.environ['WORLD_SIZE']) == 0
259-
samples = samples[int(os.environ['RANK'])::int(os.environ['WORLD_SIZE'])]
260-
logger.info(f'len(samples) rank : {len(samples)}')
258+
if os.environ.get('WORLD_SIZE') is not None:
259+
assert len(samples) % int(os.environ['WORLD_SIZE']) == 0
260+
samples = samples[int(os.environ['RANK'])::int(os.environ['WORLD_SIZE'])]
261+
logger.info(f'len(samples) rank : {len(samples)}')
261262
calib_samples = []
262263
if self.calib_dataset_type == 'txt':
263264
if self.padding:

llmc/models/base_model.py

Lines changed: 58 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False):
2828
self.build_model()
2929
self.model.eval()
3030
self.find_blocks()
31+
self.find_encoder_blocks()
3132
self.find_embed_layers()
3233
self.find_block_name()
3334
self.add_layernorms_class()
@@ -36,14 +37,20 @@ def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False):
3637
def find_blocks(self):
3738
pass
3839

40+
def find_encoder_blocks(self):
41+
pass
42+
43+
def get_encoder_catcher(self, first_block_input):
44+
pass
45+
3946
def find_block_name(self):
4047
pass
4148

4249
def get_model(self):
4350
return self.model
4451

45-
def get_blocks(self):
46-
return self.blocks
52+
def get_blocks(self, modality='language'):
53+
return self.blocks if modality == 'language' else self.encoder_blocks
4754

4855
@abstractmethod
4956
def find_embed_layers(self):
@@ -186,6 +193,43 @@ def collect_first_block_input(self, calib_data, padding_mask=None, padding_side=
186193
self.blocks[0] = self.blocks[0].cpu()
187194
self.move_embed_to_device('cpu')
188195

196+
@torch.no_grad()
197+
def collect_first_encoder_block_input(self, calib_data, padding_mask=None, padding_side=None, data_type='txt'): # noqa
198+
first_block_input = defaultdict(list)
199+
200+
Catcher = self.get_encoder_catcher(first_block_input)
201+
202+
self.move_embed_to_device('cuda')
203+
if data_type == 'img_txt':
204+
self.vision_model = self.vision_model.to('cuda')
205+
self.projector = self.projector.to('cuda')
206+
self.encoder_blocks[0] = self.encoder_blocks[0].cuda()
207+
self.encoder_blocks[0] = Catcher(self.encoder_blocks[0])
208+
209+
for data in calib_data:
210+
if isinstance(data, BatchFeature):
211+
data = data.to(next(self.model.parameters()).device)
212+
else:
213+
data = {
214+
k: (v.to(next(self.model.parameters()).device) if torch.is_tensor(v) else v)
215+
for k, v in data.items()
216+
}
217+
try:
218+
if data_type in ['txt', 'img']:
219+
self.model(**data)
220+
elif data_type == 'img_txt':
221+
self.vlm_model.generate(**data, max_new_tokens=128, do_sample=False)
222+
except ValueError:
223+
pass
224+
self.first_block_input = first_block_input
225+
self.padding_mask = None
226+
if data_type == 'img_txt':
227+
self.vision_model = self.vision_model.cpu()
228+
self.projector = self.projector.cpu()
229+
self.encoder_blocks[0] = self.encoder_blocks[0].module
230+
self.encoder_blocks[0] = self.encoder_blocks[0].cpu()
231+
self.move_embed_to_device('cpu')
232+
189233
def get_one_pad_setting(self, padding_side, length):
190234
if padding_side == 'left':
191235
return [0, length]
@@ -280,17 +324,25 @@ def set_mix_bits_params_dict(self, block_idx, name, params_dict):
280324
params_mix_dict['a_qdq'] = None
281325
return params_mix_dict
282326

283-
def replace_module_all(self, module, params_dict, keep_device=False):
284-
for block_idx in range(len(self.blocks)):
285-
logger.info(f'Replace block index: {block_idx}/{len(self.blocks)}')
286-
block = self.blocks[block_idx]
327+
def replace_modality_module_all(self, module, blocks, params_dict, keep_device=False):
328+
for block_idx in range(len(blocks)):
329+
logger.info(f'Replace block index: {block_idx}/{len(blocks)}')
330+
block = blocks[block_idx]
287331
if keep_device:
288332
self.replace_module_block(module, block, block_idx, params_dict)
289333
else:
290334
block = block.cuda()
291335
self.replace_module_block(module, block, block_idx, params_dict)
292336
block = block.cpu()
293337

338+
def replace_module_all(self, module, params_dict, keep_device=False):
339+
if hasattr(self, 'encoder_blocks'):
340+
logger.info('start replace vision blocks')
341+
self.replace_modality_module_all(module, self.encoder_blocks, params_dict, keep_device)
342+
343+
logger.info('start replace language blocks')
344+
self.replace_modality_module_all(module, self.blocks, params_dict, keep_device)
345+
294346
gc.collect()
295347
torch.cuda.empty_cache()
296348
logger.info(f'The Replaced model: {self.model}')

llmc/models/llava.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import inspect
2+
3+
import torch.nn as nn
14
from loguru import logger
25
from PIL import Image
36
from transformers import (AutoConfig, AutoProcessor,
@@ -31,6 +34,31 @@ def build_model(self):
3134
self.model = self.vlm_model.language_model
3235
self.model_config = self.vlm_model_config.text_config
3336

37+
def find_encoder_blocks(self):
38+
self.encoder_blocks = self.vision_model.vision_model.encoder.layers
39+
40+
def get_encoder_catcher(self, first_block_input):
41+
42+
class Catcher(nn.Module):
43+
def __init__(self, module):
44+
super().__init__()
45+
self.module = module
46+
self.signature = inspect.signature(module.forward)
47+
48+
def forward(self, *args, **kwargs):
49+
params = list(self.signature.parameters.keys())
50+
for i, arg in enumerate(args):
51+
if i > 0:
52+
kwargs[params[i]] = arg
53+
first_block_input['data'].append(args[0])
54+
if 'output_router_logits' in kwargs:
55+
assert kwargs['output_router_logits'] is False
56+
kwargs.pop('output_router_logits')
57+
first_block_input['kwargs'].append(kwargs)
58+
raise ValueError
59+
60+
return Catcher
61+
3462
def batch_process(self, img_qas):
3563
if len(img_qas) == 1:
3664
return self.single_process(img_qas[0])
@@ -83,3 +111,41 @@ def single_process(self, img_qas):
83111
return_tensors='pt'
84112
).to(next(self.vlm_model.parameters()).dtype) # noqa
85113
return inputs
114+
115+
def get_encoder_subsets_in_block(self, block):
116+
return [
117+
{
118+
'layers': {
119+
'self_attn.q_proj': block.self_attn.q_proj,
120+
'self_attn.k_proj': block.self_attn.k_proj,
121+
'self_attn.v_proj': block.self_attn.v_proj,
122+
},
123+
'prev_op': [block.layer_norm1],
124+
'input': ['self_attn.q_proj'],
125+
'inspect': block.self_attn,
126+
'has_kwargs': True,
127+
},
128+
{
129+
'layers': {'self_attn.out_proj': block.self_attn.out_proj},
130+
'prev_op': [block.self_attn.v_proj],
131+
'input': ['self_attn.out_proj'],
132+
'inspect': block.self_attn.out_proj,
133+
'has_kwargs': False,
134+
},
135+
{
136+
'layers': {'mlp.fc1': block.mlp.fc1},
137+
'prev_op': [block.layer_norm2],
138+
'input': ['mlp.fc1'],
139+
'inspect': block.mlp.fc1,
140+
'has_kwargs': False,
141+
'is_mlp': True,
142+
},
143+
{
144+
'layers': {'mlp.fc2': block.mlp.fc2},
145+
'prev_op': [block.mlp.fc1],
146+
'input': ['mlp.fc2'],
147+
'inspect': block.mlp.fc2,
148+
'has_kwargs': False,
149+
'is_mlp': True,
150+
},
151+
]

0 commit comments

Comments
 (0)