Skip to content

Commit ea66503

Browse files
committed
Add support for XFormers in SD3
1 parent a899e42 commit ea66503

File tree

1 file changed

+92
-0
lines changed

1 file changed

+92
-0
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,6 +277,13 @@ def set_use_memory_efficient_attention_xformers(
277277
LoRAAttnAddedKVProcessor,
278278
),
279279
)
280+
is_joint_processor = hasattr(self, "processor") and isinstance(
281+
self.processor,
282+
(
283+
JointAttnProcessor2_0,
284+
XFormersJointAttnProcessor,
285+
),
286+
)
280287

281288
if use_memory_efficient_attention_xformers:
282289
if is_added_kv_processor and (is_lora or is_custom_diffusion):
@@ -338,6 +345,8 @@ def set_use_memory_efficient_attention_xformers(
338345
"Memory efficient attention with `xformers` might currently not work correctly if an attention mask is required for the attention operation."
339346
)
340347
processor = XFormersAttnAddedKVProcessor(attention_op=attention_op)
348+
elif is_joint_processor:
349+
processor = XFormersJointAttnProcessor(attention_op=attention_op)
341350
else:
342351
processor = XFormersAttnProcessor(attention_op=attention_op)
343352
else:
@@ -1238,6 +1247,89 @@ def __call__(
12381247
return hidden_states, encoder_hidden_states
12391248

12401249

1250+
class XFormersJointAttnProcessor:
1251+
r"""
1252+
Processor for implementing memory efficient attention using xFormers.
1253+
1254+
Args:
1255+
attention_op (`Callable`, *optional*, defaults to `None`):
1256+
The base
1257+
[operator](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.AttentionOpBase) to
1258+
use as the attention operator. It is recommended to set to `None`, and allow xFormers to choose the best
1259+
operator.
1260+
"""
1261+
1262+
def __init__(self, attention_op: Optional[Callable] = None):
1263+
self.attention_op = attention_op
1264+
1265+
def __call__(
1266+
self,
1267+
attn: Attention,
1268+
hidden_states: torch.FloatTensor,
1269+
encoder_hidden_states: torch.FloatTensor = None,
1270+
attention_mask: Optional[torch.FloatTensor] = None,
1271+
*args,
1272+
**kwargs,
1273+
) -> torch.FloatTensor:
1274+
residual = hidden_states
1275+
1276+
input_ndim = hidden_states.ndim
1277+
if input_ndim == 4:
1278+
batch_size, channel, height, width = hidden_states.shape
1279+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1280+
context_input_ndim = encoder_hidden_states.ndim
1281+
if context_input_ndim == 4:
1282+
batch_size, channel, height, width = encoder_hidden_states.shape
1283+
encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1284+
1285+
batch_size = encoder_hidden_states.shape[0]
1286+
1287+
# `sample` projections.
1288+
query = attn.to_q(hidden_states)
1289+
key = attn.to_k(hidden_states)
1290+
value = attn.to_v(hidden_states)
1291+
1292+
# `context` projections.
1293+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1294+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1295+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1296+
1297+
# attention
1298+
query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
1299+
key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
1300+
value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
1301+
1302+
query = attn.head_to_batch_dim(query).contiguous()
1303+
key = attn.head_to_batch_dim(key).contiguous()
1304+
value = attn.head_to_batch_dim(value).contiguous()
1305+
1306+
hidden_states = xformers.ops.memory_efficient_attention(
1307+
query, key, value, attn_bias=attention_mask, op=self.attention_op, scale=attn.scale
1308+
)
1309+
hidden_states = hidden_states.to(query.dtype)
1310+
hidden_states = attn.batch_to_head_dim(hidden_states)
1311+
1312+
# Split the attention outputs.
1313+
hidden_states, encoder_hidden_states = (
1314+
hidden_states[:, : residual.shape[1]],
1315+
hidden_states[:, residual.shape[1]:],
1316+
)
1317+
1318+
# linear proj
1319+
hidden_states = attn.to_out[0](hidden_states)
1320+
# dropout
1321+
hidden_states = attn.to_out[1](hidden_states)
1322+
if not attn.context_pre_only:
1323+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1324+
1325+
if input_ndim == 4:
1326+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1327+
if context_input_ndim == 4:
1328+
encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1329+
1330+
return hidden_states, encoder_hidden_states
1331+
1332+
12411333
class XFormersAttnAddedKVProcessor:
12421334
r"""
12431335
Processor for implementing memory efficient attention using xFormers.

0 commit comments

Comments
 (0)