Skip to content

Commit c952370

Browse files
committed
first commit
1 parent 6b9a333 commit c952370

File tree

8 files changed

+1364
-9
lines changed

8 files changed

+1364
-9
lines changed

src/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@
276276
"UnCLIPScheduler",
277277
"UniPCMultistepScheduler",
278278
"VQDiffusionScheduler",
279+
"SCMScheduler",
279280
]
280281
)
281282
_import_structure["training_utils"] = ["EMAModel"]
@@ -421,6 +422,7 @@
421422
"ReduxImageEncoder",
422423
"SanaPAGPipeline",
423424
"SanaPipeline",
425+
"SanaSCMPipeline",
424426
"SemanticStableDiffusionPipeline",
425427
"ShapEImg2ImgPipeline",
426428
"ShapEPipeline",
@@ -839,6 +841,7 @@
839841
UnCLIPScheduler,
840842
UniPCMultistepScheduler,
841843
VQDiffusionScheduler,
844+
SCMScheduler,
842845
)
843846
from .training_utils import EMAModel
844847

@@ -965,6 +968,7 @@
965968
ReduxImageEncoder,
966969
SanaPAGPipeline,
967970
SanaPipeline,
971+
SanaSCMPipeline,
968972
SemanticStableDiffusionPipeline,
969973
ShapEImg2ImgPipeline,
970974
ShapEPipeline,

