@@ -277,6 +277,13 @@ def set_use_memory_efficient_attention_xformers(
277277 LoRAAttnAddedKVProcessor ,
278278 ),
279279 )
280+ is_joint_processor = hasattr (self , "processor" ) and isinstance (
281+ self .processor ,
282+ (
283+ JointAttnProcessor2_0 ,
284+ XFormersJointAttnProcessor ,
285+ ),
286+ )
280287
281288 if use_memory_efficient_attention_xformers :
282289 if is_added_kv_processor and (is_lora or is_custom_diffusion ):
@@ -338,6 +345,8 @@ def set_use_memory_efficient_attention_xformers(
338345 "Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
339346 )
340347 processor = XFormersAttnAddedKVProcessor (attention_op = attention_op )
348+ elif is_joint_processor :
349+ processor = XFormersJointAttnProcessor (attention_op = attention_op )
341350 else :
342351 processor = XFormersAttnProcessor (attention_op = attention_op )
343352 else :
@@ -1238,6 +1247,89 @@ def __call__(
12381247 return hidden_states , encoder_hidden_states
12391248
12401249
1250+ class XFormersJointAttnProcessor :
1251+ r"""
1252+ Processor for implementing memory efficient attention using xFormers.
1253+
1254+ Args:
1255+ attention_op (`Callable`, *optional*, defaults to `None`):
1256+ The base
1257+ [operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1258+ use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1259+ operator.
1260+ """
1261+
1262+ def __init__ (self , attention_op : Optional [Callable ] = None ):
1263+ self .attention_op = attention_op
1264+
1265+ def __call__ (
1266+ self ,
1267+ attn : Attention ,
1268+ hidden_states : torch .FloatTensor ,
1269+ encoder_hidden_states : torch .FloatTensor = None ,
1270+ attention_mask : Optional [torch .FloatTensor ] = None ,
1271+ * args ,
1272+ ** kwargs ,
1273+ ) -> torch .FloatTensor :
1274+ residual = hidden_states
1275+
1276+ input_ndim = hidden_states .ndim
1277+ if input_ndim == 4 :
1278+ batch_size , channel , height , width = hidden_states .shape
1279+ hidden_states = hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1280+ context_input_ndim = encoder_hidden_states .ndim
1281+ if context_input_ndim == 4 :
1282+ batch_size , channel , height , width = encoder_hidden_states .shape
1283+ encoder_hidden_states = encoder_hidden_states .view (batch_size , channel , height * width ).transpose (1 , 2 )
1284+
1285+ batch_size = encoder_hidden_states .shape [0 ]
1286+
1287+ # `sample` projections.
1288+ query = attn .to_q (hidden_states )
1289+ key = attn .to_k (hidden_states )
1290+ value = attn .to_v (hidden_states )
1291+
1292+ # `context` projections.
1293+ encoder_hidden_states_query_proj = attn .add_q_proj (encoder_hidden_states )
1294+ encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
1295+ encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states )
1296+
1297+ # attention
1298+ query = torch .cat ([query , encoder_hidden_states_query_proj ], dim = 1 )
1299+ key = torch .cat ([key , encoder_hidden_states_key_proj ], dim = 1 )
1300+ value = torch .cat ([value , encoder_hidden_states_value_proj ], dim = 1 )
1301+
1302+ query = attn .head_to_batch_dim (query ).contiguous ()
1303+ key = attn .head_to_batch_dim (key ).contiguous ()
1304+ value = attn .head_to_batch_dim (value ).contiguous ()
1305+
1306+ hidden_states = xformers .ops .memory_efficient_attention (
1307+ query , key , value , attn_bias = attention_mask , op = self .attention_op , scale = attn .scale
1308+ )
1309+ hidden_states = hidden_states .to (query .dtype )
1310+ hidden_states = attn .batch_to_head_dim (hidden_states )
1311+
1312+ # Split the attention outputs.
1313+ hidden_states , encoder_hidden_states = (
1314+ hidden_states [:, : residual .shape [1 ]],
1315+ hidden_states [:, residual .shape [1 ]:],
1316+ )
1317+
1318+ # linear proj
1319+ hidden_states = attn .to_out [0 ](hidden_states )
1320+ # dropout
1321+ hidden_states = attn .to_out [1 ](hidden_states )
1322+ if not attn .context_pre_only :
1323+ encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
1324+
1325+ if input_ndim == 4 :
1326+ hidden_states = hidden_states .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1327+ if context_input_ndim == 4 :
1328+ encoder_hidden_states = encoder_hidden_states .transpose (- 1 , - 2 ).reshape (batch_size , channel , height , width )
1329+
1330+ return hidden_states , encoder_hidden_states
1331+
1332+
12411333class XFormersAttnAddedKVProcessor :
12421334 r"""
12431335 Processor for implementing memory efficient attention using xFormers.
0 commit comments