Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/diffusers/hooks/faster_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch

from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..models.modeling_outputs import Transformer2DModelOutput
from ..utils import logging
Expand Down Expand Up @@ -567,7 +568,7 @@ def high_frequency_weight_callback(module: torch.nn.Module) -> float:
_apply_faster_cache_on_denoiser(module, config)

for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES):
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
continue
if any(re.search(identifier, name) is not None for identifier in _TRANSFORMER_BLOCK_IDENTIFIERS):
_apply_faster_cache_on_attention_class(name, submodule, config)
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/hooks/pyramid_attention_broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

import torch

from ..models.attention import AttentionModuleMixin
from ..models.attention_processor import Attention, MochiAttention
from ..utils import logging
from .hooks import HookRegistry, ModelHook
Expand Down Expand Up @@ -227,7 +228,7 @@ def apply_pyramid_attention_broadcast(module: torch.nn.Module, config: PyramidAt
config.spatial_attention_block_skip_range = 2

for name, submodule in module.named_modules():
if not isinstance(submodule, _ATTENTION_CLASSES):
if not isinstance(submodule, (*_ATTENTION_CLASSES, AttentionModuleMixin)):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking as a TODO for myself. We will no longer need _ATTENTION_CLASSES once everything is AttentionModuleMixin

# PAB has been implemented specific to Diffusers' Attention classes. However, this does not mean that PAB
# cannot be applied to this layer. For custom layers, users can extend this functionality and implement
# their own PAB logic similar to `_apply_pyramid_attention_broadcast_on_attention_class`.
Expand Down
9 changes: 5 additions & 4 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@
from ..models.attention_processor import (
AttnProcessor,
AttnProcessor2_0,
FluxAttnProcessor2_0,
FluxIPAdapterJointAttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor,
Expand Down Expand Up @@ -867,6 +865,9 @@ def unload_ip_adapter(self):
>>> ...
```
"""
# TODO: once the 1.0.0 deprecations are in, we can move the imports to top-level
from ..models.transformers.transformer_flux import FluxAttnProcessor, FluxIPAdapterAttnProcessor

# remove CLIP image encoder
if hasattr(self, "image_encoder") and getattr(self, "image_encoder", None) is not None:
self.image_encoder = None
Expand All @@ -886,9 +887,9 @@ def unload_ip_adapter(self):
# restore original Transformer attention processors layers
attn_procs = {}
for name, value in self.transformer.attn_processors.items():
attn_processor_class = FluxAttnProcessor2_0()
attn_processor_class = FluxAttnProcessor()
attn_procs[name] = (
attn_processor_class if isinstance(value, (FluxIPAdapterJointAttnProcessor2_0)) else value.__class__()
attn_processor_class if isinstance(value, FluxIPAdapterAttnProcessor) else value.__class__()
)
self.transformer.set_attn_processor(attn_procs)

Expand Down
6 changes: 2 additions & 4 deletions src/diffusers/loaders/transformer_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,7 @@ def _convert_ip_adapter_image_proj_to_diffusers(self, state_dict, low_cpu_mem_us
return image_projection

def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
from ..models.attention_processor import (
FluxIPAdapterJointAttnProcessor2_0,
)
from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor

if low_cpu_mem_usage:
if is_accelerate_available():
Expand Down Expand Up @@ -121,7 +119,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_
else:
cross_attention_dim = self.config.joint_attention_dim
hidden_size = self.inner_dim
attn_processor_class = FluxIPAdapterJointAttnProcessor2_0
attn_processor_class = FluxIPAdapterAttnProcessor
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
Expand Down
Loading
Loading