Skip to content

Commit 9d31426

Browse files
committed
make style
1 parent f698524 commit 9d31426

File tree

12 files changed

+66
-60
lines changed

12 files changed

+66
-60
lines changed

scripts/convert_sana_pag_to_diffusers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,9 @@ def main(args):
187187
try:
188188
state_dict.pop("y_embedder.y_embedding")
189189
state_dict.pop("pos_embed")
190-
except:
191-
pass
190+
except KeyError:
191+
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
192+
192193
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
193194

194195
num_model_params = sum(p.numel() for p in transformer.parameters())

scripts/convert_sana_to_diffusers.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
def main(args):
3939
ckpt_id = ckpt_ids[0]
4040
cache_dir_path = os.path.expanduser("~/.cache/huggingface/hub")
41-
41+
4242
if args.orig_ckpt_path is None:
4343
snapshot_download(
4444
repo_id=ckpt_id,
@@ -53,7 +53,7 @@ def main(args):
5353
)
5454
else:
5555
file_path = args.orig_ckpt_path
56-
56+
5757
all_state_dict = torch.load(file_path, weights_only=True)
5858
state_dict = all_state_dict.pop("state_dict")
5959
converted_state_dict = {}
@@ -98,7 +98,7 @@ def main(args):
9898
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
9999
f"blocks.{depth}.scale_shift_table"
100100
)
101-
101+
102102
# Linear Attention is all you need 🤘
103103
# Self attention.
104104
q, k, v = torch.chunk(state_dict.pop(f"blocks.{depth}.attn.qkv.weight"), 3, dim=0)
@@ -182,8 +182,9 @@ def main(args):
182182
try:
183183
state_dict.pop("y_embedder.y_embedding")
184184
state_dict.pop("pos_embed")
185-
except:
186-
pass
185+
except KeyError:
186+
print("y_embedder.y_embedding or pos_embed not found in the state_dict")
187+
187188
assert len(state_dict) == 0, f"State dict is not empty, {state_dict.keys()}"
188189

189190
num_model_params = sum(p.numel() for p in transformer.parameters())
@@ -198,11 +199,15 @@ def main(args):
198199
attrs=["bold"],
199200
)
200201
)
201-
transformer.save_pretrained(os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant)
202+
transformer.save_pretrained(
203+
os.path.join(args.dump_path, "transformer"), safe_serialization=True, max_shard_size="5GB", variant=variant
204+
)
202205
else:
203206
print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"]))
204207
# VAE
205-
ae = AutoencoderDC.from_pretrained("mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers",)
208+
ae = AutoencoderDC.from_pretrained(
209+
"mit-han-lab/dc-ae-f32c32-sana-1.0-diffusers",
210+
)
206211

207212
# Text Encoder
208213
text_encoder_model_path = "google/gemma-2-2b-it"

src/diffusers/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,8 @@
111111
"MultiAdapter",
112112
"MultiControlNetModel",
113113
"PixArtTransformer2DModel",
114-
"SanaTransformer2DModel",
115114
"PriorTransformer",
115+
"SanaTransformer2DModel",
116116
"SD3ControlNetModel",
117117
"SD3MultiControlNetModel",
118118
"SD3Transformer2DModel",
@@ -180,12 +180,12 @@
180180
"DEISMultistepScheduler",
181181
"DPMSolverMultistepInverseScheduler",
182182
"DPMSolverMultistepScheduler",
183-
"FlowDPMSolverMultistepScheduler",
184183
"DPMSolverSinglestepScheduler",
185184
"EDMDPMSolverMultistepScheduler",
186185
"EDMEulerScheduler",
187186
"EulerAncestralDiscreteScheduler",
188187
"EulerDiscreteScheduler",
188+
"FlowDPMSolverMultistepScheduler",
189189
"FlowMatchEulerDiscreteScheduler",
190190
"FlowMatchHeunDiscreteScheduler",
191191
"HeunDiscreteScheduler",
@@ -330,8 +330,8 @@
330330
"PixArtSigmaPAGPipeline",
331331
"PixArtSigmaPipeline",
332332
"ReduxImageEncoder",
333-
"SanaPipeline",
334333
"SanaPAGPipeline",
334+
"SanaPipeline",
335335
"SemanticStableDiffusionPipeline",
336336
"ShapEImg2ImgPipeline",
337337
"ShapEPipeline",
@@ -345,8 +345,8 @@
345345
"StableDiffusion3Img2ImgPipeline",
346346
"StableDiffusion3InpaintPipeline",
347347
"StableDiffusion3PAGImg2ImgPipeline",
348-
"StableDiffusion3PAGPipeline",
349348
"StableDiffusion3PAGImg2ImgPipeline",
349+
"StableDiffusion3PAGPipeline",
350350
"StableDiffusion3Pipeline",
351351
"StableDiffusionAdapterPipeline",
352352
"StableDiffusionAttendAndExcitePipeline",

