@@ -142,7 +142,7 @@ def custom_offload_with_hook(
142142    user_hook .attach ()
143143    return  user_hook 
144144
145- 
145+ # this is the class that user can customize to implement their own offload strategy 
146146class  AutoOffloadStrategy :
147147    """ 
148148    Offload strategy that should be used with `CustomOffloadHook` to automatically offload models to the CPU based on 
@@ -213,7 +213,101 @@ def search_best_candidate(module_sizes, min_memory_offload):
213213        return  hooks_to_offload 
214214
215215
216+ # utils for display component info in a readable format 
217+ # TODO: move to a different file 
218+ def  summarize_dict_by_value_and_parts (d : Dict [str , Any ]) ->  Dict [str , Any ]:
219+     """Summarizes a dictionary by finding common prefixes that share the same value. 
220+ 
221+     For a dictionary with dot-separated keys like: { 
222+         'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], 
223+         'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], 
224+         'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], 
225+     } 
226+ 
227+     Returns a dictionary where keys are the shortest common prefixes and values are their shared values: { 
228+         'down_blocks': [0.6], 'up_blocks': [0.3] 
229+     } 
230+     """ 
231+     # First group by values - convert lists to tuples to make them hashable 
232+     value_to_keys  =  {}
233+     for  key , value  in  d .items ():
234+         value_tuple  =  tuple (value ) if  isinstance (value , list ) else  value 
235+         if  value_tuple  not  in   value_to_keys :
236+             value_to_keys [value_tuple ] =  []
237+         value_to_keys [value_tuple ].append (key )
238+ 
239+     def  find_common_prefix (keys : List [str ]) ->  str :
240+         """Find the shortest common prefix among a list of dot-separated keys.""" 
241+         if  not  keys :
242+             return  "" 
243+         if  len (keys ) ==  1 :
244+             return  keys [0 ]
245+ 
246+         # Split all keys into parts 
247+         key_parts  =  [k .split ("." ) for  k  in  keys ]
248+ 
249+         # Find how many initial parts are common 
250+         common_length  =  0 
251+         for  parts  in  zip (* key_parts ):
252+             if  len (set (parts )) ==  1 :  # All parts at this position are the same 
253+                 common_length  +=  1 
254+             else :
255+                 break 
256+ 
257+         if  common_length  ==  0 :
258+             return  "" 
259+ 
260+         # Return the common prefix 
261+         return  "." .join (key_parts [0 ][:common_length ])
262+ 
263+     # Create summary by finding common prefixes for each value group 
264+     summary  =  {}
265+     for  value_tuple , keys  in  value_to_keys .items ():
266+         prefix  =  find_common_prefix (keys )
267+         if  prefix :  # Only add if we found a common prefix 
268+             # Convert tuple back to list if it was originally a list 
269+             value  =  list (value_tuple ) if  isinstance (d [keys [0 ]], list ) else  value_tuple 
270+             summary [prefix ] =  value 
271+         else :
272+             summary ["" ] =  value   # Use empty string if no common prefix 
273+ 
274+     return  summary 
275+ 
276+ 
216277class  ComponentsManager :
278+     """ 
279+     A central registry and management system for model components across multiple pipelines. 
280+      
281+     [`ComponentsManager`] provides a unified way to register, track, and reuse model components 
282+     (like UNet, VAE, text encoders, etc.) across different modular pipelines. It includes 
283+     features for duplicate detection, memory management, and component organization. 
284+      
285+     <Tip warning={true}> 
286+ 
287+         This is an experimental feature and is likely to change in the future. 
288+ 
289+     </Tip> 
290+          
291+     Example: 
292+         ```python 
293+         from diffusers import ComponentsManager 
294+          
295+         # Create a components manager 
296+         cm = ComponentsManager() 
297+          
298+         # Add components 
299+         cm.add("unet", unet_model, collection="sdxl") 
300+         cm.add("vae", vae_model, collection="sdxl") 
301+          
302+         # Enable auto offloading 
303+         cm.enable_auto_cpu_offload(device="cuda") 
304+          
305+         # Retrieve components 
306+         unet = cm.get_one(name="unet", collection="sdxl") 
307+         ``` 
308+     """ 
309+ 
310+ 
217311    _available_info_fields  =  [
218312        "model_id" ,
219313        "added_time" ,
@@ -278,7 +372,19 @@ def _lookup_ids(
278372    def  _id_to_name (component_id : str ):
279373        return  "_" .join (component_id .split ("_" )[:- 1 ])
280374
281-     def  add (self , name , component , collection : Optional [str ] =  None ):
375+     def  add (self , name : str , component : Any , collection : Optional [str ] =  None ):
376+         """ 
377+         Add a component to the ComponentsManager. 
378+ 
379+         Args: 
380+             name (str): The name of the component 
381+             component (Any): The component to add 
382+             collection (Optional[str]): The collection to add the component to 
383+ 
384+         Returns: 
385+             str: The unique component ID, which is generated as "{name}_{id(component)}" where  
386+                  id(component) is Python's built-in unique identifier for the object 
387+         """ 
282388        component_id  =  f"{ name }  _{ id (component )}  " 
283389
284390        # check for duplicated components 
@@ -334,6 +440,12 @@ def add(self, name, component, collection: Optional[str] = None):
334440        return  component_id 
335441
336442    def  remove (self , component_id : str  =  None ):
443+         """ 
444+         Remove a component from the ComponentsManager. 
445+ 
446+         Args: 
447+             component_id (str): The ID of the component to remove 
448+         """ 
337449        if  component_id  not  in   self .components :
338450            logger .warning (f"Component '{ component_id }  ' not found in ComponentsManager" )
339451            return 
@@ -545,6 +657,22 @@ def matches_pattern(component_id, pattern, exact_match=False):
545657        return  get_return_dict (matches , return_dict_with_names )
546658
547659    def  enable_auto_cpu_offload (self , device : Union [str , int , torch .device ] =  "cuda" , memory_reserve_margin = "3GB" ):
660+         """ 
661+         Enable automatic CPU offloading for all components. 
662+ 
663+         The algorithm works as follows: 
664+         1. All models start on CPU by default 
665+         2. When a model's forward pass is called, it's moved to the execution device 
666+         3. If there's insufficient memory, other models on the device are moved back to CPU 
667+         4. The system tries to offload the smallest combination of models that frees enough memory 
668+         5. Models stay on the execution device until another model needs memory and forces them off 
669+ 
670+         Args: 
671+             device (Union[str, int, torch.device]): The execution device where models are moved for forward passes 
672+             memory_reserve_margin (str): The memory reserve margin to use, default is 3GB. This is the amount of  
673+                                         memory to keep free on the device to avoid running out of memory during  
674+                                         model execution (e.g., for intermediate activations, gradients, etc.) 
675+         """ 
548676        if  not  is_accelerate_available ():
549677            raise  ImportError ("Make sure to install accelerate to use auto_cpu_offload" )
550678
@@ -574,6 +702,9 @@ def enable_auto_cpu_offload(self, device: Union[str, int, torch.device] = "cuda"
574702        self ._auto_offload_device  =  device 
575703
576704    def  disable_auto_cpu_offload (self ):
705+         """ 
706+         Disable automatic CPU offloading for all components. 
707+         """ 
577708        if  self .model_hooks  is  None :
578709            self ._auto_offload_enabled  =  False 
579710            return 
@@ -595,13 +726,12 @@ def get_model_info(
595726        """Get comprehensive information about a component. 
596727
597728        Args: 
598-             component_id: Name of the component to get info for 
599-             fields:  Optional field (s) to return. Can be a string for single field or list of fields. 
729+             component_id (str) : Name of the component to get info for 
730+             fields ( Optional[Union[str, List[str]]]): Field (s) to return. Can be a string for single field or list of fields. 
600731                   If None, uses the available_info_fields setting. 
601732
602733        Returns: 
603-             Dictionary containing requested component metadata. If fields is specified, returns only those fields. If a 
604-             single field is requested as string, returns just that field's value. 
734+             Dictionary containing requested component metadata. If fields is specified, returns only those fields. Otherwise, returns all fields. 
605735        """ 
606736        if  component_id  not  in   self .components :
607737            raise  ValueError (f"Component '{ component_id }  ' not found in ComponentsManager" )
@@ -808,15 +938,16 @@ def get_one(
808938        load_id : Optional [str ] =  None ,
809939    ) ->  Any :
810940        """ 
811-         Get a single component by either: (1) searching name (pattern matching), collection, or load_id. (2) passing in 
812-         a component_id Raises an error if multiple components match or none are found. support pattern matching for 
813-         name 
941+         Get a single component by either:  
942+         - searching name (pattern matching), collection, or load_id.  
943+         - passing in a component_id 
944+         Raises an error if multiple components match or none are found.  
814945
815946        Args: 
816-             component_id: Optional component ID to get 
817-             name: Component name or pattern 
818-             collection: Optional collection to filter by 
819-             load_id: Optional load_id to filter by 
947+             component_id (Optional[str]) : Optional component ID to get 
948+             name (Optional[str]) : Component name or pattern 
949+             collection (Optional[str]) : Optional collection to filter by 
950+             load_id (Optional[str]) : Optional load_id to filter by 
820951
821952        Returns: 
822953            A single component 
@@ -847,6 +978,13 @@ def get_one(
847978    def  get_ids (self , names : Union [str , List [str ]] =  None , collection : Optional [str ] =  None ):
848979        """ 
849980        Get component IDs by a list of names, optionally filtered by collection. 
981+ 
982+         Args: 
983+             names (Union[str, List[str]]): List of component names 
984+             collection (Optional[str]): Optional collection to filter by 
985+ 
986+         Returns: 
987+             List[str]: List of component IDs 
850988        """ 
851989        ids  =  set ()
852990        if  not  isinstance (names , list ):
@@ -858,6 +996,20 @@ def get_ids(self, names: Union[str, List[str]] = None, collection: Optional[str]
858996    def  get_components_by_ids (self , ids : List [str ], return_dict_with_names : Optional [bool ] =  True ):
859997        """ 
860998        Get components by a list of IDs. 
999+ 
1000+         Args: 
1001+             ids (List[str]):  
1002+                 List of component IDs 
1003+             return_dict_with_names (Optional[bool]): 
1004+                 Whether to return a dictionary with component names as keys: 
1005+ 
1006+         Returns: 
1007+             Dict[str, Any]: Dictionary of components.  
1008+                 - If return_dict_with_names=True, keys are component names. 
1009+                 - If return_dict_with_names=False, keys are component IDs. 
1010+ 
1011+         Raises: 
1012+             ValueError: If duplicate component names are found in the search results when return_dict_with_names=True 
8611013        """ 
8621014        components  =  {id : self .components [id ] for  id  in  ids }
8631015
@@ -877,65 +1029,17 @@ def get_components_by_ids(self, ids: List[str], return_dict_with_names: Optional
8771029    def  get_components_by_names (self , names : List [str ], collection : Optional [str ] =  None ):
8781030        """ 
8791031        Get components by a list of names, optionally filtered by collection. 
880-         """ 
881-         ids  =  self .get_ids (names , collection )
882-         return  self .get_components_by_ids (ids )
883- 
884- 
885- def  summarize_dict_by_value_and_parts (d : Dict [str , Any ]) ->  Dict [str , Any ]:
886-     """Summarizes a dictionary by finding common prefixes that share the same value. 
887- 
888-     For a dictionary with dot-separated keys like: { 
889-         'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6], 
890-         'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6], 
891-         'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3], 
892-     } 
893- 
894-     Returns a dictionary where keys are the shortest common prefixes and values are their shared values: { 
895-         'down_blocks': [0.6], 'up_blocks': [0.3] 
896-     } 
897-     """ 
898-     # First group by values - convert lists to tuples to make them hashable 
899-     value_to_keys  =  {}
900-     for  key , value  in  d .items ():
901-         value_tuple  =  tuple (value ) if  isinstance (value , list ) else  value 
902-         if  value_tuple  not  in   value_to_keys :
903-             value_to_keys [value_tuple ] =  []
904-         value_to_keys [value_tuple ].append (key )
9051032
906-     def  find_common_prefix (keys : List [str ]) ->  str :
907-         """Find the shortest common prefix among a list of dot-separated keys.""" 
908-         if  not  keys :
909-             return  "" 
910-         if  len (keys ) ==  1 :
911-             return  keys [0 ]
912- 
913-         # Split all keys into parts 
914-         key_parts  =  [k .split ("." ) for  k  in  keys ]
915- 
916-         # Find how many initial parts are common 
917-         common_length  =  0 
918-         for  parts  in  zip (* key_parts ):
919-             if  len (set (parts )) ==  1 :  # All parts at this position are the same 
920-                 common_length  +=  1 
921-             else :
922-                 break 
923- 
924-         if  common_length  ==  0 :
925-             return  "" 
1033+         Args: 
1034+             names (List[str]): List of component names 
1035+             collection (Optional[str]): Optional collection to filter by 
9261036
927-         # Return the common prefix 
928-         return   "." . join ( key_parts [ 0 ][: common_length ]) 
1037+         Returns:  
1038+             Dict[str, Any]: Dictionary of components with component names as keys  
9291039
930-     # Create summary by finding common prefixes for each value group 
931-     summary  =  {}
932-     for  value_tuple , keys  in  value_to_keys .items ():
933-         prefix  =  find_common_prefix (keys )
934-         if  prefix :  # Only add if we found a common prefix 
935-             # Convert tuple back to list if it was originally a list 
936-             value  =  list (value_tuple ) if  isinstance (d [keys [0 ]], list ) else  value_tuple 
937-             summary [prefix ] =  value 
938-         else :
939-             summary ["" ] =  value   # Use empty string if no common prefix 
1040+         Raises: 
1041+             ValueError: If duplicate component names are found in the search results 
1042+         """ 
1043+         ids  =  self .get_ids (names , collection )
1044+         return  self .get_components_by_ids (ids )
9401045
941-     return  summary 
0 commit comments