Skip to content

Commit 8b3cbb1

Browse files
committed
split out xla flash attention from base AttnProcessor
1 parent 10b6ba1 commit 8b3cbb1

File tree

1 file changed

+105
-3
lines changed

1 file changed

+105
-3
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 105 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,16 @@ def __init__(
276276
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
277277
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
278278
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
279+
# If torch_xla is available, we use pallas flash attention kernel to improve the performance.
279280
if processor is None:
280-
processor = (
281-
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
282-
)
281+
if hasattr(F, "scaled_dot_product_attention") and self.scale_qk:
282+
if is_torch_xla_available:
283+
processor = XLAFlashAttnProcessor2_0()
284+
else:
285+
processor = AttnProcessor2_0()
286+
else:
287+
processor = AttnProcessor()
288+
283289
self.set_processor(processor)
284290

285291
def set_use_npu_flash_attention(self, use_npu_flash_attention: bool) -> None:
@@ -2644,6 +2650,102 @@ def __init__(self):
26442650
if not hasattr(F, "scaled_dot_product_attention"):
26452651
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
26462652

2653+
def __call__(
2654+
self,
2655+
attn: Attention,
2656+
hidden_states: torch.Tensor,
2657+
encoder_hidden_states: Optional[torch.Tensor] = None,
2658+
attention_mask: Optional[torch.Tensor] = None,
2659+
temb: Optional[torch.Tensor] = None,
2660+
*args,
2661+
**kwargs,
2662+
) -> torch.Tensor:
2663+
if len(args) > 0 or kwargs.get("scale", None) is not None:
2664+
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2665+
deprecate("scale", "1.0.0", deprecation_message)
2666+
2667+
residual = hidden_states
2668+
if attn.spatial_norm is not None:
2669+
hidden_states = attn.spatial_norm(hidden_states, temb)
2670+
2671+
input_ndim = hidden_states.ndim
2672+
2673+
if input_ndim == 4:
2674+
batch_size, channel, height, width = hidden_states.shape
2675+
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2676+
2677+
batch_size, sequence_length, _ = (
2678+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2679+
)
2680+
2681+
if attention_mask is not None:
2682+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2683+
# scaled_dot_product_attention expects attention_mask shape to be
2684+
# (batch, heads, source_length, target_length)
2685+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2686+
2687+
if attn.group_norm is not None:
2688+
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2689+
2690+
query = attn.to_q(hidden_states)
2691+
2692+
if encoder_hidden_states is None:
2693+
encoder_hidden_states = hidden_states
2694+
elif attn.norm_cross:
2695+
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2696+
2697+
key = attn.to_k(encoder_hidden_states)
2698+
value = attn.to_v(encoder_hidden_states)
2699+
2700+
inner_dim = key.shape[-1]
2701+
head_dim = inner_dim // attn.heads
2702+
2703+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2704+
2705+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2706+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2707+
2708+
if attn.norm_q is not None:
2709+
query = attn.norm_q(query)
2710+
if attn.norm_k is not None:
2711+
key = attn.norm_k(key)
2712+
2713+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
2714+
# TODO: add support for attn.scale when we move to Torch 2.1
2715+
hidden_states = F.scaled_dot_product_attention(
2716+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2717+
)
2718+
2719+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2720+
hidden_states = hidden_states.to(query.dtype)
2721+
2722+
# linear proj
2723+
hidden_states = attn.to_out[0](hidden_states)
2724+
# dropout
2725+
hidden_states = attn.to_out[1](hidden_states)
2726+
2727+
if input_ndim == 4:
2728+
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2729+
2730+
if attn.residual_connection:
2731+
hidden_states = hidden_states + residual
2732+
2733+
hidden_states = hidden_states / attn.rescale_output_factor
2734+
2735+
return hidden_states
2736+
2737+
2738+
class XLAFlashAttnProcessor2_0:
2739+
r"""
2740+
Processor for implementing scaled dot-product attention (enabled by default if you're using torch_xla).
2741+
"""
2742+
2743+
def __init__(self):
2744+
if not hasattr(F, "scaled_dot_product_attention"):
2745+
raise ImportError("XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
2746+
if not is_torch_xla_available:
2747+
raise ImportError("XLAFlashAttnProcessor2_0 required torch_xla package.")
2748+
26472749
def __call__(
26482750
self,
26492751
attn: Attention,

0 commit comments

Comments
 (0)