diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index ba5be4b6a..8b7b2ad64 100644 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -375,7 +375,9 @@ def block_forward(self, block, input_data=None): self.input['kwargs'][i][key] = \ self.input['kwargs'][i][key].to(device=next(block.parameters()).device) with torch.no_grad(): - out = block(input_data[i], **self.input['kwargs'][i])[0] + out = block(input_data[i], **self.input['kwargs'][i]) + if isinstance(out, tuple): + out = out[0] output.append(out) return output diff --git a/llmc/models/internvl2.py b/llmc/models/internvl2.py index cad8aa905..b02a6621e 100644 --- a/llmc/models/internvl2.py +++ b/llmc/models/internvl2.py @@ -214,3 +214,43 @@ def batch_process(self, img_qas, calib_or_eval='eval'): **generation_config } return inputs + + def find_blocks(self, modality='language'): + if modality == 'language': + self.blocks = self.model.model.layers + elif modality == 'vision': + self.blocks = self.vision_model.encoder.layers + + def get_vision_subsets_in_block(self, block): + return [ + { + 'layers': {'attn.qkv': block.attn.qkv}, + 'prev_op': [block.norm1], + 'input': ['attn.qkv'], + 'inspect': block.attn, + 'has_kwargs': False, + }, + { + 'layers': {'attn.proj': block.attn.proj}, + 'prev_op': [block.attn.qkv], + 'input': ['attn.proj'], + 'inspect': block.attn.proj, + 'has_kwargs': False, + }, + { + 'layers': {'mlp.fc1': block.mlp.fc1}, + 'prev_op': [block.norm2], + 'input': ['mlp.fc1'], + 'inspect': block.mlp.fc1, + 'has_kwargs': False, + 'is_mlp': True, + }, + { + 'layers': {'mlp.fc2': block.mlp.fc2}, + 'prev_op': [block.mlp.fc1], + 'input': ['mlp.fc2'], + 'inspect': block.mlp.fc2, + 'has_kwargs': False, + 'is_mlp': True, + }, + ] diff --git a/llmc/models/qwen2vl.py b/llmc/models/qwen2vl.py index 8daa2ac9b..1f5cd231f 100644 --- a/llmc/models/qwen2vl.py +++ b/llmc/models/qwen2vl.py @@ -1,3 +1,6 @@ +import inspect + +import torch.nn as nn from loguru import logger from transformers import AutoConfig, AutoProcessor @@ -111,3 +114,68 @@ def batch_process(self, img_qas, calib_or_eval='eval'): return_tensors='pt', ).to(next(self.vlm_model.parameters()).dtype) return inputs + + def find_blocks(self, modality='language'): + if modality == 'language': + self.blocks = self.model.model.layers + elif modality == 'vision': + self.blocks = self.vision_model.blocks + + def get_vision_subsets_in_block(self, block): + return [ + { + 'layers': { + 'attn.qkv': block.attn.qkv, + }, + 'prev_op': [block.norm1], + 'input':['attn.qkv'], + 'inspect': block.attn, + 'has_kwargs': True, + }, + { + 'layers': {'attn.proj': block.attn.proj}, + 'prev_op': [block.attn.qkv], + 'input': ['attn.proj'], + 'inspect': block.attn.proj, + 'has_kwargs': False, + }, + { + 'layers': {'mlp.fc1': block.mlp.fc1}, + 'prev_op': [block.norm2], + 'input': ['mlp.fc1'], + 'inspect': block.mlp.fc1, + 'has_kwargs': False, + 'is_mlp': True, + }, + { + 'layers': {'mlp.fc2': block.mlp.fc2}, + 'prev_op': [block.mlp.fc1], + 'input': ['mlp.fc2'], + 'inspect': block.mlp.fc2, + 'has_kwargs': False, + 'is_mlp': True, + }, + ] + + def get_vision_catcher(self, first_block_input): + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + self.mlp = self.module.mlp + self.signature = inspect.signature(module.forward) + + def forward(self, *args, **kwargs): + params = list(self.signature.parameters.keys()) + for i, arg in enumerate(args): + if i > 0: + kwargs[params[i]] = arg + first_block_input['data'].append(args[0]) + if 'output_router_logits' in kwargs: + assert kwargs['output_router_logits'] is False + kwargs.pop('output_router_logits') + first_block_input['kwargs'].append(kwargs) + raise ValueError + + return Catcher