Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions src/diffusers/modular_pipelines/components_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def add(self, name: str, component: Any, collection: Optional[str] = None):
id(component) is Python's built-in unique identifier for the object
"""
component_id = f"{name}_{id(component)}"
is_new_component = True

# check for duplicated components
for comp_id, comp in self.components.items():
Expand All @@ -394,6 +395,7 @@ def add(self, name: str, component: Any, collection: Optional[str] = None):
if comp_name == name:
logger.warning(f"ComponentsManager: component '{name}' already exists as '{comp_id}'")
component_id = comp_id
is_new_component = False
break
else:
logger.warning(
Expand Down Expand Up @@ -426,19 +428,39 @@ def add(self, name: str, component: Any, collection: Optional[str] = None):
logger.warning(
f"ComponentsManager: removing existing {name} from collection '{collection}': {comp_id}"
)
self.remove(comp_id)
# remove existing component from this collection (if it is not in any other collection, will be removed from ComponentsManager)
self.remove_from_collection(comp_id, collection)

self.collections[collection].add(component_id)
logger.info(
f"ComponentsManager: added component '{name}' in collection '{collection}': {component_id}"
)
else:
logger.info(f"ComponentsManager: added component '{name}' as '{component_id}'")

if self._auto_offload_enabled:
if self._auto_offload_enabled and is_new_component:
self.enable_auto_cpu_offload(self._auto_offload_device)

return component_id

def remove_from_collection(self, component_id: str, collection: str):
"""
Remove a component from a collection.
"""
if collection not in self.collections:
logger.warning(f"Collection '{collection}' not found in ComponentsManager")
return
if component_id not in self.collections[collection]:
logger.warning(f"Component '{component_id}' not found in collection '{collection}'")
return
# remove from the collection
self.collections[collection].remove(component_id)
# check if this component is in any other collection
comp_colls = [coll for coll, comps in self.collections.items() if component_id in comps]
if not comp_colls: # only if no other collection contains this component, remove it
logger.warning(f"ComponentsManager: removing component '{component_id}' from ComponentsManager")
self.remove(component_id)

def remove(self, component_id: str = None):
"""
Remove a component from the ComponentsManager.
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/modular_pipelines/modular_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,8 @@ def load_id(self) -> str:
Unique identifier for this spec's pretrained load, composed of repo|subfolder|variant|revision (no empty
segments).
"""
if self.default_creation_method == "from_config":
return "null"
parts = [getattr(self, k) for k in self.loading_fields()]
parts = ["null" if p is None else p for p in parts]
return "|".join(p for p in parts if p)
Expand Down
Loading