-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Add cross attention type for Sana-Sprint training in diffusers. #11514
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
5d9a5da
1123ee7
acefec8
9cb050b
86bef58
c190600
6c3a398
04e1b02
5951f8f
93c3b4d
566aa64
740baa9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -14,6 +14,7 @@ | |||||
|
|
||||||
| from typing import Any, Dict, Optional, Tuple, Union | ||||||
|
|
||||||
| import math | ||||||
| import torch | ||||||
| import torch.nn.functional as F | ||||||
| from torch import nn | ||||||
|
|
@@ -184,6 +185,91 @@ def __call__( | |||||
|
|
||||||
| return hidden_states | ||||||
|
|
||||||
|
|
||||||
| class SanaAttnProcessor3_0: | ||||||
| r""" | ||||||
| Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). | ||||||
| """ | ||||||
|
|
||||||
| def __init__(self): | ||||||
| if not hasattr(F, "scaled_dot_product_attention"): | ||||||
| raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | ||||||
|
|
||||||
| @staticmethod | ||||||
| def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None | ||||||
| ) -> torch.Tensor: | ||||||
| B, H, L, S = *query.size()[:-1], key.size(-2) | ||||||
| scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale | ||||||
| attn_bias = torch.zeros(B, H, L, S, dtype=query.dtype, device=query.device) | ||||||
|
|
||||||
| if attn_mask is not None: | ||||||
| if attn_mask.dtype == torch.bool: | ||||||
| attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) | ||||||
| else: | ||||||
| attn_bias += attn_mask | ||||||
| attn_weight = query @ key.transpose(-2, -1) * scale_factor | ||||||
| attn_weight += attn_bias | ||||||
| attn_weight = torch.softmax(attn_weight, dim=-1) | ||||||
| attn_weight = torch.dropout(attn_weight, dropout_p, train=True) | ||||||
| return attn_weight @ value | ||||||
|
|
||||||
| # return x | ||||||
| def __call__( | ||||||
| self, | ||||||
| attn: Attention, | ||||||
| hidden_states: torch.Tensor, | ||||||
| encoder_hidden_states: Optional[torch.Tensor] = None, | ||||||
| attention_mask: Optional[torch.Tensor] = None, | ||||||
| ) -> torch.Tensor: | ||||||
| batch_size, sequence_length, _ = ( | ||||||
| hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape | ||||||
| ) | ||||||
|
|
||||||
| if attention_mask is not None: | ||||||
| attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||||||
| # scaled_dot_product_attention expects attention_mask shape to be | ||||||
| # (batch, heads, source_length, target_length) | ||||||
| attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | ||||||
|
|
||||||
| query = attn.to_q(hidden_states) | ||||||
|
|
||||||
| if encoder_hidden_states is None: | ||||||
| encoder_hidden_states = hidden_states | ||||||
|
|
||||||
| key = attn.to_k(encoder_hidden_states) | ||||||
| value = attn.to_v(encoder_hidden_states) | ||||||
|
|
||||||
| if attn.norm_q is not None: | ||||||
| query = attn.norm_q(query) | ||||||
| if attn.norm_k is not None: | ||||||
| key = attn.norm_k(key) | ||||||
|
|
||||||
| inner_dim = key.shape[-1] | ||||||
| head_dim = inner_dim // attn.heads | ||||||
|
|
||||||
| query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||||||
|
|
||||||
| key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||||||
| value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||||||
|
|
||||||
| # the output of sdp = (batch, num_heads, seq_len, head_dim) | ||||||
| # TODO: add support for attn.scale when we move to Torch 2.1 | ||||||
| hidden_states = self.scaled_dot_product_attention( | ||||||
| query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | ||||||
| ) | ||||||
|
|
||||||
| hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | ||||||
| hidden_states = hidden_states.to(query.dtype) | ||||||
|
|
||||||
| # linear proj | ||||||
| hidden_states = attn.to_out[0](hidden_states) | ||||||
| # dropout | ||||||
| hidden_states = attn.to_out[1](hidden_states) | ||||||
|
|
||||||
| hidden_states = hidden_states / attn.rescale_output_factor | ||||||
|
|
||||||
| return hidden_states | ||||||
|
|
||||||
|
|
||||||
| class SanaTransformerBlock(nn.Module): | ||||||
| r""" | ||||||
|
|
@@ -205,6 +291,7 @@ def __init__( | |||||
| attention_out_bias: bool = True, | ||||||
| mlp_ratio: float = 2.5, | ||||||
| qk_norm: Optional[str] = None, | ||||||
| cross_attention_type: str = "flash", | ||||||
| ) -> None: | ||||||
| super().__init__() | ||||||
|
|
||||||
|
|
@@ -223,6 +310,12 @@ def __init__( | |||||
| ) | ||||||
|
|
||||||
| # 2. Cross Attention | ||||||
| if cross_attention_type == "flash": | ||||||
| cross_attention_processor = SanaAttnProcessor2_0() | ||||||
| elif cross_attention_type == "vanilla": | ||||||
| cross_attention_processor = SanaAttnProcessor3_0() | ||||||
| else: | ||||||
| raise ValueError(f"Cross attention type {cross_attention_type} is not defined.") | ||||||
| if cross_attention_dim is not None: | ||||||
| self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps) | ||||||
| self.attn2 = Attention( | ||||||
|
|
@@ -235,7 +328,7 @@ def __init__( | |||||
| dropout=dropout, | ||||||
| bias=True, | ||||||
| out_bias=attention_out_bias, | ||||||
| processor=SanaAttnProcessor2_0(), | ||||||
| processor=cross_attention_processor, | ||||||
| ) | ||||||
|
|
||||||
| # 3. Feed-forward | ||||||
|
|
@@ -360,6 +453,7 @@ def __init__( | |||||
| guidance_embeds_scale: float = 0.1, | ||||||
| qk_norm: Optional[str] = None, | ||||||
| timestep_scale: float = 1.0, | ||||||
| cross_attention_type: str = "flash", | ||||||
|
||||||
| def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]): |
| def attn_processors(self) -> Dict[str, AttentionProcessor]: |
We can then just include the vanilla attention processor implementation in the training utility and do something like
model = SanaTransformer2DModel(...)
model.set_attn_processor(SanaVanillaAttnProcessor())WDYT? @DN6 any suggestion?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh this is cool and nusty IMO, thanks. I'll change the code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't we modify the
SanaAttnProcessor2_0()class to handle the changes ofSanaAttnProcessor3_0?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we merge 2_0 and 3_0, we then need a variance to check when to use the function here:
diffusers/src/diffusers/models/transformers/sana_transformer.py
Line 257 in 9cb050b
which will be similar with
cross_attention_type: str = "flash",@sayakpaul