1414
1515from  collections  import  OrderedDict 
1616from  itertools  import  combinations 
17- from  typing  import  List , Optional , Union 
17+ from  typing  import  List , Optional , Union ,  Dict ,  Any 
1818
1919import  torch 
20+ import  time 
21+ from  dataclasses  import  dataclass 
2022
2123from  ..utils  import  (
2224    is_accelerate_available ,
2325    logging ,
2426)
27+ from  ..models .modeling_utils  import  ModelMixin 
2528
2629
2730if  is_accelerate_available ():
@@ -95,9 +98,6 @@ def pre_forward(self, module, *args, **kwargs):
9598            if  self .other_hooks  is  not None :
9699                hooks_to_offload  =  [hook  for  hook  in  self .other_hooks  if  hook .model .device  ==  self .execution_device ]
97100                # offload all other hooks 
98-                 import  time 
99- 
100-                 # YiYi Notes: only logging time for now to monitor the overhead of offloading strategy (remove later) 
101101                start_time  =  time .perf_counter ()
102102                if  self .offload_strategy  is  not None :
103103                    hooks_to_offload  =  self .offload_strategy (
@@ -231,17 +231,27 @@ def search_best_candidate(module_sizes, min_memory_offload):
231231class  ComponentsManager :
232232    def  __init__ (self ):
233233        self .components  =  OrderedDict ()
234+         self .added_time  =  OrderedDict ()  # Store when components were added 
234235        self .model_hooks  =  None 
235236        self ._auto_offload_enabled  =  False 
236237
237238    def  add (self , name , component ):
238-         if  name  not  in self .components :
239-             self .components [name ] =  component 
240-             if  self ._auto_offload_enabled :
241-                 self .enable_auto_cpu_offload (self ._auto_offload_device )
239+         if  name  in  self .components :
240+             logger .warning (f"Overriding existing component '{ name }  )
241+         self .components [name ] =  component 
242+         self .added_time [name ] =  time .time ()
243+         
244+         if  self ._auto_offload_enabled :
245+             self .enable_auto_cpu_offload (self ._auto_offload_device )
242246
243247    def  remove (self , name ):
248+         if  name  not  in self .components :
249+             logger .warning (f"Component '{ name }  )
250+             return 
251+             
244252        self .components .pop (name )
253+         self .added_time .pop (name )
254+         
245255        if  self ._auto_offload_enabled :
246256            self .enable_auto_cpu_offload (self ._auto_offload_device )
247257
@@ -294,6 +304,61 @@ def disable_auto_cpu_offload(self):
294304        self .model_hooks  =  None 
295305        self ._auto_offload_enabled  =  False 
296306
307+     def  get_model_info (self , name : str ) ->  Optional [Dict [str , Any ]]:
308+         """Get comprehensive information about a model component. 
309+          
310+         Args: 
311+             name: Name of the component to get info for 
312+              
313+         Returns: 
314+             Dictionary containing model metadata including: 
315+             - model_id: Name of the model 
316+             - class_name: Class name of the model 
317+             - device: Device the model is on 
318+             - dtype: Data type of the model 
319+             - size_gb: Size of the model in GB 
320+             - added_time: Timestamp when model was added 
321+             - active_adapters: List of active adapters (if applicable) 
322+             - attn_proc: List of attention processor types (if applicable) 
323+             Returns None if component is not a torch.nn.Module 
324+         """ 
325+         if  name  not  in self .components :
326+             raise  ValueError (f"Component '{ name }  )
327+ 
328+         component  =  self .components [name ]
329+         
330+         # Only process torch.nn.Module components 
331+         if  not  isinstance (component , torch .nn .Module ):
332+             return  None 
333+ 
334+         info  =  {
335+             "model_id" : name ,
336+             "class_name" : component .__class__ .__name__ ,
337+             "device" : str (getattr (component , "device" , "N/A" )),
338+             "dtype" : str (component .dtype ) if  hasattr (component , "dtype" ) else  None ,
339+             "added_time" : self .added_time [name ],
340+             "size_gb" : get_memory_footprint (component ) /  (1024 ** 3 ),
341+             "active_adapters" : None ,  # Default to None 
342+         }
343+ 
344+         # Get active adapters if applicable 
345+         if  isinstance (component , ModelMixin ):
346+             from  peft .tuners .tuners_utils  import  BaseTunerLayer 
347+             for  module  in  component .modules ():
348+                 if  isinstance (module , BaseTunerLayer ):
349+                     info ["active_adapters" ] =  module .active_adapters 
350+                     break 
351+ 
352+         # Get attention processors if applicable 
353+         if  hasattr (component , "attn_processors" ):
354+             processors  =  component .attn_processors 
355+             # Get unique processor types 
356+             processor_types  =  list (set (str (v .__class__ .__name__ ) for  v  in  processors .values ()))
357+             if  processor_types :
358+                 info ["attn_proc" ] =  processor_types 
359+ 
360+         return  info 
361+ 
297362    def  __repr__ (self ):
298363        col_widths  =  {
299364            "id" : max (15 , max (len (id ) for  id  in  self .components .keys ())),
@@ -323,14 +388,12 @@ def __repr__(self):
323388
324389            # Model entries 
325390            for  name , component  in  models .items ():
326-                 device  =  component .device 
327-                 dtype  =  component .dtype 
328-                 size_bytes  =  get_memory_footprint (component )
329-                 size_gb  =  size_bytes  /  (1024 ** 3 )
330- 
331-                 output  +=  f"{ name :<{col_widths ['id' ]}} { component .__class__ .__name__ :<{col_widths ['class' ]}}  
391+                 info  =  self .get_model_info (name )
392+                 output  +=  f"{ name :<{col_widths ['id' ]}} { info ['class_name' ]:<{col_widths ['class' ]}}  
332393                output  +=  (
333-                     f"{ str (device ):<{col_widths ['device' ]}} { str (dtype ):<{col_widths ['dtype' ]}} { size_gb :.2f} \n " 
394+                     f"{ info ['device' ]:<{col_widths ['device' ]}}  
395+                     f"{ info ['dtype' ]:<{col_widths ['dtype' ]}}  
396+                     f"{ info ['size_gb' ]:.2f} \n " 
334397                )
335398            output  +=  dash_line 
336399
@@ -348,6 +411,18 @@ def __repr__(self):
348411                output  +=  f"{ name :<{col_widths ['id' ]}} { component .__class__ .__name__ :<{col_widths ['class' ]}} \n " 
349412            output  +=  dash_line 
350413
414+         # Add additional component info 
415+         output  +=  "\n Additional Component Info:\n "  +  "="  *  50  +  "\n " 
416+         for  name  in  self .components :
417+             info  =  self .get_model_info (name )
418+             if  info  is  not None  and  (info .get ("active_adapters" ) is  not None  or  info .get ("attn_proc" )):
419+                 output  +=  f"\n { name } \n " 
420+                 if  info .get ("active_adapters" ) is  not None :
421+                     output  +=  f"  Active Adapters: { info ['active_adapters' ]} \n " 
422+                 if  info .get ("attn_proc" ):
423+                     output  +=  f"  Attention Processors: { info ['attn_proc' ]} \n " 
424+                 output  +=  f"  Added Time: { time .strftime ('%Y-%m-%d %H:%M:%S' , time .localtime (info ['added_time' ]))} \n " 
425+         
351426        return  output 
352427
353428    def  add_from_pretrained (self , pretrained_model_name_or_path , ** kwargs ):
0 commit comments