diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index faacc431c386..945ceb57e769 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -358,6 +358,14 @@ def set_use_memory_efficient_attention_xformers( self.processor, (IPAdapterAttnProcessor, IPAdapterAttnProcessor2_0, IPAdapterXFormersAttnProcessor), ) + is_joint_processor = hasattr(self, "processor") and isinstance( + self.processor, + ( + JointAttnProcessor2_0, + XFormersJointAttnProcessor, + ), + ) + if use_memory_efficient_attention_xformers: if is_added_kv_processor and is_custom_diffusion: raise NotImplementedError( @@ -420,6 +428,8 @@ def set_use_memory_efficient_attention_xformers( processor.to( device=self.processor.to_k_ip[0].weight.device, dtype=self.processor.to_k_ip[0].weight.dtype ) + elif is_joint_processor: + processor = XFormersJointAttnProcessor(attention_op=attention_op) else: processor = XFormersAttnProcessor(attention_op=attention_op) else: @@ -1685,6 +1695,91 @@ def __call__( return hidden_states, encoder_hidden_states +class XFormersJointAttnProcessor: + r""" + Processor for implementing memory efficient attention using xFormers. + + Args: + attention_op (`Callable`, *optional*, defaults to `None`): + The base + [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to + use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best + operator. + """ + + def __init__(self, attention_op: Optional[Callable] = None): + self.attention_op = attention_op + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: torch.FloatTensor = None, + attention_mask: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + residual = hidden_states + + # `sample` projections. + query = attn.to_q(hidden_states) + key = attn.to_k(hidden_states) + value = attn.to_v(hidden_states) + + query = attn.head_to_batch_dim(query).contiguous() + key = attn.head_to_batch_dim(key).contiguous() + value = attn.head_to_batch_dim(value).contiguous() + + if attn.norm_q is not None: + query = attn.norm_q(query) + if attn.norm_k is not None: + key = attn.norm_k(key) + + # `context` projections. + if encoder_hidden_states is not None: + encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states) + encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states) + + encoder_hidden_states_query_proj = attn.head_to_batch_dim(encoder_hidden_states_query_proj).contiguous() + encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj).contiguous() + encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj).contiguous() + + if attn.norm_added_q is not None: + encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj) + if attn.norm_added_k is not None: + encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj) + + query = torch.cat([query, encoder_hidden_states_query_proj], dim=1) + key = torch.cat([key, encoder_hidden_states_key_proj], dim=1) + value = torch.cat([value, encoder_hidden_states_value_proj], dim=1) + + hidden_states = xformers.ops.memory_efficient_attention( + query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale + ) + hidden_states = hidden_states.to(query.dtype) + hidden_states = attn.batch_to_head_dim(hidden_states) + + if encoder_hidden_states is not None: + # Split the attention outputs. + hidden_states, encoder_hidden_states = ( + hidden_states[:, : residual.shape[1]], + hidden_states[:, residual.shape[1] :], + ) + if not attn.context_pre_only: + encoder_hidden_states = attn.to_add_out(encoder_hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if encoder_hidden_states is not None: + return hidden_states, encoder_hidden_states + else: + return hidden_states + + class AllegroAttnProcessor2_0: r""" Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py index b9e12a11fafa..2531381dc7c8 100644 --- a/tests/models/transformers/test_models_transformer_sd3.py +++ b/tests/models/transformers/test_models_transformer_sd3.py @@ -18,6 +18,7 @@ import torch from diffusers import SD3Transformer2DModel +from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.testing_utils import ( enable_full_determinism, torch_device, @@ -80,6 +81,20 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_enable_works(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + model.enable_xformers_memory_efficient_attention() + + assert ( + model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor" + ), "xformers is not enabled" + @unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply") def test_set_attn_processor_for_determinism(self): pass @@ -140,6 +155,20 @@ def prepare_init_args_and_inputs_for_common(self): inputs_dict = self.dummy_input return init_dict, inputs_dict + @unittest.skipIf( + torch_device != "cuda" or not is_xformers_available(), + reason="XFormers attention is only available with CUDA and `xformers` installed", + ) + def test_xformers_enable_works(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + + model.enable_xformers_memory_efficient_attention() + + assert ( + model.transformer_blocks[0].attn.processor.__class__.__name__ == "XFormersJointAttnProcessor" + ), "xformers is not enabled" + @unittest.skip("SD3Transformer2DModel uses a dedicated attention processor. This test doesn't apply") def test_set_attn_processor_for_determinism(self): pass