Skip to content
Merged
6 changes: 4 additions & 2 deletions src/diffusers/loaders/ip_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from huggingface_hub.utils import validate_hf_hub_args
from safetensors import safe_open


from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_state_dict
from ..utils import (
USE_PEFT_BACKEND,
Expand All @@ -43,6 +44,7 @@
AttnProcessor2_0,
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor
)

logger = logging.get_logger(__name__)
Expand Down Expand Up @@ -284,7 +286,7 @@ def set_ip_adapter_scale(self, scale):
scale_configs = _maybe_expand_lora_scales(unet, scale, default_scale=0.0)

for attn_name, attn_processor in unet.attn_processors.items():
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
if isinstance(attn_processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor)):
if len(scale_configs) != len(attn_processor.scale):
raise ValueError(
f"Cannot assign {len(scale_configs)} scale_configs to "
Expand Down Expand Up @@ -342,7 +344,7 @@ def unload_ip_adapter(self):
)
attn_procs[name] = (
attn_processor_class
if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0))
if isinstance(value, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor))
else value.__class__()
)
self.unet.set_attn_processor(attn_procs)
24 changes: 17 additions & 7 deletions src/diffusers/loaders/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
from huggingface_hub.utils import validate_hf_hub_args
from torch import nn

from diffusers.models.attention_processor import AttnProcessor, AttnProcessor2_0

from ..models.embeddings import (
ImageProjection,
IPAdapterFaceIDImageProjection,
Expand Down Expand Up @@ -765,6 +767,7 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
from ..models.attention_processor import (
IPAdapterAttnProcessor,
IPAdapterAttnProcessor2_0,
IPAdapterXFormersAttnProcessor
)

if low_cpu_mem_usage:
Expand Down Expand Up @@ -802,13 +805,20 @@ def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=F
hidden_size = self.config.block_out_channels[block_id]

if cross_attention_dim is None or "motion_modules" in name:
attn_processor_class = self.attn_processors[name].__class__
attn_procs[name] = attn_processor_class()

else:
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
if ('XFormers' not in str(self.attn_processors[name].__class__)):
attn_processor_class = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain this code change here

  • previously, if I understand the code correctly, we keep the original attention processor for motion modules (do not change to IP adapter attention processor)
  • now, we change to the default attention processor when it is not Xformer?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, how are you? Help me understand if my reasoning is correct. In this condition it is true when "cross_attention_dim" is None or "motion_modules" is present. I just paid attention to the scenario where "cross_attention_dim" is None. Is there any specific attention class for models with "motion_modules" other than AttnProcessor and AttnProcessor2_0? Because if there was I would just leave "motion_modules" in an "elif". But this part of the code is part of a first solution that I had implemented some time ago when I had not yet implemented the replacement of the attention mechanism in the "set_use_memory_efficient_attention_xformers" method of the "Attention" class. So at the time when I was testing several adapters and combined adapters I was probably encountering a situation that made me force this xformers check in this part of the code. However, now that you mentioned it, I decided to comment out this part of the code and perform some more tests, and it seems that this modification is no longer necessary since "set_use_memory_efficient_attention_xformers" has been implemented. At least for now, I haven't run into any error situations when loading.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you agree i will commit updates without this verification to original code.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this, can you provide a code example that would fail without this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for this, can you provide a code example that would fail without this change?

yes see my code on PR if you check i have lines #pipe.enable_xformers_memory_efficient_attention() you can remove # to run before or after load PR i put the two lines before and after loading model.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i will commit my lasted code with some fixes for quality check

AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
)
attn_procs[name] = attn_processor_class()
else:
attn_procs[name] = self.attn_processors[name]
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you explain the change here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here in Else: it seems to me a situation of which comes first, the chicken or the egg? When we do not use the pipe.enable_xformers_memory_efficient_attention() call by default it is defining IPAdapterAttnProcessor2_0 or IPAdapterAttnProcessor for the IP Adapter. So when you call pipe.enable_xformers_memory_efficient_attention() before loading the IP Adapter all attn are defined for XFormersAttnProcessor, so when loading the IP Adapter modules after this call it is necessary to check if the defined mechanism is xformers to apply the new class "IPAdapterXFormersAttnProcessor". However, when you call pipe.enable_xformers_memory_efficient_attention() after loading the IP Adapter modules, the modules had already been set by default to "IPAdapterAttnProcessor2_0 or IPAdapterAttnProcessor " and the "set_use_memory_efficient_attention_xformers" method of the "Attention" class only knows how to set everything to XFormersAttnProcessor and this generated the error that was reported in the open issue. Now with the implementation that I made in this class, the method also knows how to identify "IPAdapterAttnProcessor2_0 or IPAdapterAttnProcessor " in the modules and correctly replace them with the new class. But it only knows how to do this because "IPAdapterAttnProcessor2_0 or IPAdapterAttnProcessor " was defined when loading the module. So these checks are necessary on both sides due to the order in which pipe.enable_xformers_memory_efficient_attention() is called, before or after loading the modules.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you confirm that you added this change in order to be able to handle this?

