Skip to content

Commit 378f7ed

Browse files
add xla flux attention class
1 parent 83b55ba commit 378f7ed

File tree

3 files changed

+113
-17
lines changed

3 files changed

+113
-17
lines changed

examples/research_projects/pytorch_xla/inference/flux/flux_inference.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def _main(index, args, text_pipe, ckpt_id):
3333
logger.info(f'loading flux from {ckpt_id}')
3434
flux_pipe = FluxPipeline.from_pretrained(ckpt_id, text_encoder=None, tokenizer=None,
3535
text_encoder_2=None, tokenizer_2=None, torch_dtype=torch.bfloat16).to(device0)
36+
flux_pipe.transformer.enable_xla_flash_attention(partition_spec=("data", None, None, None), is_flux=True)
3637

3738
prompt = 'photograph of an electronics chip in the shape of a race car with trillium written on its side'
3839
width = args.width

src/diffusers/models/attention_processor.py

Lines changed: 108 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def __init__(
297297
self.set_processor(processor)
298298

299299
def set_use_xla_flash_attention(
300-
self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None
300+
self, use_xla_flash_attention: bool, partition_spec: Optional[Tuple[Optional[str], ...]] = None, **kwargs
301301
) -> None:
302302
r"""
303303
Set whether to use xla flash attention from `torch_xla` or not.
@@ -316,7 +316,10 @@ def set_use_xla_flash_attention(
316316
elif is_spmd() and is_torch_xla_version("<", "2.4"):
317317
raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
318318
else:
319-
processor = XLAFlashAttnProcessor2_0(partition_spec)
319+
if len(kwargs) > 0 and kwargs.get("is_flux", None):
320+
processor = XLAFluxFlashAttnProcessor2_0(partition_spec)
321+
else:
322+
processor = XLAFlashAttnProcessor2_0(partition_spec)
320323
else:
321324
processor = (
322325
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
@@ -2318,11 +2321,7 @@ def __call__(
23182321
query = apply_rotary_emb(query, image_rotary_emb)
23192322
key = apply_rotary_emb(key, image_rotary_emb)
23202323

2321-
if XLA_AVAILABLE:
2322-
query /= math.sqrt(head_dim)
2323-
hidden_states = flash_attention(query, key, value, causal=False)
2324-
else:
2325-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2324+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
23262325

23272326
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
23282327
hidden_states = hidden_states.to(query.dtype)
@@ -2523,12 +2522,8 @@ def __call__(
25232522

25242523
query = apply_rotary_emb(query, image_rotary_emb)
25252524
key = apply_rotary_emb(key, image_rotary_emb)
2526-
2527-
if XLA_AVAILABLE:
2528-
query /= math.sqrt(head_dim)
2529-
hidden_states = flash_attention(query, key, value)
2530-
else:
2531-
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2525+
2526+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
25322527

25332528
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
25342529
hidden_states = hidden_states.to(query.dtype)
@@ -3430,6 +3425,106 @@ def __call__(
34303425
return hidden_states
34313426

34323427

3428+
class XLAFluxFlashAttnProcessor2_0:
3429+
r"""
3430+
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
3431+
"""
3432+
3433+
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
3434+
if not hasattr(F, "scaled_dot_product_attention"):
3435+
raise ImportError(
3436+
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
3437+
)
3438+
if is_torch_xla_version("<", "2.3"):
3439+
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
3440+
if is_spmd() and is_torch_xla_version("<", "2.4"):
3441+
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
3442+
self.partition_spec = partition_spec
3443+
3444+
def __call__(
3445+
self,
3446+
attn: Attention,
3447+
hidden_states: torch.FloatTensor,
3448+
encoder_hidden_states: torch.FloatTensor = None,
3449+
attention_mask: Optional[torch.FloatTensor] = None,
3450+
image_rotary_emb: Optional[torch.Tensor] = None,
3451+
) -> torch.FloatTensor:
3452+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
3453+
3454+
# `sample` projections.
3455+
query = attn.to_q(hidden_states)
3456+
key = attn.to_k(hidden_states)
3457+
value = attn.to_v(hidden_states)
3458+
3459+
inner_dim = key.shape[-1]
3460+
head_dim = inner_dim // attn.heads
3461+
3462+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3463+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3464+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
3465+
3466+
if attn.norm_q is not None:
3467+
query = attn.norm_q(query)
3468+
if attn.norm_k is not None:
3469+
key = attn.norm_k(key)
3470+
3471+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
3472+
if encoder_hidden_states is not None:
3473+
# `context` projections.
3474+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
3475+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
3476+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
3477+
3478+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
3479+
batch_size, -1, attn.heads, head_dim
3480+
).transpose(1, 2)
3481+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
3482+
batch_size, -1, attn.heads, head_dim
3483+
).transpose(1, 2)
3484+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
3485+
batch_size, -1, attn.heads, head_dim
3486+
).transpose(1, 2)
3487+
3488+
if attn.norm_added_q is not None:
3489+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
3490+
if attn.norm_added_k is not None:
3491+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
3492+
3493+
# attention
3494+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
3495+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
3496+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
3497+
3498+
if image_rotary_emb is not None:
3499+
from .embeddings import apply_rotary_emb
3500+
3501+
query = apply_rotary_emb(query, image_rotary_emb)
3502+
key = apply_rotary_emb(key, image_rotary_emb)
3503+
3504+
query /= math.sqrt(head_dim)
3505+
hidden_states = flash_attention(query, key, value, causal=False)
3506+
3507+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
3508+
hidden_states = hidden_states.to(query.dtype)
3509+
3510+
if encoder_hidden_states is not None:
3511+
encoder_hidden_states, hidden_states = (
3512+
hidden_states[:, : encoder_hidden_states.shape[1]],
3513+
hidden_states[:, encoder_hidden_states.shape[1] :],
3514+
)
3515+
3516+
# linear proj
3517+
hidden_states = attn.to_out[0](hidden_states)
3518+
# dropout
3519+
hidden_states = attn.to_out[1](hidden_states)
3520+
3521+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
3522+
3523+
return hidden_states, encoder_hidden_states
3524+
else:
3525+
return hidden_states
3526+
3527+
34333528
class MochiVaeAttnProcessor2_0:
34343529
r"""
34353530
Attention processor used in Mochi VAE.

src/diffusers/models/modeling_utils.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,14 +227,14 @@ def disable_npu_flash_attention(self) -> None:
227227
self.set_use_npu_flash_attention(False)
228228

229229
def set_use_xla_flash_attention(
230-
self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None
230+
self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None, **kwargs
231231
) -> None:
232232
# Recursively walk through all the children.
233233
# Any children which exposes the set_use_xla_flash_attention method
234234
# gets the message
235235
def fn_recursive_set_flash_attention(module: torch.nn.Module):
236236
if hasattr(module, "set_use_xla_flash_attention"):
237-
module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec)
237+
module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec, **kwargs)
238238

239239
for child in module.children():
240240
fn_recursive_set_flash_attention(child)
@@ -243,11 +243,11 @@ def fn_recursive_set_flash_attention(module: torch.nn.Module):
243243
if isinstance(module, torch.nn.Module):
244244
fn_recursive_set_flash_attention(module)
245245

246-
def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None):
246+
def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None, **kwargs):
247247
r"""
248248
Enable the flash attention pallals kernel for torch_xla.
249249
"""
250-
self.set_use_xla_flash_attention(True, partition_spec)
250+
self.set_use_xla_flash_attention(True, partition_spec, **kwargs)
251251

252252
def disable_xla_flash_attention(self):
253253
r"""

0 commit comments

Comments
 (0)