diff --git a/src/diffusers/modular_pipelines/components_manager.py b/src/diffusers/modular_pipelines/components_manager.py index 08e6d80fefd2..f48a227e2edb 100644 --- a/src/diffusers/modular_pipelines/components_manager.py +++ b/src/diffusers/modular_pipelines/components_manager.py @@ -386,6 +386,7 @@ def add(self, name: str, component: Any, collection: Optional[str] = None): id(component) is Python's built-in unique identifier for the object """ component_id = f"{name}_{id(component)}" + is_new_component = True # check for duplicated components for comp_id, comp in self.components.items(): @@ -394,6 +395,7 @@ def add(self, name: str, component: Any, collection: Optional[str] = None): if comp_name == name: logger.warning(f"ComponentsManager: component '{name}' already exists as '{comp_id}'") component_id = comp_id + is_new_component = False break else: logger.warning( @@ -426,7 +428,9 @@ def add(self, name: str, component: Any, collection: Optional[str] = None): logger.warning( f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}" ) - self.remove(comp_id) + # remove existing component from this collection (if it is not in any other collection, will be removed from ComponentsManager) + self.remove_from_collection(comp_id, collection) + self.collections[collection].add(component_id) logger.info( f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}" @@ -434,11 +438,29 @@ def add(self, name: str, component: Any, collection: Optional[str] = None): else: logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'") - if self._auto_offload_enabled: + if self._auto_offload_enabled and is_new_component: self.enable_auto_cpu_offload(self._auto_offload_device) return component_id + def remove_from_collection(self, component_id: str, collection: str): + """ + Remove a component from a collection. + """ + if collection not in self.collections: + logger.warning(f"Collection '{collection}' not found in ComponentsManager") + return + if component_id not in self.collections[collection]: + logger.warning(f"Component '{component_id}' not found in collection '{collection}'") + return + # remove from the collection + self.collections[collection].remove(component_id) + # check if this component is in any other collection + comp_colls = [coll for coll, comps in self.collections.items() if component_id in comps] + if not comp_colls: # only if no other collection contains this component, remove it + logger.warning(f"ComponentsManager: removing component '{component_id}' from ComponentsManager") + self.remove(component_id) + def remove(self, component_id: str = None): """ Remove a component from the ComponentsManager. diff --git a/src/diffusers/modular_pipelines/modular_pipeline_utils.py b/src/diffusers/modular_pipelines/modular_pipeline_utils.py index b63925df26ff..f2fc015e948f 100644 --- a/src/diffusers/modular_pipelines/modular_pipeline_utils.py +++ b/src/diffusers/modular_pipelines/modular_pipeline_utils.py @@ -185,6 +185,8 @@ def load_id(self) -> str: Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty segments). """ + if self.default_creation_method == "from_config": + return "null" parts = [getattr(self, k) for k in self.loading_fields()] parts = ["null" if p is None else p for p in parts] return "|".join(p for p in parts if p)