src/diffusers/models/attention_processor.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6020,6 +6020,11 @@ def __call__(
60206020
key = attn.to_k(encoder_hidden_states)
60216021
value = attn.to_v(encoder_hidden_states)
60226022

6023+
if attn.norm_q is not None:
6024+
query = attn.norm_q(query)
6025+
if attn.norm_k is not None:
6026+
key = attn.norm_k(key)
6027+
60236028
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
60246029
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
60256030
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 121 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,9 @@
3030
from ..modeling_outputs import Transformer2DModelOutput
3131
from ..modeling_utils import ModelMixin
3232
from ..normalization import AdaLayerNormSingle, RMSNorm
33+
from ..embeddings import TimestepEmbedding, Timesteps
3334

35+
import torch.nn.functional as F
3436

3537
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3638

@@ -96,6 +98,102 @@ def forward(
9698
return hidden_states
9799

98100

101+
class SanaCombinedTimestepGuidanceEmbeddings(nn.Module):
102+
"""
103+
For Sana.
104+
105+
Reference:
106+
"""
107+
108+
def __init__(self, embedding_dim):
109+
super().__init__()
110+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
111+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
112+
113+
self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
114+
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
115+
116+
self.silu = nn.SiLU()
117+
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
118+
119+
def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None):
120+
timesteps_proj = self.time_proj(timestep)
121+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
122+
123+
guidance_proj = self.guidance_condition_proj(guidance)
124+
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype))
125+
conditioning = timesteps_emb + guidance_emb
126+
127+
return self.linear(self.silu(conditioning)), conditioning
128+
129+
130+
131+
class SanaAttnProcessor2_0:
132+
r"""
133+
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
134+
"""
135+
136+
def __init__(self):
137+
if not hasattr(F, "scaled_dot_product_attention"):
138+
raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
139+
140+
def __call__(
141+
self,
142+
attn: Attention,
143+
hidden_states: torch.Tensor,
144+
encoder_hidden_states: Optional[torch.Tensor] = None,
145+
attention_mask: Optional[torch.Tensor] = None,
146+
) -> torch.Tensor:
147+
148+
batch_size, sequence_length, _ = (
149+
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
150+
)
151+
152+
if attention_mask is not None:
153+
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
154+
# scaled_dot_product_attention expects attention_mask shape to be
155+
# (batch, heads, source_length, target_length)
156+
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
157+
158+
query = attn.to_q(hidden_states)
159+
160+
if encoder_hidden_states is None:
161+
encoder_hidden_states = hidden_states
162+
163+
key = attn.to_k(encoder_hidden_states)
164+
value = attn.to_v(encoder_hidden_states)
165+
166+
if attn.norm_q is not None:
167+
query = attn.norm_q(query)
168+
if attn.norm_k is not None:
169+
key = attn.norm_k(key)
170+
171+
inner_dim = key.shape[-1]
172+
head_dim = inner_dim // attn.heads
173+
174+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
175+
176+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
177+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
178+
179+
# the output of sdp = (batch, num_heads, seq_len, head_dim)
180+
# TODO: add support for attn.scale when we move to Torch 2.1
181+
hidden_states = F.scaled_dot_product_attention(
182+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
183+
)
184+
185+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
186+
hidden_states = hidden_states.to(query.dtype)
187+
188+
# linear proj
189+
hidden_states = attn.to_out[0](hidden_states)
190+
# dropout
191+
hidden_states = attn.to_out[1](hidden_states)
192+
193+
hidden_states = hidden_states / attn.rescale_output_factor
194+
195+
return hidden_states
196+
99197
class SanaTransformerBlock(nn.Module):
100198
r"""
101199
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
@@ -115,6 +213,7 @@ def __init__(
115213
norm_eps: float = 1e-6,
116214
attention_out_bias: bool = True,
117215
mlp_ratio: float = 2.5,
216+
qk_norm: Optional[str] = None,
118217
) -> None:
119218
super().__init__()
120219

@@ -124,6 +223,8 @@ def __init__(
124223
query_dim=dim,
125224
heads=num_attention_heads,
126225
dim_head=attention_head_dim,
226+
kv_heads=num_attention_heads if qk_norm is not None else None,
227+
qk_norm=qk_norm,
127228
dropout=dropout,
128229
bias=attention_bias,
129230
cross_attention_dim=None,
@@ -135,13 +236,15 @@ def __init__(
135236
self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
136237
self.attn2 = Attention(
137238
query_dim=dim,
239+
qk_norm=qk_norm,
240+
kv_heads=num_cross_attention_heads if qk_norm is not None else None,
138241
cross_attention_dim=cross_attention_dim,
139242
heads=num_cross_attention_heads,
140243
dim_head=cross_attention_head_dim,
141244
dropout=dropout,
142245
bias=True,
143246
out_bias=attention_out_bias,
144-
processor=AttnProcessor2_0(),
247+
processor=SanaAttnProcessor2_0(),
145248
)
146249

147250
# 3. Feed-forward
@@ -258,6 +361,8 @@ def __init__(
258361
norm_elementwise_affine: bool = False,
259362
norm_eps: float = 1e-6,
260363
interpolation_scale: Optional[int] = None,
364+
guidance_embeds: bool = False,
365+
qk_norm: Optional[str] = None,
261366
) -> None:
262367
super().__init__()
263368

@@ -276,7 +381,10 @@ def __init__(
276381
)
277382

278383
# 2. Additional condition embeddings
279-
self.time_embed = AdaLayerNormSingle(inner_dim)
384+
if guidance_embeds:
385+
self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
386+
else:
387+
self.time_embed = AdaLayerNormSingle(inner_dim)
280388

281389
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
282390
self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
@@ -296,6 +404,7 @@ def __init__(
296404
norm_elementwise_affine=norm_elementwise_affine,
297405
norm_eps=norm_eps,
298406
mlp_ratio=mlp_ratio,
407+
qk_norm=qk_norm,
299408
)
300409
for _ in range(num_layers)
301410
]
@@ -372,7 +481,8 @@ def forward(
372481
self,
373482
hidden_states: torch.Tensor,
374483
encoder_hidden_states: torch.Tensor,
375-
timestep: torch.LongTensor,
484+
timestep: torch.Tensor,
485+
guidance: Optional[torch.Tensor] = None,
376486
encoder_attention_mask: Optional[torch.Tensor] = None,
377487
attention_mask: Optional[torch.Tensor] = None,
378488
attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -423,9 +533,14 @@ def forward(
423533

424534
hidden_states = self.patch_embed(hidden_states)
425535

426-
timestep, embedded_timestep = self.time_embed(
427-
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
428-
)
536+
if guidance is not None:
537+
timestep, embedded_timestep = self.time_embed(
538+
timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
539+
)
540+
else:
541+
timestep, embedded_timestep = self.time_embed(
542+
timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
543+
)
429544

430545
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
431546
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,7 @@
280280
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
281281
_import_structure["pia"] = ["PIAPipeline"]
282282
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
283-
_import_structure["sana"] = ["SanaPipeline"]
283+
_import_structure["sana"] = ["SanaPipeline", "SanaSCMPipeline"]
284284
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
285285
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
286286
_import_structure["stable_audio"] = [
@@ -651,7 +651,7 @@
651651
from .paint_by_example import PaintByExamplePipeline
652652
from .pia import PIAPipeline
653653
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
654-
from .sana import SanaPipeline
654+
from .sana import SanaPipeline, SanaSCMPipeline
655655
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
656656
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
657657
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel

src/diffusers/pipelines/sana/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
2424
else:
2525
_import_structure["pipeline_sana"] = ["SanaPipeline"]
26+
_import_structure["pipeline_sana_scm"] = ["SanaSCMPipeline"]
2627

2728
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
2829
try:
@@ -33,6 +34,7 @@
3334
from ...utils.dummy_torch_and_transformers_objects import *
3435
else:
3536
from .pipeline_sana import SanaPipeline
37+
from .pipeline_sana_scm import SanaSCMPipeline
3638
else:
3739
import sys
3840

0 commit comments

Comments
 (0)