Skip to content

Commit 2bdc342

Browse files
update vlm models (#214)
* support internvl and qwenvl * fix --------- Co-authored-by: chengtao-lv <[email protected]>
1 parent b41b26a commit 2bdc342

File tree

3 files changed

+111
-1
lines changed

3 files changed

+111
-1
lines changed

llmc/compression/quantization/base_blockwise_quantization.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,9 @@ def block_forward(self, block, input_data=None):
375375
self.input['kwargs'][i][key] = \
376376
self.input['kwargs'][i][key].to(device=next(block.parameters()).device)
377377
with torch.no_grad():
378-
out = block(input_data[i], **self.input['kwargs'][i])[0]
378+
out = block(input_data[i], **self.input['kwargs'][i])
379+
if isinstance(out, tuple):
380+
out = out[0]
379381
output.append(out)
380382
return output
381383

llmc/models/internvl2.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,43 @@ def batch_process(self, img_qas, calib_or_eval='eval'):
214214
**generation_config
215215
}
216216
return inputs
217+
218+
def find_blocks(self, modality='language'):
219+
if modality == 'language':
220+
self.blocks = self.model.model.layers
221+
elif modality == 'vision':
222+
self.blocks = self.vision_model.encoder.layers
223+
224+
def get_vision_subsets_in_block(self, block):
225+
return [
226+
{
227+
'layers': {'attn.qkv': block.attn.qkv},
228+
'prev_op': [block.norm1],
229+
'input': ['attn.qkv'],
230+
'inspect': block.attn,
231+
'has_kwargs': False,
232+
},
233+
{
234+
'layers': {'attn.proj': block.attn.proj},
235+
'prev_op': [block.attn.qkv],
236+
'input': ['attn.proj'],
237+
'inspect': block.attn.proj,
238+
'has_kwargs': False,
239+
},
240+
{
241+
'layers': {'mlp.fc1': block.mlp.fc1},
242+
'prev_op': [block.norm2],
243+
'input': ['mlp.fc1'],
244+
'inspect': block.mlp.fc1,
245+
'has_kwargs': False,
246+
'is_mlp': True,
247+
},
248+
{
249+
'layers': {'mlp.fc2': block.mlp.fc2},
250+
'prev_op': [block.mlp.fc1],
251+
'input': ['mlp.fc2'],
252+
'inspect': block.mlp.fc2,
253+
'has_kwargs': False,
254+
'is_mlp': True,
255+
},
256+
]

llmc/models/qwen2vl.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import inspect
2+
3+
import torch.nn as nn
14
from loguru import logger
25
from transformers import AutoConfig, AutoProcessor
36

@@ -111,3 +114,68 @@ def batch_process(self, img_qas, calib_or_eval='eval'):
111114
return_tensors='pt',
112115
).to(next(self.vlm_model.parameters()).dtype)
113116
return inputs
117+
118+
def find_blocks(self, modality='language'):
119+
if modality == 'language':
120+
self.blocks = self.model.model.layers
121+
elif modality == 'vision':
122+
self.blocks = self.vision_model.blocks
123+
124+
def get_vision_subsets_in_block(self, block):
125+
return [
126+
{
127+
'layers': {
128+
'attn.qkv': block.attn.qkv,
129+
},
130+
'prev_op': [block.norm1],
131+
'input':['attn.qkv'],
132+
'inspect': block.attn,
133+
'has_kwargs': True,
134+
},
135+
{
136+
'layers': {'attn.proj': block.attn.proj},
137+
'prev_op': [block.attn.qkv],
138+
'input': ['attn.proj'],
139+
'inspect': block.attn.proj,
140+
'has_kwargs': False,
141+
},
142+
{
143+
'layers': {'mlp.fc1': block.mlp.fc1},
144+
'prev_op': [block.norm2],
145+
'input': ['mlp.fc1'],
146+
'inspect': block.mlp.fc1,
147+
'has_kwargs': False,
148+
'is_mlp': True,
149+
},
150+
{
151+
'layers': {'mlp.fc2': block.mlp.fc2},
152+
'prev_op': [block.mlp.fc1],
153+
'input': ['mlp.fc2'],
154+
'inspect': block.mlp.fc2,
155+
'has_kwargs': False,
156+
'is_mlp': True,
157+
},
158+
]
159+
160+
def get_vision_catcher(self, first_block_input):
161+
162+
class Catcher(nn.Module):
163+
def __init__(self, module):
164+
super().__init__()
165+
self.module = module
166+
self.mlp = self.module.mlp
167+
self.signature = inspect.signature(module.forward)
168+
169+
def forward(self, *args, **kwargs):
170+
params = list(self.signature.parameters.keys())
171+
for i, arg in enumerate(args):
172+
if i > 0:
173+
kwargs[params[i]] = arg
174+
first_block_input['data'].append(args[0])
175+
if 'output_router_logits' in kwargs:
176+
assert kwargs['output_router_logits'] is False
177+
kwargs.pop('output_router_logits')
178+
first_block_input['kwargs'].append(kwargs)
179+
raise ValueError
180+
181+
return Catcher

0 commit comments

Comments
 (0)