pipe.enable_xformers_memory_efficient_attention()
pipe.load_ip_adapter()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes this is to solve this order, and vice-versa. My provided code in this PR simulate the two scenarios

if ('XFormers' in str(self.attn_processors[name].__class__)):
attn_processor_class = (IPAdapterXFormersAttnProcessor)
else:
attn_processor_class = (
IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
)
num_image_text_embeds = []
for state_dict in state_dicts:
if "proj.weight" in state_dict["image_proj"]:
Expand Down
248 changes: 246 additions & 2 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,8 +368,21 @@ def set_use_memory_efficient_attention_xformers(
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
)
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
else:
processor = XFormersAttnProcessor(attention_op=attention_op)
else:
processor = self.processor
if isinstance(self.processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's add a is_ip_adapter flag similar to is_custom_diffusion etc

is_ip_adapter = hasattr(self, "processor") and isinstance(
            self.processor,( IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0),
        )

Copy link
Contributor Author

@elismasilva elismasilva Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is perfectly possible, but it will have to be like below, so that the modules that have already been changed to the Xformers attention class are not replaced again to the XFormersAttnProcessor class in the final Else during the method recursion.

is_ip_adapter = hasattr(self, "processor") and isinstance(
            self.processor, 
            (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor),
        ) 

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, the code you show here is ok!
we just want to keep a consistent style that's all:)

processor = IPAdapterXFormersAttnProcessor(hidden_size=self.processor.hidden_size,
cross_attention_dim=self.processor.cross_attention_dim,
scale=self.processor.scale,
attention_op=attention_op)
processor.load_state_dict(self.processor.state_dict())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to load_state_dict again here?

Copy link
Contributor Author

@elismasilva elismasilva Nov 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, I couldn't identify why, but if I don't reload the state_dict again here after assigning the new class, the final result in the image is not applied. I don't know if it's because the call to "pipe.enable_xformers_memory_efficient_attention()" was after the IP adapter weights had already been loaded, so it's as if the model was not being used. I saw that during the loading of the IP adapter weights you do some manipulations, but I don't think it makes sense to replicate that logic here and I don't know that's the reason. See a final image when there is no state dict and another when there is. So I noticed that in custom diffusion something similar is done, so for practicality I decided to do the same. If you have a better solution I would like to try it.

Without load_sate_dict:
result_1_diff

With load_state_dict:
result_1_diff

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh yeah it comes with weights

self.to_k_ip = nn.ModuleList(

(i had forgotten about that sorry! lol)

if len(self.processor._modules) > 0:
module_list = self.processor._modules[[m for m in self.processor._modules][0]]
if len(module_list) > 0:
processor.to(device=module_list[0].weight.device, dtype=module_list[0].weight.dtype)
elif isinstance(self.processor, (AttnProcessor, AttnProcessor2_0)):
processor = XFormersAttnProcessor(attention_op=attention_op)

else:
if is_custom_diffusion:
attn_processor_class = (
Expand Down Expand Up @@ -4541,7 +4554,238 @@ def __call__(

return hidden_states

class IPAdapterXFormersAttnProcessor(torch.nn.Module):
r"""
Attention processor for IP-Adapter using xFormers.

Args:
hidden_size (`int`):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
num_tokens (`int`, `Tuple[int]` or `List[int]`, defaults to `(4,)`):
The context length of the image features.
scale (`float` or `List[float]`, defaults to 1.0):
the weight scale of image prompt.
attention_op (`Callable`, *optional*, defaults to `None`):
The base
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to use
as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best operator.
"""

def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0, attention_op: Optional[Callable] = None):
super().__init__()

if not hasattr(F, "scaled_dot_product_attention"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check can be removed because this class uses xformers.ops.memory_efficient_attention instead of torch.nn.functional.scaled_dot_product_attention

raise ImportError(
f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
)

self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.attention_op = attention_op

if not isinstance(num_tokens, (tuple, list)):
num_tokens = [num_tokens]
self.num_tokens = num_tokens

if not isinstance(scale, list):
scale = [scale] * len(num_tokens)
if len(scale) != len(num_tokens):
raise ValueError("`scale` should be a list of integers with the same length as `num_tokens`.")
self.scale = scale

self.to_k_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
)
self.to_v_ip = nn.ModuleList(
[nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False) for _ in range(len(num_tokens))]
)

def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
# TODO attention_mask
query = query.contiguous()
key = key.contiguous()
value = value.contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask, op=self.attention_op)
return hidden_states

def __call__(
self,
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
attention_mask: Optional[torch.FloatTensor] = None,
temb: Optional[torch.FloatTensor] = None,
scale: float = 1.0,
ip_adapter_masks: Optional[torch.FloatTensor] = None,
):
residual = hidden_states

# separate ip_hidden_states from encoder_hidden_states
if encoder_hidden_states is not None:
if isinstance(encoder_hidden_states, tuple):
encoder_hidden_states, ip_hidden_states = encoder_hidden_states
else:
deprecation_message = (
"You have passed a tensor as `encoder_hidden_states`. This is deprecated and will be removed in a future release."
" Please make sure to update your script to pass `encoder_hidden_states` as a tuple to suppress this warning."
)
deprecate("encoder_hidden_states not a tuple", "1.0.0", deprecation_message, standard_warn=False)
end_pos = encoder_hidden_states.shape[1] - self.num_tokens[0]
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, :end_pos, :],
[encoder_hidden_states[:, end_pos:, :]],
)

