Skip to content

Commit e7c1a59

Browse files
committed
update for solving conversation.
1 parent 297c0e7 commit e7c1a59

File tree

4 files changed

+89
-122
lines changed

4 files changed

+89
-122
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5566,5 +5566,4 @@ def __call__(
55665566
LoRAAttnProcessor2_0,
55675567
LoRAXFormersAttnProcessor,
55685568
LoRAAttnAddedKVProcessor,
5569-
SanaLinearAttnProcessor2_0,
55705569
]

src/diffusers/models/normalization.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,3 +590,41 @@ def get_normalization(
590590
else:
591591
raise ValueError(f"{norm_type=} is not supported.")
592592
return norm
593+
594+
595+
class RMSNormScaled(nn.Module):
596+
def __init__(self, dim, eps: float, elementwise_affine: bool = True, scale_factor: float = 1.0, bias: bool = False):
597+
super().__init__(dim, eps, elementwise_affine)
598+
self.weight = nn.Parameter(torch.ones(dim) * scale_factor)
599+
600+
self.eps = eps
601+
self.elementwise_affine = elementwise_affine
602+
603+
if isinstance(dim, numbers.Integral):
604+
dim = (dim,)
605+
606+
self.dim = torch.Size(dim)
607+
608+
self.weight = None
609+
self.bias = None
610+
611+
if elementwise_affine:
612+
self.weight = nn.Parameter(torch.ones(dim) * scale_factor)
613+
if bias:
614+
self.bias = nn.Parameter(torch.zeros(dim))
615+
def forward(self, hidden_states):
616+
input_dtype = hidden_states.dtype
617+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
618+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
619+
620+
if self.weight is not None:
621+
# convert into half-precision if necessary
622+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
623+
hidden_states = hidden_states.to(self.weight.dtype)
624+
hidden_states = hidden_states * self.weight
625+
if self.bias is not None:
626+
hidden_states = hidden_states + self.bias
627+
else:
628+
hidden_states = hidden_states.to(input_dtype)
629+
630+
return hidden_states

src/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from .lumina_nextdit2d import LuminaNextDiT2DModel
1212
from .pixart_transformer_2d import PixArtTransformer2DModel
1313
from .prior_transformer import PriorTransformer
14-
from .sana_transformer_2d import SanaTransformer2DModel
14+
from .sana_transformer import SanaTransformer2DModel
1515
from .stable_audio_transformer import StableAudioDiTModel
1616
from .t5_film_transformer import T5FilmDecoder
1717
from .transformer_2d import Transformer2DModel

src/diffusers/models/transformers/sana_transformer_2d.py renamed to src/diffusers/models/transformers/sana_transformer.py

Lines changed: 50 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import numbers
1415
from typing import Any, Dict, Optional, Union
1516

1617
import torch
@@ -25,38 +26,18 @@
2526
AttnProcessor2_0,
2627
FusedAttnProcessor2_0,
2728
SanaLinearAttnProcessor2_0,
29+
SanaMultiscaleAttnProcessor2_0,
30+
SanaMultiscaleLinearAttention,
2831
)
2932
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, SinusoidalPositionalEmbedding
3033
from ..modeling_outputs import Transformer2DModelOutput
3134
from ..modeling_utils import ModelMixin
32-
from ..normalization import AdaLayerNormSingle, RMSNorm
35+
from ..normalization import AdaLayerNormSingle, RMSNormScaled
3336

3437

3538
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3639

3740

