Skip to content
97 changes: 96 additions & 1 deletion src/diffusers/models/transformers/sana_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from typing import Any, Dict, Optional, Tuple, Union

import math
import torch
import torch.nn.functional as F
from torch import nn
Expand Down Expand Up @@ -184,6 +185,91 @@ def __call__(

return hidden_states


class SanaAttnProcessor3_0:
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""

def __init__(self):
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")

@staticmethod
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None
) -> torch.Tensor:
B, H, L, S = *query.size()[:-1], key.size(-2)
scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device)

if attn_mask is not None:
if attn_mask.dtype == torch.bool:
attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
else:
attn_bias += attn_mask
attn_weight = query @ key.transpose(-2, -1) * scale_factor
attn_weight += attn_bias
attn_weight = torch.softmax(attn_weight, dim=-1)
attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
return attn_weight @ value

# return x
def __call__(
self,
attn: Attention,
hidden_states: torch.Tensor,
encoder_hidden_states: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)

if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])

query = attn.to_q(hidden_states)

if encoder_hidden_states is None:
encoder_hidden_states = hidden_states

key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)

if attn.norm_q is not None:
query = attn.norm_q(query)
if attn.norm_k is not None:
key = attn.norm_k(key)

inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = self.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)

hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)

# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)

hidden_states = hidden_states / attn.rescale_output_factor

return hidden_states


class SanaTransformerBlock(nn.Module):
r"""
Expand All @@ -205,6 +291,7 @@ def __init__(
attention_out_bias: bool = True,
mlp_ratio: float = 2.5,
qk_norm: Optional[str] = None,
cross_attention_type: str = "flash",
) -> None:
super().__init__()

Expand All @@ -223,6 +310,12 @@ def __init__(
)

# 2. Cross Attention
if cross_attention_type == "flash":
cross_attention_processor = SanaAttnProcessor2_0()
elif cross_attention_type == "vanilla":
cross_attention_processor = SanaAttnProcessor3_0()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't we modify the SanaAttnProcessor2_0() class to handle the changes of SanaAttnProcessor3_0?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we merge 2_0 and 3_0, we then need a variance to check when to use the function here:

hidden_states = self.scaled_dot_product_attention(

which will be similar with cross_attention_type: str = "flash",

@sayakpaul

else:
raise ValueError(f"Cross attention type {cross_attention_type} is not defined.")
if cross_attention_dim is not None:
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
self.attn2 = Attention(
Expand All @@ -235,7 +328,7 @@ def __init__(
dropout=dropout,
bias=True,
out_bias=attention_out_bias,
processor=SanaAttnProcessor2_0(),
processor=cross_attention_processor,
)

# 3. Feed-forward
Expand Down Expand Up @@ -360,6 +453,7 @@ def __init__(
guidance_embeds_scale: float = 0.1,
qk_norm: Optional[str] = None,
timestep_scale: float = 1.0,
cross_attention_type: str = "flash",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This goes a bit against our design.

Copy link
Contributor

@lawrence-cj lawrence-cj May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then can we just separate it into two classes and let u to help for better implementation?

Copy link
Contributor

@lawrence-cj lawrence-cj May 7, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, the only difference is that F.scaled_dot_product_attention is not supported by torch.JVP. Therefore, during training we need to replace with the vanilla attention implementation. Any good idea how to merge these two? @sayakpaul

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see. If that is the case, I think we should through the attention processor mechanism wherein, we use something like set_attn_processor and use the vanilla attention processor class.

If this is only needed for training, I think we should have the following methods added to the model class:

We can then just include the vanilla attention processor implementation in the training utility and do something like

model = SanaTransformer2DModel(...)
model.set_attn_processor(SanaVanillaAttnProcessor())

WDYT? @DN6 any suggestion?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is cool and nusty IMO, thanks. I'll change the code.

) -> None:
super().__init__()

Expand Down Expand Up @@ -402,6 +496,7 @@ def __init__(
norm_eps=norm_eps,
mlp_ratio=mlp_ratio,
qk_norm=qk_norm,
cross_attention_type=cross_attention_type,
)
for _ in range(num_layers)
]
Expand Down