@@ -358,6 +358,14 @@ def set_use_memory_efficient_attention_xformers(
358358            self .processor ,
359359            (IPAdapterAttnProcessor , IPAdapterAttnProcessor2_0 , IPAdapterXFormersAttnProcessor ),
360360        )
361+         is_joint_processor  =  hasattr (self , "processor" ) and  isinstance (
362+             self .processor ,
363+             (
364+                 JointAttnProcessor2_0 ,
365+                 XFormersJointAttnProcessor ,
366+             ),
367+         )
368+ 
361369        if  use_memory_efficient_attention_xformers :
362370            if  is_added_kv_processor  and  is_custom_diffusion :
363371                raise  NotImplementedError (
@@ -420,6 +428,8 @@ def set_use_memory_efficient_attention_xformers(
420428                    processor .to (
421429                        device = self .processor .to_k_ip [0 ].weight .device , dtype = self .processor .to_k_ip [0 ].weight .dtype 
422430                    )
431+             elif  is_joint_processor :
432+                 processor  =  XFormersJointAttnProcessor (attention_op = attention_op )
423433            else :
424434                processor  =  XFormersAttnProcessor (attention_op = attention_op )
425435        else :
@@ -1685,6 +1695,91 @@ def __call__(
16851695        return  hidden_states , encoder_hidden_states 
16861696
16871697
1698+ class  XFormersJointAttnProcessor :
1699+     r""" 
1700+     Processor for implementing memory efficient attention using xFormers. 
1701+ 
1702+     Args: 
1703+         attention_op (`Callable`, *optional*, defaults to `None`): 
1704+             The base 
1705+             [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to 
1706+             use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best 
1707+             operator. 
1708+     """ 
1709+ 
1710+     def  __init__ (self , attention_op : Optional [Callable ] =  None ):
1711+         self .attention_op  =  attention_op 
1712+ 
1713+     def  __call__ (
1714+         self ,
1715+         attn : Attention ,
1716+         hidden_states : torch .FloatTensor ,
1717+         encoder_hidden_states : torch .FloatTensor  =  None ,
1718+         attention_mask : Optional [torch .FloatTensor ] =  None ,
1719+         * args ,
1720+         ** kwargs ,
1721+     ) ->  torch .FloatTensor :
1722+         residual  =  hidden_states 
1723+ 
1724+         # `sample` projections. 
1725+         query  =  attn .to_q (hidden_states )
1726+         key  =  attn .to_k (hidden_states )
1727+         value  =  attn .to_v (hidden_states )
1728+ 
1729+         query  =  attn .head_to_batch_dim (query ).contiguous ()
1730+         key  =  attn .head_to_batch_dim (key ).contiguous ()
1731+         value  =  attn .head_to_batch_dim (value ).contiguous ()
1732+ 
1733+         if  attn .norm_q  is  not None :
1734+             query  =  attn .norm_q (query )
1735+         if  attn .norm_k  is  not None :
1736+             key  =  attn .norm_k (key )
1737+ 
1738+         # `context` projections. 
1739+         if  encoder_hidden_states  is  not None :
1740+             encoder_hidden_states_query_proj  =  attn .add_q_proj (encoder_hidden_states )
1741+             encoder_hidden_states_key_proj  =  attn .add_k_proj (encoder_hidden_states )
1742+             encoder_hidden_states_value_proj  =  attn .add_v_proj (encoder_hidden_states )
1743+ 
1744+             encoder_hidden_states_query_proj  =  attn .head_to_batch_dim (encoder_hidden_states_query_proj ).contiguous ()
1745+             encoder_hidden_states_key_proj  =  attn .head_to_batch_dim (encoder_hidden_states_key_proj ).contiguous ()
1746+             encoder_hidden_states_value_proj  =  attn .head_to_batch_dim (encoder_hidden_states_value_proj ).contiguous ()
1747+ 
1748+             if  attn .norm_added_q  is  not None :
1749+                 encoder_hidden_states_query_proj  =  attn .norm_added_q (encoder_hidden_states_query_proj )
1750+             if  attn .norm_added_k  is  not None :
1751+                 encoder_hidden_states_key_proj  =  attn .norm_added_k (encoder_hidden_states_key_proj )
1752+ 
1753+             query  =  torch .cat ([query , encoder_hidden_states_query_proj ], dim = 1 )
1754+             key  =  torch .cat ([key , encoder_hidden_states_key_proj ], dim = 1 )
1755+             value  =  torch .cat ([value , encoder_hidden_states_value_proj ], dim = 1 )
1756+ 
1757+         hidden_states  =  xformers .ops .memory_efficient_attention (
1758+             query , key , value , attn_bias = attention_mask , op = self .attention_op , scale = attn .scale 
1759+         )
1760+         hidden_states  =  hidden_states .to (query .dtype )
1761+         hidden_states  =  attn .batch_to_head_dim (hidden_states )
1762+ 
1763+         if  encoder_hidden_states  is  not None :
1764+             # Split the attention outputs. 
1765+             hidden_states , encoder_hidden_states  =  (
1766+                 hidden_states [:, : residual .shape [1 ]],
1767+                 hidden_states [:, residual .shape [1 ] :],
1768+             )
1769+             if  not  attn .context_pre_only :
1770+                 encoder_hidden_states  =  attn .to_add_out (encoder_hidden_states )
1771+ 
1772+         # linear proj 
1773+         hidden_states  =  attn .to_out [0 ](hidden_states )
1774+         # dropout 
1775+         hidden_states  =  attn .to_out [1 ](hidden_states )
1776+ 
1777+         if  encoder_hidden_states  is  not None :
1778+             return  hidden_states , encoder_hidden_states 
1779+         else :
1780+             return  hidden_states 
1781+ 
1782+ 
16881783class  AllegroAttnProcessor2_0 :
16891784    r""" 
16901785    Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). This is 
0 commit comments