38-
def _chunked_feed_forward(ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, HW: tuple=None):
39-
# "feed_forward_chunk_size" can be used to save memory
40-
if hidden_states.shape[chunk_dim] % chunk_size != 0:
41-
raise ValueError(
42-
f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
43-
)
44-
45-
num_chunks = hidden_states.shape[chunk_dim] // chunk_size
46-
ff_output = torch.cat(
47-
[ff(hid_slice, HW) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
48-
dim=chunk_dim,
49-
)
50-
return ff_output
51-
52-
53-
@maybe_allow_in_graph
54-
class RMSNormScaled(RMSNorm):
55-
def __init__(self, dim, eps: float, elementwise_affine: bool = True, scale_factor: float = 1.0):
56-
super().__init__(dim, eps, elementwise_affine)
57-
self.weight = nn.Parameter(torch.ones(dim) * scale_factor)
58-
59-
6041
# Modified from diffusers.models.autoencoders.autoencoder_dc.GLUMBConv
6142
@maybe_allow_in_graph
6243
class SanaGLUMBConv(nn.Module):
@@ -105,22 +86,19 @@ class SanaLinearTransformerBlock(nn.Module):
10586
dim (`int`): The number of channels in the input and output.
10687
num_attention_heads (`int`): The number of heads to use for multi-head attention.
10788
attention_head_dim (`int`): The number of channels in each head.
108-
context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
109-
processing of `context` conditions.
11089
"""
11190

11291
def __init__(
11392
self,
114-
dim: int,
115-
num_attention_heads: int,
116-
attention_head_dim: int,
117-
dropout=0.0,
118-
num_cross_attention_heads: Optional[int] = None,
119-
cross_attention_head_dim: Optional[int] = None,
120-
cross_attention_dim: Optional[int] = None,
121-
activation_fn: tuple = ("silu", "silu", None),
122-
num_embeds_ada_norm: Optional[int] = None,
123-
attention_bias: bool = False,
93+
dim: int = 2240,
94+
num_attention_heads: int = 70,
95+
attention_head_dim: int = 32,
96+
dropout: float = 0.0,
97+
num_cross_attention_heads: Optional[int] = 20,
98+
cross_attention_head_dim: Optional[int] = 112,
99+
cross_attention_dim: Optional[int] = 2240,
100+
num_embeds_ada_norm: Optional[int] = 1000,
101+
attention_bias: bool = True,
124102
upcast_attention: bool = False,
125103
norm_type: str = "ada_norm_single",
126104
norm_elementwise_affine: bool = False,
@@ -136,7 +114,6 @@ def __init__(
136114
self.attention_head_dim = attention_head_dim
137115
self.dropout = dropout
138116
self.cross_attention_dim = cross_attention_dim
139-
self.activation_fn = activation_fn
140117
self.attention_bias = attention_bias
141118
self.norm_elementwise_affine = norm_elementwise_affine
142119

@@ -205,8 +182,6 @@ def forward(
205182
encoder_attention_mask: Optional[torch.Tensor] = None,
206183
timestep: Optional[torch.LongTensor] = None,
207184
cross_attention_kwargs: Dict[str, Any] = None,
208-
class_labels: Optional[torch.LongTensor] = None,
209-
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
210185
HW: Optional[tuple[int]] = None,
211186
) -> torch.Tensor:
212187
if cross_attention_kwargs is not None:
@@ -260,11 +235,7 @@ def forward(
260235
norm_hidden_states = self.norm2(hidden_states)
261236
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
262237

263-
if self._chunk_size is not None:
264-
# "feed_forward_chunk_size" can be used to save memory
265-
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, HW=HW)
266-
else:
267-
ff_output = self.ff(norm_hidden_states, HW=HW)
238+
ff_output = self.ff(norm_hidden_states, HW=HW)
268239

269240
if self.norm_type == "ada_norm_single":
270241
ff_output = gate_mlp * ff_output
@@ -301,8 +272,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin):
301272
The width of the latent images. This parameter is fixed during training.
302273
patch_size (int, defaults to 1):
303274
Size of the patches the model processes, relevant for architectures working on non-sequential data.
304-
activation_fn (str, optional, defaults to "gelu-approximate"):
305-
Activation function to use in feed-forward networks within Transformer blocks.
306275
num_embeds_ada_norm (int, optional, defaults to 1000):
307276
Number of embeddings for AdaLayerNorm, fixed during training and affects the maximum denoising steps during
308277
inference.
@@ -338,11 +307,10 @@ def __init__(
338307
norm_num_groups: int = 32,
339308
num_cross_attention_heads: Optional[int] = 20,
340309
cross_attention_head_dim: Optional[int] = 112,
341-
cross_attention_dim: Optional[int] = 1152,
310+
cross_attention_dim: Optional[int] = 2240,
342311
attention_bias: bool = True,
343312
sample_size: int = 32,
344313
patch_size: int = 1,
345-
activation_fn: tuple = ("silu", "silu", None),
346314
num_embeds_ada_norm: Optional[int] = 1000,
347315
upcast_attention: bool = False,
348316
norm_type: str = "ada_norm_single",
@@ -371,7 +339,7 @@ def __init__(
371339

372340
# Set some common variables used across the board.
373341
self.attention_head_dim = attention_head_dim
374-
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
342+
self.inner_dim = num_attention_heads * attention_head_dim
375343
self.out_channels = in_channels if out_channels is None else out_channels
376344
if use_additional_conditions is None:
377345
if sample_size == 128:
@@ -383,63 +351,66 @@ def __init__(
383351
self.gradient_checkpointing = False
384352

385353
# 2. Initialize the position embedding and transformer blocks.
386-
self.height = self.config.sample_size
387-
self.width = self.config.sample_size
354+
self.height = sample_size
355+
self.width = sample_size
356+
357+
if use_pe:
358+
interpolation_scale = (
359+
interpolation_scale
360+
if interpolation_scale is not None
361+
else max(sample_size // 64, 1)
362+
)
363+
else:
364+
interpolation_scale = None
388365

389-
interpolation_scale = (
390-
self.config.interpolation_scale
391-
if self.config.interpolation_scale is not None
392-
else max(self.config.sample_size // 64, 1)
393-
)
394366
self.pos_embed = PatchEmbed(
395-
height=self.config.sample_size,
396-
width=self.config.sample_size,
397-
patch_size=self.config.patch_size,
398-
in_channels=self.config.in_channels,
367+
height=sample_size,
368+
width=sample_size,
369+
patch_size=patch_size,
370+
in_channels=in_channels,
399371
embed_dim=self.inner_dim,
400372
interpolation_scale=interpolation_scale,
401-
pos_embed_type="sincos" if self.config.use_pe else None
373+
pos_embed_type="sincos" if use_pe else None
402374
)
403375

404376
self.transformer_blocks = nn.ModuleList(
405377
[
406378
SanaLinearTransformerBlock(
407379
self.inner_dim,
408-
self.config.num_attention_heads,
409-
self.config.attention_head_dim,
410-
dropout=self.config.dropout,
411-
num_cross_attention_heads=self.config.num_cross_attention_heads,
412-
cross_attention_head_dim=self.config.cross_attention_head_dim,
413-
cross_attention_dim=self.config.cross_attention_dim,
414-
activation_fn=self.config.activation_fn,
415-
num_embeds_ada_norm=self.config.num_embeds_ada_norm,
416-
attention_bias=self.config.attention_bias,
417-
upcast_attention=self.config.upcast_attention,
380+
num_attention_heads,
381+
attention_head_dim,
382+
dropout=dropout,
383+
num_cross_attention_heads=num_cross_attention_heads,
384+
cross_attention_head_dim=cross_attention_head_dim,
385+
cross_attention_dim=cross_attention_dim,
386+
num_embeds_ada_norm=num_embeds_ada_norm,
387+
attention_bias=attention_bias,
388+
upcast_attention=upcast_attention,
418389
norm_type=norm_type,
419-
norm_elementwise_affine=self.config.norm_elementwise_affine,
420-
norm_eps=self.config.norm_eps,
421-
use_pe=self.config.use_pe,
422-
expand_ratio=self.config.expand_ratio,
390+
norm_elementwise_affine=norm_elementwise_affine,
391+
norm_eps=norm_eps,
392+
use_pe=use_pe,
393+
expand_ratio=expand_ratio,
423394
)
424-
for _ in range(self.config.num_layers)
395+
for _ in range(num_layers)
425396
]
426397
)
427398

428399
# 3. Output blocks.
429400
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
430401
self.scale_shift_table = nn.Parameter(torch.randn(2, self.inner_dim) / self.inner_dim**0.5)
431-
self.proj_out = nn.Linear(self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels)
402+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels)
432403

433404
self.adaln_single = AdaLayerNormSingle(
434405
self.inner_dim, use_additional_conditions=self.use_additional_conditions
435406
)
436407
self.caption_projection = None
437-
if self.config.caption_channels is not None:
408+
if caption_channels is not None:
438409
self.caption_projection = PixArtAlphaTextProjection(
439-
in_features=self.config.caption_channels, hidden_size=self.inner_dim
410+
in_features=caption_channels, hidden_size=self.inner_dim
440411
)
441412
self.caption_norm = None
442-
if self.config.use_caption_norm:
413+
if use_caption_norm:
443414
self.caption_norm = RMSNormScaled(self.inner_dim, eps=1e-5, scale_factor=caption_norm_scale_factor)
444415

445416
def _set_gradient_checkpointing(self, module, value=False):
@@ -506,46 +477,6 @@ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
506477
for name, module in self.named_children():
507478
fn_recursive_attn_processor(name, module, processor)
508479

509-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
510-
def fuse_qkv_projections(self):
511-
"""
512-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
513-
are fused. For cross-attention modules, key and value projection matrices are fused.
514-
515-
<Tip warning={true}>
516-
517-
This API is 🧪 experimental.
518-
519-
</Tip>
520-
"""
521-
self.original_attn_processors = None
522-
523-
for _, attn_processor in self.attn_processors.items():
524-
if "Added" in str(attn_processor.__class__.__name__):
525-
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
526-
527-
self.original_attn_processors = self.attn_processors
528-
529-
for module in self.modules():
530-
if isinstance(module, Attention):
531-
module.fuse_projections(fuse=True)
532-
533-
self.set_attn_processor(FusedAttnProcessor2_0())
534-
535-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
536-
def unfuse_qkv_projections(self):
537-
"""Disables the fused QKV projection if enabled.
538-
539-
<Tip warning={true}>
540-
541-
This API is 🧪 experimental.
542-
543-
</Tip>
544-
545-
"""
546-
if self.original_attn_processors is not None:
547-
self.set_attn_processor(self.original_attn_processors)
548-
549480
def forward(
550481
self,
551482
hidden_states: torch.Tensor,
@@ -556,7 +487,6 @@ def forward(
556487
cross_attention_kwargs: Dict[str, Any] = None,
557488
attention_mask: Optional[torch.Tensor] = None,
558489
return_dict: bool = True,
559-
**kwargs,
560490
):
561491
"""
562492
The [`PixArtTransformer2DModel`] forward method.

0 commit comments

Comments
 (0)