src/diffusers/models/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@
5858
_import_structure["transformers.latte_transformer_3d"] = ["LatteTransformer3DModel"]
5959
_import_structure["transformers.lumina_nextdit2d"] = ["LuminaNextDiT2DModel"]
6060
_import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
61-
_import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"]
6261
_import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
62+
_import_structure["transformers.sana_transformer"] = ["SanaTransformer2DModel"]
6363
_import_structure["transformers.stable_audio_transformer"] = ["StableAudioDiTModel"]
6464
_import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
6565
_import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]

src/diffusers/models/attention_processor.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5359,12 +5359,9 @@ def __call__(
53595359
encoder_hidden_states: Optional[torch.Tensor] = None,
53605360
attention_mask: Optional[torch.Tensor] = None,
53615361
) -> torch.Tensor:
5362-
input_ndim = hidden_states.ndim
53635362
original_dtype = hidden_states.dtype
53645363

5365-
batch_size, _, _ = (
5366-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
5367-
)
5364+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
53685365

53695366
if encoder_hidden_states is None:
53705367
encoder_hidden_states = hidden_states
@@ -5391,7 +5388,7 @@ def __call__(
53915388

53925389
if hidden_states.dtype in [torch.float16, torch.bfloat16]:
53935390
hidden_states = hidden_states.float()
5394-
5391+
53955392
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
53965393
hidden_states = hidden_states.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
53975394
hidden_states = hidden_states.to(original_dtype)

src/diffusers/models/autoencoders/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,4 @@
88
from .autoencoder_oobleck import AutoencoderOobleck
99
from .autoencoder_tiny import AutoencoderTiny
1010
from .consistency_decoder_vae import ConsistencyDecoderVAE
11-
from .autoencoder_dc import AutoencoderDC
1211
from .vq_model import VQModel

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from functools import partial
16-
from typing import Any, Dict, Optional, Union
15+
from typing import Dict, Optional, Union
1716

1817
import torch
1918
from torch import nn
@@ -25,7 +24,6 @@
2524
Attention,
2625
AttentionProcessor,
2726
AttnProcessor2_0,
28-
SanaMultiscaleLinearAttention,
2927
SanaLinearAttnProcessor2_0,
3028
)
3129
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
@@ -135,7 +133,7 @@ def __init__(
135133
mlp_ratio=mlp_ratio,
136134
)
137135

138-
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim ** 0.5)
136+
self.scale_shift_table = nn.Parameter(torch.randn(6, dim) / dim**0.5)
139137

