@@ -304,59 +304,66 @@ def disable_auto_cpu_offload(self):
304304        self .model_hooks  =  None 
305305        self ._auto_offload_enabled  =  False 
306306
307-     def  get_model_info (self , name : str ) ->  Optional [Dict [str , Any ]]:
308-         """Get comprehensive information about a model  component. 
307+     def  get_model_info (self , name : str ,  fields :  Optional [ Union [ str ,  List [ str ]]]  =   None ) ->  Optional [Dict [str , Any ]]:
308+         """Get comprehensive information about a component. 
309309         
310310        Args: 
311311            name: Name of the component to get info for 
312-              
312+             fields: Optional field(s) to return. Can be a string for single field or list of fields. 
313+                    If None, returns all fields. 
314+                     
313315        Returns: 
314-             Dictionary containing model metadata including: 
315-             - model_id: Name of the model 
316-             - class_name: Class name of the model 
317-             - device: Device the model is on 
318-             - dtype: Data type of the model 
319-             - size_gb: Size of the model in GB 
320-             - added_time: Timestamp when model was added 
321-             - active_adapters: List of active adapters (if applicable) 
322-             - attn_proc: List of attention processor types (if applicable) 
323-             Returns None if component is not a torch.nn.Module 
316+             Dictionary containing requested component metadata. 
317+             If fields is specified, returns only those fields. 
318+             If a single field is requested as string, returns just that field's value. 
324319        """ 
325320        if  name  not  in self .components :
326321            raise  ValueError (f"Component '{ name }  )
327322
328323        component  =  self .components [name ]
329324
330-         # Only process torch.nn.Module components 
331-         if  not  isinstance (component , torch .nn .Module ):
332-             return  None 
333- 
325+         # Build complete info dict first 
334326        info  =  {
335327            "model_id" : name ,
336-             "class_name" : component .__class__ .__name__ ,
337-             "device" : str (getattr (component , "device" , "N/A" )),
338-             "dtype" : str (component .dtype ) if  hasattr (component , "dtype" ) else  None ,
339328            "added_time" : self .added_time [name ],
340-             "size_gb" : get_memory_footprint (component ) /  (1024 ** 3 ),
341-             "active_adapters" : None ,  # Default to None 
342329        }
343- 
344-         # Get active adapters if applicable 
345-         if  isinstance (component , ModelMixin ):
346-             from  peft .tuners .tuners_utils  import  BaseTunerLayer 
347-             for  module  in  component .modules ():
348-                 if  isinstance (module , BaseTunerLayer ):
349-                     info ["active_adapters" ] =  module .active_adapters 
350-                     break 
351- 
352-         # Get attention processors if applicable 
353-         if  hasattr (component , "attn_processors" ):
354-             processors  =  component .attn_processors 
355-             # Get unique processor types 
356-             processor_types  =  list (set (str (v .__class__ .__name__ ) for  v  in  processors .values ()))
357-             if  processor_types :
358-                 info ["attn_proc" ] =  processor_types 
359- 
330+         
331+         # Additional info for torch.nn.Module components 
332+         if  isinstance (component , torch .nn .Module ):
333+             info .update ({
334+                 "class_name" : component .__class__ .__name__ ,
335+                 "size_gb" : get_memory_footprint (component ) /  (1024 ** 3 ),
336+                 "adapters" : None ,  # Default to None 
337+             })
338+ 
339+             # Get adapters if applicable 
340+             if  hasattr (component , "peft_config" ):
341+                 info ["adapters" ] =  list (component .peft_config .keys ())
342+ 
343+             # Check for IP-Adapter scales 
344+             if  hasattr (component , "_load_ip_adapter_weights" ) and  hasattr (component , "attn_processors" ):
345+                 processors  =  component .attn_processors 
346+                 # First check if any processor is an IP-Adapter 
347+                 processor_types  =  [v .__class__ .__name__  for  v  in  processors .values ()]
348+                 if  any ("IPAdapter"  in  ptype  for  ptype  in  processor_types ):
349+                     # Then get scales only from IP-Adapter processors 
350+                     scales  =  {
351+                         k : v .scale  
352+                         for  k , v  in  processors .items () 
353+                         if  hasattr (v , "scale" ) and  "IPAdapter"  in  v .__class__ .__name__ 
354+                     }
355+                     if  scales :
356+                         info ["ip_adapter" ] =  summarize_dict_by_value_and_parts (scales )
357+ 
358+         # If fields specified, filter info 
359+         if  fields  is  not None :
360+             if  isinstance (fields , str ):
361+                 # Single field requested, return just that value 
362+                 return  {fields : info .get (fields )}
363+             else :
364+                 # List of fields requested, return dict with just those fields 
365+                 return  {k : v  for  k , v  in  info .items () if  k  in  fields }
366+             
360367        return  info 
361368
362369    def  __repr__ (self ):
@@ -383,18 +390,16 @@ def __repr__(self):
383390            output  +=  "Models:\n "  +  dash_line 
384391            # Column headers 
385392            output  +=  f"{ 'Model ID' :<{col_widths ['id' ]}} { 'Class' :<{col_widths ['class' ]}}  
386-             output  +=  f"{ 'Device' :<{col_widths ['device' ]}} { 'Dtype' :<{col_widths ['dtype' ]}}   \n " 
393+             output  +=  f"{ 'Device' :<{col_widths ['device' ]}} { 'Dtype' :<{col_widths ['dtype' ]}} \n " 
387394            output  +=  dash_line 
388395
389396            # Model entries 
390397            for  name , component  in  models .items ():
391398                info  =  self .get_model_info (name )
399+                 device  =  str (getattr (component , "device" , "N/A" ))
400+                 dtype  =  str (component .dtype ) if  hasattr (component , "dtype" ) else  "N/A" 
392401                output  +=  f"{ name :<{col_widths ['id' ]}} { info ['class_name' ]:<{col_widths ['class' ]}}  
393-                 output  +=  (
394-                     f"{ info ['device' ]:<{col_widths ['device' ]}}  
395-                     f"{ info ['dtype' ]:<{col_widths ['dtype' ]}}  
396-                     f"{ info ['size_gb' ]:.2f} \n " 
397-                 )
402+                 output  +=  f"{ device :<{col_widths ['device' ]}} { dtype :<{col_widths ['dtype' ]}} { info ['size_gb' ]:.2f} \n " 
398403            output  +=  dash_line 
399404
400405        # Other components section 
@@ -415,12 +420,12 @@ def __repr__(self):
415420        output  +=  "\n Additional Component Info:\n "  +  "="  *  50  +  "\n " 
416421        for  name  in  self .components :
417422            info  =  self .get_model_info (name )
418-             if  info  is  not None  and  (info .get ("active_adapters " ) is  not None  or  info .get ("attn_proc " )):
423+             if  info  is  not None  and  (info .get ("adapters " ) is  not None  or  info .get ("ip_adapter " )):
419424                output  +=  f"\n { name } \n " 
420-                 if  info .get ("active_adapters " ) is  not None :
421-                     output  +=  f"  Active  Adapters: { info ['active_adapters ' ]} \n " 
422-                 if  info .get ("attn_proc " ):
423-                     output  +=  f"  Attention Processors:  { info [ 'attn_proc' ] } \n " 
425+                 if  info .get ("adapters " ) is  not None :
426+                     output  +=  f"  Adapters: { info ['adapters ' ]} \n " 
427+                 if  info .get ("ip_adapter " ):
428+                     output  +=  f"  IP-Adapter: Enabled \n " 
424429                output  +=  f"  Added Time: { time .strftime ('%Y-%m-%d %H:%M:%S' , time .localtime (info ['added_time' ]))} \n " 
425430
426431        return  output 
@@ -438,3 +443,64 @@ def add_from_pretrained(self, pretrained_model_name_or_path, **kwargs):
438443                    f"1. remove the existing component with remove('{ name } \n " 
439444                    f"2. Use a different name: add('{ name }  
440445                )
446+ 
447+ def  summarize_dict_by_value_and_parts (d : Dict [str , Any ]) ->  Dict [str , Any ]:
448+     """Summarizes a dictionary by finding common prefixes that share the same value. 
449+      
450+     For a dictionary with dot-separated keys like: 
451+     { 
452+         'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], 
453+         'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], 
454+         'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], 
455+     } 
456+      
457+     Returns a dictionary where keys are the shortest common prefixes and values are their shared values: 
458+     { 
459+         'down_blocks': [0.6], 
460+         'up_blocks': [0.3] 
461+     } 
462+     """ 
463+     # First group by values - convert lists to tuples to make them hashable 
464+     value_to_keys  =  {}
465+     for  key , value  in  d .items ():
466+         value_tuple  =  tuple (value ) if  isinstance (value , list ) else  value 
467+         if  value_tuple  not  in value_to_keys :
468+             value_to_keys [value_tuple ] =  []
469+         value_to_keys [value_tuple ].append (key )
470+     
471+     def  find_common_prefix (keys : List [str ]) ->  str :
472+         """Find the shortest common prefix among a list of dot-separated keys.""" 
473+         if  not  keys :
474+             return  "" 
475+         if  len (keys ) ==  1 :
476+             return  keys [0 ]
477+             
478+         # Split all keys into parts 
479+         key_parts  =  [k .split ('.' ) for  k  in  keys ]
480+         
481+         # Find how many initial parts are common 
482+         common_length  =  0 
483+         for  parts  in  zip (* key_parts ):
484+             if  len (set (parts )) ==  1 :  # All parts at this position are the same 
485+                 common_length  +=  1 
486+             else :
487+                 break 
488+                 
489+         if  common_length  ==  0 :
490+             return  "" 
491+             
492+         # Return the common prefix 
493+         return  '.' .join (key_parts [0 ][:common_length ])
494+ 
495+     # Create summary by finding common prefixes for each value group 
496+     summary  =  {}
497+     for  value_tuple , keys  in  value_to_keys .items ():
498+         prefix  =  find_common_prefix (keys )
499+         if  prefix :  # Only add if we found a common prefix 
500+             # Convert tuple back to list if it was originally a list 
501+             value  =  list (value_tuple ) if  isinstance (d [keys [0 ]], list ) else  value_tuple 
502+             summary [prefix ] =  value 
503+         else :
504+             summary ["" ] =  value   # Use empty string if no common prefix 
505+             
506+     return  summary 
0 commit comments