Skip to content

Commit c7020df

Browse files
committed
add model_info
1 parent 4bed3e3 commit c7020df

File tree

2 files changed

+531
-50
lines changed

2 files changed

+531
-50
lines changed

src/diffusers/pipelines/components_manager.py

Lines changed: 116 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -304,59 +304,66 @@ def disable_auto_cpu_offload(self):
304304
self.model_hooks = None
305305
self._auto_offload_enabled = False
306306

307-
def get_model_info(self, name: str) -> Optional[Dict[str, Any]]:
308-
"""Get comprehensive information about a model component.
307+
def get_model_info(self, name: str, fields: Optional[Union[str, List[str]]] = None) -> Optional[Dict[str, Any]]:
308+
"""Get comprehensive information about a component.
309309
310310
Args:
311311
name: Name of the component to get info for
312-
312+
fields: Optional field(s) to return. Can be a string for single field or list of fields.
313+
If None, returns all fields.
314+
313315
Returns:
314-
Dictionary containing model metadata including:
315-
- model_id: Name of the model
316-
- class_name: Class name of the model
317-
- device: Device the model is on
318-
- dtype: Data type of the model
319-
- size_gb: Size of the model in GB
320-
- added_time: Timestamp when model was added
321-
- active_adapters: List of active adapters (if applicable)
322-
- attn_proc: List of attention processor types (if applicable)
323-
Returns None if component is not a torch.nn.Module
316+
Dictionary containing requested component metadata.
317+
If fields is specified, returns only those fields.
318+
If a single field is requested as string, returns just that field's value.
324319
"""
325320
if name not in self.components:
326321
raise ValueError(f"Component '{name}' not found in ComponentsManager")
327322

328323
component = self.components[name]
329324

330-
# Only process torch.nn.Module components
331-
if not isinstance(component, torch.nn.Module):
332-
return None
333-
325+
# Build complete info dict first
334326
info = {
335327
"model_id": name,
336-
"class_name": component.__class__.__name__,
337-
"device": str(getattr(component, "device", "N/A")),
338-
"dtype": str(component.dtype) if hasattr(component, "dtype") else None,
339328
"added_time": self.added_time[name],
340-
"size_gb": get_memory_footprint(component) / (1024**3),
341-
"active_adapters": None, # Default to None
342329
}
343-
344-
# Get active adapters if applicable
345-
if isinstance(component, ModelMixin):
346-
from peft.tuners.tuners_utils import BaseTunerLayer
347-
for module in component.modules():
348-
if isinstance(module, BaseTunerLayer):
349-
info["active_adapters"] = module.active_adapters
350-
break
351-
352-
# Get attention processors if applicable
353-
if hasattr(component, "attn_processors"):
354-
processors = component.attn_processors
355-
# Get unique processor types
356-
processor_types = list(set(str(v.__class__.__name__) for v in processors.values()))
357-
if processor_types:
358-
info["attn_proc"] = processor_types
359-
330+
331+
# Additional info for torch.nn.Module components
332+
if isinstance(component, torch.nn.Module):
333+
info.update({
334+
"class_name": component.__class__.__name__,
335+
"size_gb": get_memory_footprint(component) / (1024**3),
336+
"adapters": None, # Default to None
337+
})
338+
339+
# Get adapters if applicable
340+
if hasattr(component, "peft_config"):
341+
info["adapters"] = list(component.peft_config.keys())
342+
343+
# Check for IP-Adapter scales
344+
if hasattr(component, "_load_ip_adapter_weights") and hasattr(component, "attn_processors"):
345+
processors = component.attn_processors
346+
# First check if any processor is an IP-Adapter
347+
processor_types = [v.__class__.__name__ for v in processors.values()]
348+
if any("IPAdapter" in ptype for ptype in processor_types):
349+
# Then get scales only from IP-Adapter processors
350+
scales = {
351+
k: v.scale
352+
for k, v in processors.items()
353+
if hasattr(v, "scale") and "IPAdapter" in v.__class__.__name__
354+
}
355+
if scales:
356+
info["ip_adapter"] = summarize_dict_by_value_and_parts(scales)
357+
358+
# If fields specified, filter info
359+
if fields is not None:
360+
if isinstance(fields, str):
361+
# Single field requested, return just that value
362+
return {fields: info.get(fields)}
363+
else:
364+
# List of fields requested, return dict with just those fields
365+
return {k: v for k, v in info.items() if k in fields}
366+
360367
return info
361368

362369
def __repr__(self):
@@ -383,18 +390,16 @@ def __repr__(self):
383390
output += "Models:\n" + dash_line
384391
# Column headers
385392
output += f"{'Model ID':<{col_widths['id']}} | {'Class':<{col_widths['class']}} | "
386-
output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB) \n"
393+
output += f"{'Device':<{col_widths['device']}} | {'Dtype':<{col_widths['dtype']}} | Size (GB)\n"
387394
output += dash_line
388395

389396
# Model entries
390397
for name, component in models.items():
391398
info = self.get_model_info(name)
399+
device = str(getattr(component, "device", "N/A"))
400+
dtype = str(component.dtype) if hasattr(component, "dtype") else "N/A"
392401
output += f"{name:<{col_widths['id']}} | {info['class_name']:<{col_widths['class']}} | "
393-
output += (
394-
f"{info['device']:<{col_widths['device']}} | "
395-
f"{info['dtype']:<{col_widths['dtype']}} | "
396-
f"{info['size_gb']:.2f}\n"
397-
)
402+
output += f"{device:<{col_widths['device']}} | {dtype:<{col_widths['dtype']}} | {info['size_gb']:.2f}\n"
398403
output += dash_line
399404

400405
# Other components section
@@ -415,12 +420,12 @@ def __repr__(self):
415420
output += "\nAdditional Component Info:\n" + "=" * 50 + "\n"
416421
for name in self.components:
417422
info = self.get_model_info(name)
418-
if info is not None and (info.get("active_adapters") is not None or info.get("attn_proc")):
423+
if info is not None and (info.get("adapters") is not None or info.get("ip_adapter")):
419424
output += f"\n{name}:\n"
420-
if info.get("active_adapters") is not None:
421-
output += f" Active Adapters: {info['active_adapters']}\n"
422-
if info.get("attn_proc"):
423-
output += f" Attention Processors: {info['attn_proc']}\n"
425+
if info.get("adapters") is not None:
426+
output += f" Adapters: {info['adapters']}\n"
427+
if info.get("ip_adapter"):
428+
output += f" IP-Adapter: Enabled\n"
424429
output += f" Added Time: {time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(info['added_time']))}\n"
425430

426431
return output
@@ -438,3 +443,64 @@ def add_from_pretrained(self, pretrained_model_name_or_path, **kwargs):
438443
f"1. remove the existing component with remove('{name}')\n"
439444
f"2. Use a different name: add('{name}_2', component)"
440445
)
446+
447+
def summarize_dict_by_value_and_parts(d: Dict[str, Any]) -> Dict[str, Any]:
448+
"""Summarizes a dictionary by finding common prefixes that share the same value.
449+
450+
For a dictionary with dot-separated keys like:
451+
{
452+
'down_blocks.1.attentions.1.transformer_blocks.0.attn2.processor': [0.6],
453+
'down_blocks.1.attentions.1.transformer_blocks.1.attn2.processor': [0.6],
454+
'up_blocks.1.attentions.0.transformer_blocks.0.attn2.processor': [0.3],
455+
}
456+
457+
Returns a dictionary where keys are the shortest common prefixes and values are their shared values:
458+
{
459+
'down_blocks': [0.6],
460+
'up_blocks': [0.3]
461+
}
462+
"""
463+
# First group by values - convert lists to tuples to make them hashable
464+
value_to_keys = {}
465+
for key, value in d.items():
466+
value_tuple = tuple(value) if isinstance(value, list) else value
467+
if value_tuple not in value_to_keys:
468+
value_to_keys[value_tuple] = []
469+
value_to_keys[value_tuple].append(key)
470+
471+
def find_common_prefix(keys: List[str]) -> str:
472+
"""Find the shortest common prefix among a list of dot-separated keys."""
473+
if not keys:
474+
return ""
475+
if len(keys) == 1:
476+
return keys[0]
477+
478+
# Split all keys into parts
479+
key_parts = [k.split('.') for k in keys]
480+
481+
# Find how many initial parts are common
482+
common_length = 0
483+
for parts in zip(*key_parts):
484+
if len(set(parts)) == 1: # All parts at this position are the same
485+
common_length += 1
486+
else:
487+
break
488+
489+
if common_length == 0:
490+
return ""
491+
492+
# Return the common prefix
493+
return '.'.join(key_parts[0][:common_length])
494+
495+
# Create summary by finding common prefixes for each value group
496+
summary = {}
497+
for value_tuple, keys in value_to_keys.items():
498+
prefix = find_common_prefix(keys)
499+
if prefix: # Only add if we found a common prefix
500+
# Convert tuple back to list if it was originally a list
501+
value = list(value_tuple) if isinstance(d[keys[0]], list) else value_tuple
502+
summary[prefix] = value
503+
else:
504+
summary[""] = value # Use empty string if no common prefix
505+
506+
return summary

0 commit comments

Comments
 (0)