Skip to content

Commit 4bed3e3

Browse files
committed
up up
1 parent 00a3bc9 commit 4bed3e3

File tree

2 files changed

+93
-15
lines changed

2 files changed

+93
-15
lines changed

src/diffusers/loaders/ip_adapter.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -586,6 +586,9 @@ def unload_ip_adapter(self):
586586
"""
587587

588588
# remove hidden encoder
589+
if self.unet is None:
590+
return
591+
589592
self.unet.encoder_hid_proj = None
590593
self.unet.config.encoder_hid_dim_type = None
591594

src/diffusers/pipelines/components_manager.py

Lines changed: 90 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,17 @@
1414

1515
from collections import OrderedDict
1616
from itertools import combinations
17-
from typing import List, Optional, Union
17+
from typing import List, Optional, Union, Dict, Any
1818

1919
import torch
20+
import time
21+
from dataclasses import dataclass
2022

2123
from ..utils import (
2224
is_accelerate_available,
2325
logging,
2426
)
27+
from ..models.modeling_utils import ModelMixin
2528

2629

2730
if is_accelerate_available():
@@ -95,9 +98,6 @@ def pre_forward(self, module, *args, **kwargs):
9598
if self.other_hooks is not None:
9699
hooks_to_offload = [hook for hook in self.other_hooks if hook.model.device == self.execution_device]
97100
# offload all other hooks
98-
import time
99-
100-
# YiYi Notes: only logging time for now to monitor the overhead of offloading strategy (remove later)
101101
start_time = time.perf_counter()
102102
if self.offload_strategy is not None:
103103
hooks_to_offload = self.offload_strategy(
@@ -231,17 +231,27 @@ def search_best_candidate(module_sizes, min_memory_offload):
231231
class ComponentsManager:
232232
def __init__(self):
233233
self.components = OrderedDict()
234+
self.added_time = OrderedDict() # Store when components were added
234235
self.model_hooks = None
235236
self._auto_offload_enabled = False
236237

237238
def add(self, name, component):
238-
if name not in self.components:
239-
self.components[name] = component
240-
if self._auto_offload_enabled:
241-
self.enable_auto_cpu_offload(self._auto_offload_device)
239+
if name in self.components:
240+
logger.warning(f"Overriding existing component '{name}' in ComponentsManager")
241+
self.components[name] = component
242+
self.added_time[name] = time.time()
243+
244+
if self._auto_offload_enabled:
245+
self.enable_auto_cpu_offload(self._auto_offload_device)
242246

243247
def remove(self, name):
248+
if name not in self.components:
249+
logger.warning(f"Component '{name}' not found in ComponentsManager")
250+
return
251+
244252
self.components.pop(name)
253+
self.added_time.pop(name)
254+
245255
if self._auto_offload_enabled:
246256
self.enable_auto_cpu_offload(self._auto_offload_device)
247257

@@ -294,6 +304,61 @@ def disable_auto_cpu_offload(self):
294304
self.model_hooks = None
295305
self._auto_offload_enabled = False
296306

307+
def get_model_info(self, name: str) -> Optional[Dict[str, Any]]:
308+
"""Get comprehensive information about a model component.
309+
310+
Args:
311+
name: Name of the component to get info for
312+
313+
Returns:
314+
Dictionary containing model metadata including:
315+
- model_id: Name of the model
316+
- class_name: Class name of the model
317+
- device: Device the model is on
318+
- dtype: Data type of the model
319+
- size_gb: Size of the model in GB
320+
- added_time: Timestamp when model was added
321+
- active_adapters: List of active adapters (if applicable)
322+
- attn_proc: List of attention processor types (if applicable)
323+
Returns None if component is not a torch.nn.Module
324+
"""
325+
if name not in self.components:
326+
raise ValueError(f"Component '{name}' not found in ComponentsManager")
327+
328+
component = self.components[name]
329+
330+
# Only process torch.nn.Module components
331+
if not isinstance(component, torch.nn.Module):
332+
return None
333+
334+
info = {
335+
"model_id": name,
336+
"class_name": component.__class__.__name__,
337+
"device": str(getattr(component, "device", "N/A")),
338+
"dtype": str(component.dtype) if hasattr(component, "dtype") else None,
339+
"added_time": self.added_time[name],
340+
"size_gb": get_memory_footprint(component) / (1024**3),
341+
"active_adapters": None, # Default to None
342+
}
343+
344+
# Get active adapters if applicable
345+
if isinstance(component, ModelMixin):
346+
from peft.tuners.tuners_utils import BaseTunerLayer
347+
for module in component.modules():
348+
if isinstance(module, BaseTunerLayer):
349+
info["active_adapters"] = module.active_adapters
350+
break
351+
352+
# Get attention processors if applicable
353+
if hasattr(component, "attn_processors"):
354+
processors = component.attn_processors
355+
# Get unique processor types
356+
processor_types = list(set(str(v.__class__.__name__) for v in processors.values()))
357+
if processor_types:
358+
info["attn_proc"] = processor_types
359+
360+
return info
361+
297362
def __repr__(self):
298363
col_widths = {
299364
"id": max(15, max(len(id) for id in self.components.keys())),
@@ -323,14 +388,12 @@ def __repr__(self):
323388

324389
# Model entries
325390
for name, component in models.items():
326-
device = component.device
327-
dtype = component.dtype
328-
size_bytes = get_memory_footprint(component)
329-
size_gb = size_bytes / (1024**3)
330-
331-
output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}} | "
391+
info = self.get_model_info(name)
392+
output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | "
332393
output += (
333-
f"{str(device):<{col_widths['device']}} | {str(dtype):<{col_widths['dtype']}} | {size_gb:.2f}\n"
394+
f"{info['device']:<{col_widths['device']}} | "
395+
f"{info['dtype']:<{col_widths['dtype']}} | "
396+
f"{info['size_gb']:.2f}\n"
334397
)
335398
output += dash_line
336399

@@ -348,6 +411,18 @@ def __repr__(self):
348411
output += f"{name:<{col_widths['id']}} | {component.__class__.__name__:<{col_widths['class']}}\n"
349412
output += dash_line
350413

414+
# Add additional component info
415+
output += "\nAdditional Component Info:\n" + "=" * 50 + "\n"
416+
for name in self.components:
417+
info = self.get_model_info(name)
418+
if info is not None and (info.get("active_adapters") is not None or info.get("attn_proc")):
419+
output += f"\n{name}:\n"
420+
if info.get("active_adapters") is not None:
421+
output += f" Active Adapters: {info['active_adapters']}\n"
422+
if info.get("attn_proc"):
423+
output += f" Attention Processors: {info['attn_proc']}\n"
424+
output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n"
425+
351426
return output
352427

353428
def add_from_pretrained(self, pretrained_model_name_or_path, **kwargs):

0 commit comments

Comments
 (0)