diff --git a/llmc/__main__.py b/llmc/__main__.py index a2d0557a8..cfb7c752f 100644 --- a/llmc/__main__.py +++ b/llmc/__main__.py @@ -69,43 +69,50 @@ def main(config): for ppl_eval in eval_list: ppl = ppl_eval.eval(model) logger.info(f'{ppl_eval.dataset} ppl : {ppl}') - - if not config.get('calib', False): - blockwise_opt = ALGO_REGISTRY[config.quant.method]( - model, - quant_config=config.quant, - input=None, - padding_mask=None, - config=config - ) - blockwise_opt.run_block_loop() - dist.barrier() - else: - dataset = BaseDataset(tokenizer.get_tokenizer(), config.calib, model.batch_process) - calib_data, padding_mask = dataset.get_calib_dataset() - padding_side = getattr(tokenizer.get_tokenizer(), 'padding_side', None) - model.collect_first_block_input(calib_data, padding_mask, padding_side, config.calib.type) - del calib_data - gc.collect() - torch.cuda.empty_cache() - if not config.get('sparse', False): + for modality in config.quant.get('quant_objects', ['language']): + if not config.get('calib', False): blockwise_opt = ALGO_REGISTRY[config.quant.method]( model, - config.quant, - model.get_first_block_input(), - model.get_padding_mask(), - config + quant_config=config.quant, + input=None, + padding_mask=None, + config=config, + modality=modality, ) + blockwise_opt.run_block_loop() + dist.barrier() else: - blockwise_opt = ALGO_REGISTRY[config.sparse.method]( - model, - config.sparse, - model.get_first_block_input(), - model.get_padding_mask(), - config - ) - blockwise_opt.run_block_loop() - dist.barrier() + dataset = BaseDataset(tokenizer.get_tokenizer(), config.calib, model.batch_process) + calib_data, padding_mask = dataset.get_calib_dataset() + padding_side = getattr(tokenizer.get_tokenizer(), 'padding_side', None) + model.collect_first_block_input(calib_data, + padding_mask, + padding_side, + config.calib.type, + modality) + del calib_data + gc.collect() + torch.cuda.empty_cache() + if not config.get('sparse', False): + blockwise_opt = ALGO_REGISTRY[config.quant.method]( + model, + config.quant, + model.get_first_block_input(), + model.get_padding_mask(), + config, + modality + ) + else: + blockwise_opt = ALGO_REGISTRY[config.sparse.method]( + model, + config.sparse, + model.get_first_block_input(), + model.get_padding_mask(), + config, + modality + ) + blockwise_opt.run_block_loop() + dist.barrier() if int(os.environ['RANK']) == 0: if 'eval' in config and 'transformed' in config.eval.eval_pos: diff --git a/llmc/compression/blockwise_optimization.py b/llmc/compression/blockwise_optimization.py index c00e83775..700732f7e 100644 --- a/llmc/compression/blockwise_optimization.py +++ b/llmc/compression/blockwise_optimization.py @@ -6,8 +6,10 @@ class BlockwiseOpt(metaclass=ABCMeta): - def __init__(self, model, quant_config, input, padding_mask, config): + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): self.model = model + self.modality = modality + self.model.find_blocks(modality) self.blocks = model.get_blocks() self.quant_config = quant_config self.sparsity_config = quant_config diff --git a/llmc/compression/quantization/adadim.py b/llmc/compression/quantization/adadim.py index 8d481dc61..1b5fe1e10 100644 --- a/llmc/compression/quantization/adadim.py +++ b/llmc/compression/quantization/adadim.py @@ -9,8 +9,8 @@ @ALGO_REGISTRY class AdaDim(BaseBlockwiseQuantization): - def __init__(self, model, quant_config, input, config): - super().__init__(model, quant_config, input, config) + def __init__(self, model, quant_config, input, config, modality='language'): + super().__init__(model, quant_config, input, config, modality) def get_layer_out(self, x, layer): with torch.no_grad(): diff --git a/llmc/compression/quantization/awq.py b/llmc/compression/quantization/awq.py index db2a9cf62..b5de0b6f7 100644 --- a/llmc/compression/quantization/awq.py +++ b/llmc/compression/quantization/awq.py @@ -17,8 +17,8 @@ @ALGO_REGISTRY class Awq(BaseBlockwiseQuantization): - def __init__(self, model, quant_config, input, padding_mask, config): - super().__init__(model, quant_config, input, padding_mask, config) + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): + super().__init__(model, quant_config, input, padding_mask, config, modality) special_config = self.quant_config.get('special', {}) self.trans = special_config.get('trans', True) self.trans_version = special_config.get('trans_version', 'v2') diff --git a/llmc/compression/quantization/base_blockwise_quantization.py b/llmc/compression/quantization/base_blockwise_quantization.py index 1daef540e..10d97ce07 100644 --- a/llmc/compression/quantization/base_blockwise_quantization.py +++ b/llmc/compression/quantization/base_blockwise_quantization.py @@ -27,8 +27,8 @@ class BaseBlockwiseQuantization(BlockwiseOpt): - def __init__(self, model, quant_config, input, padding_mask, config): - super().__init__(model, quant_config, input, padding_mask, config) + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): + super().__init__(model, quant_config, input, padding_mask, config, modality) self.set_quant_config() def w_qdq(self, module, wquantizer): @@ -439,7 +439,8 @@ def run(self, block, input_feat, handles): def block_transform(self, block, input_feat, block_kwargs): logger.info(f'Start transform the {self.block_idx}-th block') - subsets = self.model.get_subsets_in_block(block) + subsets = self.model.get_subsets_in_block(block) \ + if self.modality == 'language' else self.model.get_vision_subsets_in_block(block) if self.act_static: self.register_non_linear_qparams(block, input_feat) diff --git a/llmc/compression/quantization/dgq.py b/llmc/compression/quantization/dgq.py index 4109065d5..9342a5911 100644 --- a/llmc/compression/quantization/dgq.py +++ b/llmc/compression/quantization/dgq.py @@ -13,8 +13,8 @@ @ALGO_REGISTRY class DGQ(BaseBlockwiseQuantization): - def __init__(self, model, quant_config, input, padding_mask, config): - super().__init__(model, quant_config, input, padding_mask, config) + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): + super().__init__(model, quant_config, input, padding_mask, config, modality) self.model_dtype = next(self.model.model.parameters()).dtype def w_qdq(self, module, wquantizer): diff --git a/llmc/compression/quantization/gptq.py b/llmc/compression/quantization/gptq.py index 9183df78c..3db44f56b 100644 --- a/llmc/compression/quantization/gptq.py +++ b/llmc/compression/quantization/gptq.py @@ -17,8 +17,8 @@ @ALGO_REGISTRY class GPTQ(BaseBlockwiseQuantization): - def __init__(self, model, quant_config, input, padding_mask, config): - super().__init__(model, quant_config, input, padding_mask, config) + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): + super().__init__(model, quant_config, input, padding_mask, config, modality) self.dev = torch.device('cuda') self.model_dtype = next(self.model.model.parameters()).dtype self.add_quant_config() diff --git a/llmc/compression/quantization/hqq.py b/llmc/compression/quantization/hqq.py index 0077c401b..1c98d55bd 100644 --- a/llmc/compression/quantization/hqq.py +++ b/llmc/compression/quantization/hqq.py @@ -11,8 +11,8 @@ @ALGO_REGISTRY class HQQ(BaseBlockwiseQuantization): - def __init__(self, model, quant_config, input, padding_mask, config): - super().__init__(model, quant_config, input, padding_mask, config) + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): + super().__init__(model, quant_config, input, padding_mask, config, modality) self.add_quant_config() @torch.no_grad() diff --git a/llmc/compression/quantization/llmint8.py b/llmc/compression/quantization/llmint8.py index 116b8843e..8321bd4e2 100644 --- a/llmc/compression/quantization/llmint8.py +++ b/llmc/compression/quantization/llmint8.py @@ -9,8 +9,8 @@ @ALGO_REGISTRY class LlmInt8(BaseBlockwiseQuantization): - def __init__(self, model, quant_config, input, padding_mask, config): - super().__init__(model, quant_config, input, padding_mask, config) + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): + super().__init__(model, quant_config, input, padding_mask, config, modality) self.add_quant_config() @torch.no_grad() diff --git a/llmc/compression/quantization/ntweak.py b/llmc/compression/quantization/ntweak.py index b758bf2a8..c61a4a281 100644 --- a/llmc/compression/quantization/ntweak.py +++ b/llmc/compression/quantization/ntweak.py @@ -19,8 +19,8 @@ @ALGO_REGISTRY class NormTweaking(BaseBlockwiseQuantization): - def __init__(self, model, quant_config, input, padding_mask, config): - super().__init__(model, quant_config, input, padding_mask, config) + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): + super().__init__(model, quant_config, input, padding_mask, config, modality) self.add_quant_config() model_type = self.config['model']['type'] diff --git a/llmc/compression/quantization/omniq.py b/llmc/compression/quantization/omniq.py index b2c432b58..045e95093 100644 --- a/llmc/compression/quantization/omniq.py +++ b/llmc/compression/quantization/omniq.py @@ -25,8 +25,8 @@ @ALGO_REGISTRY class OmniQuant(BaseBlockwiseQuantization): - def __init__(self, model, quant_config, input, padding_mask, config): - super().__init__(model, quant_config, input, padding_mask, config) + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): + super().__init__(model, quant_config, input, padding_mask, config, modality) self.add_quant_config() model_type = self.config['model']['type'] diff --git a/llmc/compression/quantization/osplus.py b/llmc/compression/quantization/osplus.py index 347e62010..e4d1da7e5 100644 --- a/llmc/compression/quantization/osplus.py +++ b/llmc/compression/quantization/osplus.py @@ -17,9 +17,9 @@ @ALGO_REGISTRY class OsPlus(BaseBlockwiseQuantization): - def __init__(self, model, quant_config, input, padding_mask, config): + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): torch.set_grad_enabled(False) - super().__init__(model, quant_config, input, padding_mask, config) + super().__init__(model, quant_config, input, padding_mask, config, modality) @torch.no_grad() def filter_subset(self, layers_dict, prev_op): diff --git a/llmc/compression/quantization/quarot.py b/llmc/compression/quantization/quarot.py index 8e6c1d2df..a3cd02e35 100644 --- a/llmc/compression/quantization/quarot.py +++ b/llmc/compression/quantization/quarot.py @@ -16,8 +16,8 @@ @ALGO_REGISTRY class Quarot(BaseBlockwiseQuantization): - def __init__(self, model, quant_config, input, padding_mask, config): - super().__init__(model, quant_config, input, padding_mask, config) + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): + super().__init__(model, quant_config, input, padding_mask, config, modality) self.dev = torch.device('cuda') self.add_quant_config() self.preprocess() diff --git a/llmc/compression/quantization/quik.py b/llmc/compression/quantization/quik.py index 3a1e0441b..57693647b 100644 --- a/llmc/compression/quantization/quik.py +++ b/llmc/compression/quantization/quik.py @@ -12,8 +12,8 @@ @ALGO_REGISTRY class QUIK(BaseBlockwiseQuantization): - def __init__(self, model, quant_config, input, padding_mask, config): - super().__init__(model, quant_config, input, padding_mask, config) + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): + super().__init__(model, quant_config, input, padding_mask, config, modality) self.add_quant_config() def add_quant_config(self): diff --git a/llmc/compression/quantization/rtn.py b/llmc/compression/quantization/rtn.py index 35ba840cf..966524c87 100644 --- a/llmc/compression/quantization/rtn.py +++ b/llmc/compression/quantization/rtn.py @@ -8,8 +8,8 @@ @ALGO_REGISTRY class RTN(BaseBlockwiseQuantization): - def __init__(self, model, quant_config, input, padding_mask, config): - super().__init__(model, quant_config, input, padding_mask, config) + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): + super().__init__(model, quant_config, input, padding_mask, config, modality) @torch.no_grad() def block_opt(self, *opt_kwargs): diff --git a/llmc/compression/quantization/smoothquant.py b/llmc/compression/quantization/smoothquant.py index 061b65068..de0c4710a 100644 --- a/llmc/compression/quantization/smoothquant.py +++ b/llmc/compression/quantization/smoothquant.py @@ -12,8 +12,8 @@ @ALGO_REGISTRY class SmoothQuant(BaseBlockwiseQuantization): - def __init__(self, model, quant_config, input, padding_mask, config): - super().__init__(model, quant_config, input, padding_mask, config) + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): + super().__init__(model, quant_config, input, padding_mask, config, modality) special_config = self.quant_config.get('special', {}) self.alpha = special_config.get('alpha', 0.5) diff --git a/llmc/compression/quantization/spqr.py b/llmc/compression/quantization/spqr.py index 51ee90742..3260612f0 100644 --- a/llmc/compression/quantization/spqr.py +++ b/llmc/compression/quantization/spqr.py @@ -18,8 +18,8 @@ @ALGO_REGISTRY class SpQR(BaseBlockwiseQuantization): - def __init__(self, model, quant_config, input, padding_mask, config): - super().__init__(model, quant_config, input, padding_mask, config) + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): + super().__init__(model, quant_config, input, padding_mask, config, modality) assert ( self.wquantizer.granularity == 'per_group' ), 'SpQR only supports per_group quantization' diff --git a/llmc/compression/quantization/tesseraq.py b/llmc/compression/quantization/tesseraq.py index e128373f2..e4be2ab87 100644 --- a/llmc/compression/quantization/tesseraq.py +++ b/llmc/compression/quantization/tesseraq.py @@ -23,8 +23,8 @@ @ALGO_REGISTRY class TesseraQ(BaseBlockwiseQuantization): - def __init__(self, model, quant_config, input, padding_mask, config): - super().__init__(model, quant_config, input, padding_mask, config) + def __init__(self, model, quant_config, input, padding_mask, config, modality='language'): + super().__init__(model, quant_config, input, padding_mask, config, modality) self.add_quant_config() self.attention_mask = self.input['kwargs'][0].get('attention_mask') diff --git a/llmc/compression/sparsification/base_blockwise_sparsification.py b/llmc/compression/sparsification/base_blockwise_sparsification.py index e62f5f272..1f0dbda89 100644 --- a/llmc/compression/sparsification/base_blockwise_sparsification.py +++ b/llmc/compression/sparsification/base_blockwise_sparsification.py @@ -12,8 +12,8 @@ class BaseBlockwiseSparsification(BlockwiseOpt): - def __init__(self, model, sparsity_config, input, padding_mask, config): - super().__init__(model, sparsity_config, input, padding_mask, config) + def __init__(self, model, sparsity_config, input, padding_mask, config, modality='language'): + super().__init__(model, sparsity_config, input, padding_mask, config, modality) self.set_sparsity_config() def block_init(self, block): diff --git a/llmc/compression/sparsification/magnitude.py b/llmc/compression/sparsification/magnitude.py index 8f36b295d..af3712c45 100644 --- a/llmc/compression/sparsification/magnitude.py +++ b/llmc/compression/sparsification/magnitude.py @@ -8,8 +8,8 @@ @ALGO_REGISTRY class Magnitude(BaseBlockwiseSparsification): - def __init__(self, model, sparsity_config, input, padding_mask, config): - super().__init__(model, sparsity_config, input, padding_mask, config) + def __init__(self, model, sparsity_config, input, padding_mask, config, modality='language'): + super().__init__(model, sparsity_config, input, padding_mask, config, modality) @torch.no_grad() def subset_transform( diff --git a/llmc/compression/sparsification/shortgpt.py b/llmc/compression/sparsification/shortgpt.py index c8c8dc410..14f9b4ddf 100644 --- a/llmc/compression/sparsification/shortgpt.py +++ b/llmc/compression/sparsification/shortgpt.py @@ -17,8 +17,8 @@ @ALGO_REGISTRY class ShortGPT(BaseBlockwiseSparsification): - def __init__(self, model, sparsity_config, input, padding_mask, config): - super().__init__(model, sparsity_config, input, padding_mask, config) + def __init__(self, model, sparsity_config, input, padding_mask, config, modality='language'): + super().__init__(model, sparsity_config, input, padding_mask, config, modality) def block_opt(self, block): block = block.cuda() diff --git a/llmc/compression/sparsification/wanda.py b/llmc/compression/sparsification/wanda.py index 951e58dab..c0cfe710c 100644 --- a/llmc/compression/sparsification/wanda.py +++ b/llmc/compression/sparsification/wanda.py @@ -9,8 +9,8 @@ @ALGO_REGISTRY class Wanda(BaseBlockwiseSparsification): - def __init__(self, model, sparsity_config, input, padding_mask, config): - super().__init__(model, sparsity_config, input, padding_mask, config) + def __init__(self, model, sparsity_config, input, padding_mask, config, modality='language'): + super().__init__(model, sparsity_config, input, padding_mask, config, modality) @torch.no_grad() def get_row_scale(self, layer, act): diff --git a/llmc/models/base_model.py b/llmc/models/base_model.py index 97e5b0fad..1d4e021a2 100644 --- a/llmc/models/base_model.py +++ b/llmc/models/base_model.py @@ -1,4 +1,5 @@ import gc +import inspect from abc import ABCMeta, abstractmethod from collections import defaultdict from functools import partial @@ -33,7 +34,7 @@ def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): self.add_layernorms_class() @abstractmethod - def find_blocks(self): + def find_blocks(self, modality='language'): pass def find_block_name(self): @@ -87,7 +88,29 @@ def get_attention_rotary_layers(self): def batch_process(self): raise Exception('batch_process should not be called here.') - def get_catcher(self, first_block_input): + def get_vision_catcher(self, first_block_input): + + class Catcher(nn.Module): + def __init__(self, module): + super().__init__() + self.module = module + self.signature = inspect.signature(module.forward) + + def forward(self, *args, **kwargs): + params = list(self.signature.parameters.keys()) + for i, arg in enumerate(args): + if i > 0: + kwargs[params[i]] = arg + first_block_input['data'].append(args[0]) + if 'output_router_logits' in kwargs: + assert kwargs['output_router_logits'] is False + kwargs.pop('output_router_logits') + first_block_input['kwargs'].append(kwargs) + raise ValueError + + return Catcher + + def get_language_catcher(self, first_block_input): class Catcher(nn.Module): def __init__(self, module): @@ -138,10 +161,15 @@ def add_layernorms_class(self): logger.info(f'_TRANSFORMERS_LN_TYPES_ : {_TRANSFORMERS_LN_TYPES_}') @torch.no_grad() - def collect_first_block_input(self, calib_data, padding_mask=None, padding_side=None, data_type='txt'): # noqa + def collect_first_block_input(self, calib_data, padding_mask=None, + padding_side=None, data_type='txt', modality='language'): first_block_input = defaultdict(list) - Catcher = self.get_catcher(first_block_input) + self.find_blocks(modality) + if modality == 'language': + Catcher = self.get_language_catcher(first_block_input) + elif modality == 'vision': + Catcher = self.get_vision_catcher(first_block_input) self.move_embed_to_device('cuda') if data_type == 'img_txt': diff --git a/llmc/models/bloom.py b/llmc/models/bloom.py index 16980a87c..2b121dc7d 100644 --- a/llmc/models/bloom.py +++ b/llmc/models/bloom.py @@ -8,7 +8,7 @@ class Bloom(BaseModel): def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): super().__init__(model_path, torch_dtype, device_map, use_cache) - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.transformer.h def find_embed_layers(self): diff --git a/llmc/models/deepseekv2.py b/llmc/models/deepseekv2.py index cc610dd4a..e0a7e0990 100644 --- a/llmc/models/deepseekv2.py +++ b/llmc/models/deepseekv2.py @@ -8,7 +8,7 @@ class DeepseekV2(BaseModel): def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): super().__init__(model_path, torch_dtype, device_map, use_cache) - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.model.layers def find_embed_layers(self): diff --git a/llmc/models/falcon.py b/llmc/models/falcon.py index 90be4f2c2..2e9633fc8 100644 --- a/llmc/models/falcon.py +++ b/llmc/models/falcon.py @@ -8,7 +8,7 @@ class Falcon(BaseModel): def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): super().__init__(model_path, torch_dtype, device_map, use_cache) - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.transformer.h def find_embed_layers(self): diff --git a/llmc/models/gemma2.py b/llmc/models/gemma2.py index 47ffec702..c22f9bd3f 100644 --- a/llmc/models/gemma2.py +++ b/llmc/models/gemma2.py @@ -26,7 +26,7 @@ def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): m.weight = nn.Parameter(w + 1.0) m.forward = MethodType(gemma2_rms_norm_forward, m) - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.model.layers def find_embed_layers(self): diff --git a/llmc/models/internlm2.py b/llmc/models/internlm2.py index 39f1b57eb..c3fb9c79e 100644 --- a/llmc/models/internlm2.py +++ b/llmc/models/internlm2.py @@ -11,7 +11,7 @@ def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): global _TRANSFORMERS_LN_TYPES_ _TRANSFORMERS_LN_TYPES_ += [type(self.model.model.norm)] - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.model.layers def find_embed_layers(self): diff --git a/llmc/models/llama.py b/llmc/models/llama.py index 35e62bc86..6ad82a522 100644 --- a/llmc/models/llama.py +++ b/llmc/models/llama.py @@ -8,7 +8,7 @@ class Llama(BaseModel): def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): super().__init__(model_path, torch_dtype, device_map, use_cache) - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.model.layers def find_embed_layers(self): diff --git a/llmc/models/llava.py b/llmc/models/llava.py index f203f457a..a4eebe2c0 100644 --- a/llmc/models/llava.py +++ b/llmc/models/llava.py @@ -83,3 +83,47 @@ def single_process(self, img_qas): return_tensors='pt' ).to(next(self.vlm_model.parameters()).dtype) # noqa return inputs + + def find_blocks(self, modality='language'): + if modality == 'language': + self.blocks = self.model.model.layers + elif modality == 'vision': + self.blocks = self.vision_model.vision_model.encoder.layers + + def get_vision_subsets_in_block(self, block): + return [ + { + 'layers': { + 'self_attn.q_proj': block.self_attn.q_proj, + 'self_attn.k_proj': block.self_attn.k_proj, + 'self_attn.v_proj': block.self_attn.v_proj, + }, + 'prev_op': [block.layer_norm1], + 'input': ['self_attn.q_proj'], + 'inspect': block.self_attn, + 'has_kwargs': True, + }, + { + 'layers': {'self_attn.out_proj': block.self_attn.out_proj}, + 'prev_op': [block.self_attn.v_proj], + 'input': ['self_attn.out_proj'], + 'inspect': block.self_attn.out_proj, + 'has_kwargs': False, + }, + { + 'layers': {'mlp.fc1': block.mlp.fc1}, + 'prev_op': [block.layer_norm2], + 'input': ['mlp.fc1'], + 'inspect': block.mlp.fc1, + 'has_kwargs': False, + 'is_mlp': True, + }, + { + 'layers': {'mlp.fc2': block.mlp.fc2}, + 'prev_op': [block.mlp.fc1], + 'input': ['mlp.fc2'], + 'inspect': block.mlp.fc2, + 'has_kwargs': False, + 'is_mlp': True, + }, + ] diff --git a/llmc/models/minicpm.py b/llmc/models/minicpm.py index b8e52d0b7..793417541 100644 --- a/llmc/models/minicpm.py +++ b/llmc/models/minicpm.py @@ -11,7 +11,7 @@ def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): global _TRANSFORMERS_LN_TYPES_ _TRANSFORMERS_LN_TYPES_ += [type(self.model.model.norm)] - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.model.layers def find_embed_layers(self): diff --git a/llmc/models/mistral.py b/llmc/models/mistral.py index 9e9787961..41a71f95d 100644 --- a/llmc/models/mistral.py +++ b/llmc/models/mistral.py @@ -8,7 +8,7 @@ class Mistral(BaseModel): def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): super().__init__(model_path, torch_dtype, device_map, use_cache) - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.model.layers def find_embed_layers(self): diff --git a/llmc/models/mixtral.py b/llmc/models/mixtral.py index fec0fcdb5..88a3c2b09 100644 --- a/llmc/models/mixtral.py +++ b/llmc/models/mixtral.py @@ -8,7 +8,7 @@ class Mixtral(BaseModel): def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): super().__init__(model_path, torch_dtype, device_map, use_cache) - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.model.layers def find_embed_layers(self): diff --git a/llmc/models/opt.py b/llmc/models/opt.py index 71e2f2114..ac5604e00 100644 --- a/llmc/models/opt.py +++ b/llmc/models/opt.py @@ -8,7 +8,7 @@ class Opt(BaseModel): def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): super().__init__(model_path, torch_dtype, device_map, use_cache) - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.model.decoder.layers def find_embed_layers(self): diff --git a/llmc/models/phi.py b/llmc/models/phi.py index d55518a5d..5bfef7af9 100644 --- a/llmc/models/phi.py +++ b/llmc/models/phi.py @@ -8,7 +8,7 @@ class Phi(BaseModel): def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): super().__init__(model_path, torch_dtype, device_map, use_cache) - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.model.layers def find_embed_layers(self): diff --git a/llmc/models/qwen.py b/llmc/models/qwen.py index 21d323ae6..48a53bc3c 100644 --- a/llmc/models/qwen.py +++ b/llmc/models/qwen.py @@ -8,7 +8,7 @@ class Qwen(BaseModel): def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): super().__init__(model_path, torch_dtype, device_map, use_cache) - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.transformer.h def find_embed_layers(self): diff --git a/llmc/models/qwen2.py b/llmc/models/qwen2.py index 25840decd..4c669f306 100644 --- a/llmc/models/qwen2.py +++ b/llmc/models/qwen2.py @@ -8,7 +8,7 @@ class Qwen2(BaseModel): def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): super().__init__(model_path, torch_dtype, device_map, use_cache) - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.model.layers def find_embed_layers(self): diff --git a/llmc/models/qwen2moe.py b/llmc/models/qwen2moe.py index 44e848fd3..57dca87ee 100644 --- a/llmc/models/qwen2moe.py +++ b/llmc/models/qwen2moe.py @@ -8,7 +8,7 @@ class Qwen2Moe(BaseModel): def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): super().__init__(model_path, torch_dtype, device_map, use_cache) - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.model.layers def find_embed_layers(self): diff --git a/llmc/models/smollm.py b/llmc/models/smollm.py index b58038d67..aaa562caa 100644 --- a/llmc/models/smollm.py +++ b/llmc/models/smollm.py @@ -8,7 +8,7 @@ class SmolLM(BaseModel): def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): super().__init__(model_path, torch_dtype, device_map, use_cache) - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.model.layers def find_embed_layers(self): diff --git a/llmc/models/stablelm.py b/llmc/models/stablelm.py index 6dcc59abe..5d6aa0570 100644 --- a/llmc/models/stablelm.py +++ b/llmc/models/stablelm.py @@ -8,7 +8,7 @@ class StableLm(BaseModel): def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): super().__init__(model_path, torch_dtype, device_map, use_cache) - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.model.layers self.rotary_emb = self.model.model.rotary_emb diff --git a/llmc/models/starcoder.py b/llmc/models/starcoder.py index 0a97f9b63..a8046c728 100644 --- a/llmc/models/starcoder.py +++ b/llmc/models/starcoder.py @@ -8,7 +8,7 @@ class Starcoder(BaseModel): def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False): super().__init__(model_path, torch_dtype, device_map, use_cache) - def find_blocks(self): + def find_blocks(self, modality='language'): self.blocks = self.model.transformer.h def find_embed_layers(self): diff --git a/llmc/models/vit.py b/llmc/models/vit.py index 476ecd815..330f53f62 100644 --- a/llmc/models/vit.py +++ b/llmc/models/vit.py @@ -22,7 +22,7 @@ def build_model(self): self.processor = ViTImageProcessor.from_pretrained(self.model_path) self.model = ViTForImageClassification.from_pretrained(self.model_path) - def find_blocks(self): + def find_blocks(self, modality='vision'): self.blocks = self.model.vit.encoder.layer def find_embed_layers(self):