@@ -358,6 +358,14 @@ def set_use_memory_efficient_attention_xformers(
358358 self .processor ,
359359 (IPAdapterAttnProcessor , IPAdapterAttnProcessor2_0 , IPAdapterXFormersAttnProcessor ),
360360 )
361+ is_joint_processor = hasattr (self , "processor" ) and isinstance (
362+ self .processor ,
363+ (
364+ JointAttnProcessor2_0 ,
365+ XFormersJointAttnProcessor ,
366+ ),
367+ )
368+
361369 if use_memory_efficient_attention_xformers :
362370 if is_added_kv_processor and is_custom_diffusion :
363371 raise NotImplementedError (
@@ -420,6 +428,8 @@ def set_use_memory_efficient_attention_xformers(
420428 processor .to (
421429 device = self .processor .to_k_ip [0 ].weight .device , dtype = self .processor .to_k_ip [0 ].weight .dtype
422430 )
431+ elif is_joint_processor :
432+ processor = XFormersJointAttnProcessor (attention_op = attention_op )
423433 else :
424434 processor = XFormersAttnProcessor (attention_op = attention_op )
425435 else :
@@ -1685,6 +1695,91 @@ def __call__(
16851695 return hidden_states , encoder_hidden_states
16861696
16871697
1698+ class XFormersJointAttnProcessor :
1699+ r"""
1700+ Processor for implementing memory efficient attention using xFormers.
1701+
1702+ Args:
1703+ attention_op (`Callable`, *optional*, defaults to `None`):
1704+ The base
1705+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1706+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1707+ operator.
1708+ """
1709+
1710+ def __init__ (self , attention_op : Optional [Callable ] = None ):
1711+ self .attention_op = attention_op
1712+
1713+ def __call__ (
1714+ self ,
1715+ attn : Attention ,
1716+ hidden_states : torch .FloatTensor ,
1717+ encoder_hidden_states : torch .FloatTensor = None ,
1718+ attention_mask : Optional [torch .FloatTensor ] = None ,
1719+ * args ,
1720+ ** kwargs ,
1721+ ) -> torch .FloatTensor :
1722+ residual = hidden_states
1723+
1724+ # `sample` projections.
1725+ query = attn .to_q (hidden_states )
1726+ key = attn .to_k (hidden_states )
1727+ value = attn .to_v (hidden_states )
1728+
1729+ query = attn .head_to_batch_dim (query ).contiguous ()
1730+ key = attn .head_to_batch_dim (key ).contiguous ()
1731+ value = attn .head_to_batch_dim (value ).contiguous ()
1732+
1733+ if attn .norm_q is not None :
1734+ query = attn .norm_q (query )
1735+ if attn .norm_k is not None :
1736+ key = attn .norm_k (key )
1737+
1738+ # `context` projections.
1739+ if encoder_hidden_states is not None :
1740+ encoder_hidden_states_query_proj = attn .add_q_proj (encoder_hidden_states )
1741+ encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
1742+ encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states )
1743+
1744+ encoder_hidden_states_query_proj = attn .head_to_batch_dim (encoder_hidden_states_query_proj ).contiguous ()
1745+ encoder_hidden_states_key_proj = attn .head_to_batch_dim (encoder_hidden_states_key_proj ).contiguous ()
1746+ encoder_hidden_states_value_proj = attn .head_to_batch_dim (encoder_hidden_states_value_proj ).contiguous ()
1747+
1748+ if attn .norm_added_q is not None :
1749+ encoder_hidden_states_query_proj = attn .norm_added_q (encoder_hidden_states_query_proj )
1750+ if attn .norm_added_k is not None :
1751+ encoder_hidden_states_key_proj = attn .norm_added_k (encoder_hidden_states_key_proj )
1752+
1753+ query = torch .cat ([query , encoder_hidden_states_query_proj ], dim = 1 )
1754+ key = torch .cat ([key , encoder_hidden_states_key_proj ], dim = 1 )
1755+ value = torch .cat ([value , encoder_hidden_states_value_proj ], dim = 1 )
1756+
1757+ hidden_states = xformers .ops .memory_efficient_attention (
1758+ query , key , value , attn_bias = attention_mask , op = self .attention_op , scale = attn .scale
1759+ )
1760+ hidden_states = hidden_states .to (query .dtype )
1761+ hidden_states = attn .batch_to_head_dim (hidden_states )
1762+
1763+ if encoder_hidden_states is not None :
1764+ # Split the attention outputs.
1765+ hidden_states , encoder_hidden_states = (
1766+ hidden_states [:, : residual .shape [1 ]],
1767+ hidden_states [:, residual .shape [1 ] :],
1768+ )
1769+ if not attn .context_pre_only :
1770+ encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
1771+
1772+ # linear proj
1773+ hidden_states = attn .to_out [0 ](hidden_states )
1774+ # dropout
1775+ hidden_states = attn .to_out [1 ](hidden_states )
1776+
1777+ if encoder_hidden_states is not None :
1778+ return hidden_states , encoder_hidden_states
1779+ else :
1780+ return hidden_states
1781+
1782+
16881783class AllegroAttnProcessor2_0 :
16891784 r"""
16901785 Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is
0 commit comments