140138
def forward(
141139
self,
@@ -152,7 +150,7 @@ def forward(
152150
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
153151
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
154152
).chunk(6, dim=1)
155-
153+
156154
# 2. Self Attention
157155
norm_hidden_states = self.norm1(hidden_states)
158156
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
@@ -258,9 +256,7 @@ def __init__(
258256
)
259257

260258
# 2. Caption Embedding
261-
self.caption_projection = PixArtAlphaTextProjection(
262-
in_features=caption_channels, hidden_size=inner_dim
263-
)
259+
self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
264260
self.caption_norm = RMSNorm(inner_dim, eps=1e-5)
265261

266262
# 3. Transformer blocks
@@ -285,7 +281,7 @@ def __init__(
285281

286282
# 4. Output blocks
287283
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
288-
284+
289285
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
290286
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
291287

@@ -401,12 +397,12 @@ def forward(
401397

402398
encoder_hidden_states = self.caption_projection(encoder_hidden_states)
403399
encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
404-
400+
405401
encoder_hidden_states = self.caption_norm(encoder_hidden_states)
406402

407403
# 2. Transformer blocks
408404
use_reentrant = is_torch_version("<=", "1.11.0")
409-
405+
410406
def create_block_forward(block):
411407
if torch.is_grad_enabled() and self.gradient_checkpointing:
412408
return lambda *inputs: torch.utils.checkpoint.checkpoint(
@@ -430,16 +426,23 @@ def create_block_forward(block):
430426
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
431427
).chunk(2, dim=1)
432428
hidden_states = self.norm_out(hidden_states)
433-
429+
434430
# 4. Modulation
435431
hidden_states = hidden_states * (1 + scale) + shift
436432
hidden_states = self.proj_out(hidden_states)
437433

438434
# 5. Unpatchify
439-
hidden_states = hidden_states.reshape(batch_size, post_patch_height, post_patch_width, self.config.patch_size, self.config.patch_size, -1)
435+
hidden_states = hidden_states.reshape(
436+
batch_size, post_patch_height, post_patch_width, self.config.patch_size, self.config.patch_size, -1
437+
)
440438
hidden_states = hidden_states.permute(0, 5, 1, 3, 2, 4)
441439
output = hidden_states.reshape(
442-
shape=(batch_size, -1, post_patch_height * self.config.patch_size, post_patch_width * self.config.patch_size)
440+
shape=(
441+
batch_size,
442+
-1,
443+
post_patch_height * self.config.patch_size,
444+
post_patch_width * self.config.patch_size,
445+
)
443446
)
444447

445448
if not return_dict:

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ def __init__(
176176
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
177177

178178
self.set_pag_applied_layers(
179-
pag_applied_layers, pag_attn_processors=(PAGCFGSanaLinearAttnProcessor2_0(), PAGIdentitySanaLinearAttnProcessor2_0())
179+
pag_applied_layers,
180+
pag_attn_processors=(PAGCFGSanaLinearAttnProcessor2_0(), PAGIdentitySanaLinearAttnProcessor2_0()),
180181
)
181182

182183
# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from ...schedulers import FlowDPMSolverMultistepScheduler
2828
from ...utils import (
2929
BACKENDS_MAPPING,
30-
deprecate,
3130
is_bs4_available,
3231
is_ftfy_available,
3332
logging,
@@ -59,9 +58,7 @@
5958
>>> from diffusers import SanaPipeline
6059
6160
>>> # You can replace the checkpoint id with "Sana_1600M_1024px/Sana_1600M_1024px" too.
62-
>>> pipe = SanaPipeline.from_pretrained(
63-
... "Sana_1600M_1024px/Sana_1600M_1024px", torch_dtype=torch.float16
64-
... )
61+
>>> pipe = SanaPipeline.from_pretrained("Sana_1600M_1024px/Sana_1600M_1024px", torch_dtype=torch.float16)
6562
>>> # Enable memory optimizations.
6663
>>> # pipe.enable_model_cpu_offload()
6764
@@ -171,7 +168,9 @@ def __init__(
171168
)
172169

173170
self.vae_scale_factor = (
174-
2 ** (len(self.vae.config.encoder_block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 32
171+
2 ** (len(self.vae.config.encoder_block_out_channels) - 1)
172+
if hasattr(self, "vae") and self.vae is not None
173+
else 32
175174
)
176175
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
177176

@@ -650,13 +649,13 @@ def __call__(
650649
651650
Returns:
652651
[`~pipelines.sana.pipeline_output.SanaPipelineOutput`] or `tuple`:
653-
If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned, otherwise a `tuple` is
654-
returned where the first element is a list with the generated images
652+
If `return_dict` is `True`, [`~pipelines.sana.pipeline_output.SanaPipelineOutput`] is returned,
653+
otherwise a `tuple` is returned where the first element is a list with the generated images
655654
"""
656655

657656
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
658657
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
659-
658+
660659
# 1. Check inputs. Raise error if not correct
661660
if use_resolution_binning:
662661
if self.transformer.config.sample_size == 64:
@@ -778,6 +777,17 @@ def __call__(
778777
if torch.backends.mps.is_available():
779778
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
780779
latents = latents.to(latents_dtype)
780+
781+
if callback_on_step_end is not None:
782+
callback_kwargs = {}
783+
for k in callback_on_step_end_tensor_inputs:
784+
callback_kwargs[k] = locals()[k]
785+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
786+
787+
latents = callback_outputs.pop("latents", latents)
788+
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
789+
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
790+
781791
# call the callback, if provided
782792
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
783793
progress_bar.update()

src/diffusers/schedulers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@
5252
_import_structure["scheduling_deis_multistep"] = ["DEISMultistepScheduler"]
5353
_import_structure["scheduling_dpm_cogvideox"] = ["CogVideoXDPMScheduler"]
5454
_import_structure["scheduling_dpmsolver_multistep"] = ["DPMSolverMultistepScheduler"]
55+
_import_structure["scheduling_dpmsolver_multistep_flow"] = ["FlowDPMSolverMultistepScheduler"]
5556
_import_structure["scheduling_dpmsolver_multistep_inverse"] = ["DPMSolverMultistepInverseScheduler"]
5657
_import_structure["scheduling_dpmsolver_singlestep"] = ["DPMSolverSinglestepScheduler"]
57-
_import_structure["scheduling_dpmsolver_multistep_flow"] = ["FlowDPMSolverMultistepScheduler"]
5858
_import_structure["scheduling_edm_dpmsolver_multistep"] = ["EDMDPMSolverMultistepScheduler"]
5959
_import_structure["scheduling_edm_euler"] = ["EDMEulerScheduler"]
6060
_import_structure["scheduling_euler_ancestral_discrete"] = ["EulerAncestralDiscreteScheduler"]

0 commit comments

Comments
 (0)