@@ -28,6 +28,7 @@ def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False):
2828 self .build_model ()
2929 self .model .eval ()
3030 self .find_blocks ()
31+ self .find_encoder_blocks ()
3132 self .find_embed_layers ()
3233 self .find_block_name ()
3334 self .add_layernorms_class ()
@@ -36,14 +37,20 @@ def __init__(self, model_path, torch_dtype, device_map=None, use_cache=False):
3637 def find_blocks (self ):
3738 pass
3839
40+ def find_encoder_blocks (self ):
41+ pass
42+
43+ def get_encoder_catcher (self , first_block_input ):
44+ pass
45+
3946 def find_block_name (self ):
4047 pass
4148
4249 def get_model (self ):
4350 return self .model
4451
45- def get_blocks (self ):
46- return self .blocks
52+ def get_blocks (self , modality = 'language' ):
53+ return self .blocks if modality == 'language' else self . encoder_blocks
4754
4855 @abstractmethod
4956 def find_embed_layers (self ):
@@ -186,6 +193,43 @@ def collect_first_block_input(self, calib_data, padding_mask=None, padding_side=
186193 self .blocks [0 ] = self .blocks [0 ].cpu ()
187194 self .move_embed_to_device ('cpu' )
188195
196+ @torch .no_grad ()
197+ def collect_first_encoder_block_input (self , calib_data , padding_mask = None , padding_side = None , data_type = 'txt' ): # noqa
198+ first_block_input = defaultdict (list )
199+
200+ Catcher = self .get_encoder_catcher (first_block_input )
201+
202+ self .move_embed_to_device ('cuda' )
203+ if data_type == 'img_txt' :
204+ self .vision_model = self .vision_model .to ('cuda' )
205+ self .projector = self .projector .to ('cuda' )
206+ self .encoder_blocks [0 ] = self .encoder_blocks [0 ].cuda ()
207+ self .encoder_blocks [0 ] = Catcher (self .encoder_blocks [0 ])
208+
209+ for data in calib_data :
210+ if isinstance (data , BatchFeature ):
211+ data = data .to (next (self .model .parameters ()).device )
212+ else :
213+ data = {
214+ k : (v .to (next (self .model .parameters ()).device ) if torch .is_tensor (v ) else v )
215+ for k , v in data .items ()
216+ }
217+ try :
218+ if data_type in ['txt' , 'img' ]:
219+ self .model (** data )
220+ elif data_type == 'img_txt' :
221+ self .vlm_model .generate (** data , max_new_tokens = 128 , do_sample = False )
222+ except ValueError :
223+ pass
224+ self .first_block_input = first_block_input
225+ self .padding_mask = None
226+ if data_type == 'img_txt' :
227+ self .vision_model = self .vision_model .cpu ()
228+ self .projector = self .projector .cpu ()
229+ self .encoder_blocks [0 ] = self .encoder_blocks [0 ].module
230+ self .encoder_blocks [0 ] = self .encoder_blocks [0 ].cpu ()
231+ self .move_embed_to_device ('cpu' )
232+
189233 def get_one_pad_setting (self , padding_side , length ):
190234 if padding_side == 'left' :
191235 return [0 , length ]
@@ -280,17 +324,25 @@ def set_mix_bits_params_dict(self, block_idx, name, params_dict):
280324 params_mix_dict ['a_qdq' ] = None
281325 return params_mix_dict
282326
283- def replace_module_all (self , module , params_dict , keep_device = False ):
284- for block_idx in range (len (self . blocks )):
285- logger .info (f'Replace block index: { block_idx } /{ len (self . blocks )} ' )
286- block = self . blocks [block_idx ]
327+ def replace_modality_module_all (self , module , blocks , params_dict , keep_device = False ):
328+ for block_idx in range (len (blocks )):
329+ logger .info (f'Replace block index: { block_idx } /{ len (blocks )} ' )
330+ block = blocks [block_idx ]
287331 if keep_device :
288332 self .replace_module_block (module , block , block_idx , params_dict )
289333 else :
290334 block = block .cuda ()
291335 self .replace_module_block (module , block , block_idx , params_dict )
292336 block = block .cpu ()
293337
338+ def replace_module_all (self , module , params_dict , keep_device = False ):
339+ if hasattr (self , 'encoder_blocks' ):
340+ logger .info ('start replace vision blocks' )
341+ self .replace_modality_module_all (module , self .encoder_blocks , params_dict , keep_device )
342+
343+ logger .info ('start replace language blocks' )
344+ self .replace_modality_module_all (module , self .blocks , params_dict , keep_device )
345+
294346 gc .collect ()
295347 torch .cuda .empty_cache ()
296348 logger .info (f'The Replaced model: { self .model } ' )
0 commit comments