if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)

input_ndim = hidden_states.ndim

if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# expand our mask's singleton query_tokens dimension:
# [batch*heads, 1, key_tokens] ->
# [batch*heads, query_tokens, key_tokens]
# so that it can be added as a bias onto the attention scores that xformers computes:
# [batch*heads, query_tokens, key_tokens]
# we do this explicitly because xformers doesn't broadcast the singleton dimension for us.
_, query_tokens, _ = hidden_states.shape
attention_mask = attention_mask.expand(-1, query_tokens, -1)

if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)

hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)

hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)

if ip_hidden_states:
if ip_adapter_masks is not None:
if not isinstance(ip_adapter_masks, List):
# for backward compatibility, we accept `ip_adapter_mask` as a tensor of shape [num_ip_adapter, 1, height, width]
ip_adapter_masks = list(ip_adapter_masks.unsqueeze(1))
if not (len(ip_adapter_masks) == len(self.scale) == len(ip_hidden_states)):
raise ValueError(
f"Length of ip_adapter_masks array ({len(ip_adapter_masks)}) must match "
f"length of self.scale array ({len(self.scale)}) and number of ip_hidden_states "
f"({len(ip_hidden_states)})"
)
else:
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
if mask is None:
continue
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
raise ValueError(
"Each element of the ip_adapter_masks array should be a tensor with shape "
"[1, num_images_for_ip_adapter, height, width]."
" Please use `IPAdapterMaskProcessor` to preprocess your mask"
)
if mask.shape[1] != ip_state.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of ip images ({ip_state.shape[1]}) at index {index}"
)
if isinstance(scale, list) and not len(scale) == mask.shape[1]:
raise ValueError(
f"Number of masks ({mask.shape[1]}) does not match "
f"number of scales ({len(scale)}) at index {index}"
)
else:
ip_adapter_masks = [None] * len(self.scale)

# for ip-adapter
ip_index = 0
for current_ip_hidden_states, scale, to_k_ip, to_v_ip, mask in zip(
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip, ip_adapter_masks
):
skip = False
if isinstance(scale, list):
if all(s == 0 for s in scale):
skip = True
elif scale == 0:
skip = True
if not skip:
if mask is not None:
mask = mask.to(torch.float16)
if not isinstance(scale, list):
scale = [scale] * mask.shape[1]

current_num_images = mask.shape[1]
for i in range(current_num_images):
ip_key = to_k_ip(current_ip_hidden_states[:, i, :, :])
ip_value = to_v_ip(current_ip_hidden_states[:, i, :, :])

ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

_current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)

_current_ip_hidden_states = _current_ip_hidden_states.to(query.dtype)
_current_ip_hidden_states = attn.batch_to_head_dim(_current_ip_hidden_states)

mask_downsample = IPAdapterMaskProcessor.downsample(
mask[:, i, :, :],
batch_size,
_current_ip_hidden_states.shape[1],
_current_ip_hidden_states.shape[2],
)

mask_downsample = mask_downsample.to(dtype=query.dtype, device=query.device)
hidden_states = hidden_states + scale[i] * (_current_ip_hidden_states * mask_downsample)
else:
ip_key = to_k_ip(current_ip_hidden_states)
ip_value = to_v_ip(current_ip_hidden_states)

ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)

current_ip_hidden_states = self._memory_efficient_attention_xformers(query, ip_key, ip_value, None)

current_ip_hidden_states = current_ip_hidden_states.to(query.dtype)
current_ip_hidden_states = attn.batch_to_head_dim(current_ip_hidden_states)

hidden_states = hidden_states + scale * current_ip_hidden_states
ip_index+=1

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

if attn.residual_connection:
hidden_states = hidden_states + residual

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states
class PAGIdentitySelfAttnProcessor2_0:
r"""
Processor for implementing PAG using scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
Expand Down
Loading