Skip to content

Commit 23f780f

Browse files
committed
update Sana for DC-AE's recent commit;
1 parent c67753e commit 23f780f

File tree

8 files changed

+47
-69
lines changed

8 files changed

+47
-69
lines changed

scripts/convert_sana_pag_to_diffusers.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from accelerate import init_empty_weights
1010
from diffusers import (
1111
DCAE,
12-
DCAE_HF,
13-
FlowDPMSolverMultistepScheduler,
12+
DPMSolverMultistepScheduler,
1413
FlowMatchEulerDiscreteScheduler,
1514
SanaPAGPipeline,
1615
SanaTransformer2DModel,
@@ -186,27 +185,10 @@ def main(args):
186185
else:
187186
print(colored(f"Saving the whole SanaPAGPipeline containing {args.model_type}", "green", attrs=["bold"]))
188187
# VAE
189-
dc_ae = DCAE_HF.from_pretrained(f"mit-han-lab/dc-ae-f32c32-sana-1.0")
190-
dc_ae_state_dict = dc_ae.state_dict()
191-
dc_ae = DCAE(
192-
in_channels=3,
193-
latent_channels=32,
194-
encoder_width_list=[128, 256, 512, 512, 1024, 1024],
195-
encoder_depth_list=[2, 2, 2, 3, 3, 3],
196-
encoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"],
197-
encoder_norm="rms2d",
198-
encoder_act="silu",
199-
downsample_block_type="Conv",
200-
decoder_width_list=[128, 256, 512, 512, 1024, 1024],
201-
decoder_depth_list=[3, 3, 3, 3, 3, 3],
202-
decoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"],
203-
decoder_norm="rms2d",
204-
decoder_act="silu",
205-
upsample_block_type="InterpolateConv",
206-
scaling_factor=0.41407,
207-
)
208-
dc_ae.load_state_dict(dc_ae_state_dict, strict=True)
209-
dc_ae.to(torch.float32).to(device)
188+
dc_ae = DCAE.from_pretrained(
189+
"Efficient-Large-Model/dc_ae_f32c32_sana_1.0_diffusers",
190+
torch_dtype=torch.float32,
191+
).to(device)
210192

211193
# Text Encoder
212194
text_encoder_model_path = "google/gemma-2-2b-it"
@@ -220,7 +202,11 @@ def main(args):
220202

221203
# Scheduler
222204
if args.scheduler_type == "flow-dpm_solver":
223-
scheduler = FlowDPMSolverMultistepScheduler(flow_shift=flow_shift)
205+
scheduler = DPMSolverMultistepScheduler(
206+
flow_shift=flow_shift,
207+
use_flow_sigmas=True,
208+
prediction_type="flow_prediction",
209+
)
224210
elif args.scheduler_type == "flow-euler":
225211
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
226212
else:

scripts/convert_sana_to_diffusers.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from accelerate import init_empty_weights
1010
from diffusers import (
1111
DCAE,
12-
DCAE_HF,
13-
FlowDPMSolverMultistepScheduler,
12+
DPMSolverMultistepScheduler,
1413
FlowMatchEulerDiscreteScheduler,
1514
SanaPipeline,
1615
SanaTransformer2DModel,
@@ -186,27 +185,10 @@ def main(args):
186185
else:
187186
print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"]))
188187
# VAE
189-
dc_ae = DCAE_HF.from_pretrained(f"mit-han-lab/dc-ae-f32c32-sana-1.0")
190-
dc_ae_state_dict = dc_ae.state_dict()
191-
dc_ae = DCAE(
192-
in_channels=3,
193-
latent_channels=32,
194-
encoder_width_list=[128, 256, 512, 512, 1024, 1024],
195-
encoder_depth_list=[2, 2, 2, 3, 3, 3],
196-
encoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"],
197-
encoder_norm="rms2d",
198-
encoder_act="silu",
199-
downsample_block_type="Conv",
200-
decoder_width_list=[128, 256, 512, 512, 1024, 1024],
201-
decoder_depth_list=[3, 3, 3, 3, 3, 3],
202-
decoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"],
203-
decoder_norm="rms2d",
204-
decoder_act="silu",
205-
upsample_block_type="InterpolateConv",
206-
scaling_factor=0.41407,
207-
)
208-
dc_ae.load_state_dict(dc_ae_state_dict, strict=True)
209-
dc_ae.to(torch.float32).to(device)
188+
dc_ae = DCAE.from_pretrained(
189+
"Efficient-Large-Model/dc_ae_f32c32_sana_1.0_diffusers",
190+
torch_dtype=torch.float32,
191+
).to(device)
210192

211193
# Text Encoder
212194
text_encoder_model_path = "google/gemma-2-2b-it"
@@ -220,7 +202,11 @@ def main(args):
220202

221203
# Scheduler
222204
if args.scheduler_type == "flow-dpm_solver":
223-
scheduler = FlowDPMSolverMultistepScheduler(flow_shift=flow_shift)
205+
scheduler = DPMSolverMultistepScheduler(
206+
flow_shift=flow_shift,
207+
use_flow_sigmas=True,
208+
prediction_type="flow_prediction",
209+
)
224210
elif args.scheduler_type == "flow-euler":
225211
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
226212
else:

src/diffusers/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,6 @@
130130
"UVit2DModel",
131131
"VQModel",
132132
"DCAE",
133-
"DCAE_HF",
134133
]
135134
)
136135
_import_structure["optimization"] = [
@@ -575,7 +574,6 @@
575574
else:
576575
from .models import (
577576
DCAE,
578-
DCAE_HF,
579577
AllegroTransformer3DModel,
580578
AsymmetricAutoencoderKL,
581579
AuraFlowTransformer2DModel,

src/diffusers/models/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
3737
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
3838
_import_structure["autoencoders.vq_model"] = ["VQModel"]
39-
_import_structure["autoencoders.dc_ae"] = ["DCAE", "DCAE_HF"]
39+
_import_structure["autoencoders.autoencoder_dc"] = ["DCAE"]
4040
_import_structure["controlnets.controlnet"] = ["ControlNetModel"]
4141
_import_structure["controlnets.controlnet_flux"] = ["FluxControlNetModel", "FluxMultiControlNetModel"]
4242
_import_structure["controlnets.controlnet_hunyuan"] = [
@@ -90,7 +90,6 @@
9090
from .adapter import MultiAdapter, T2IAdapter
9191
from .autoencoders import (
9292
DCAE,
93-
DCAE_HF,
9493
AsymmetricAutoencoderKL,
9594
AutoencoderKL,
9695
AutoencoderKLAllegro,

src/diffusers/models/autoencoders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,5 +7,5 @@
77
from .autoencoder_oobleck import AutoencoderOobleck
88
from .autoencoder_tiny import AutoencoderTiny
99
from .consistency_decoder_vae import ConsistencyDecoderVAE
10-
from .dc_ae import DCAE, DCAE_HF
10+
from .autoencoder_dc import DCAE
1111
from .vq_model import VQModel

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ def encode(self, x: torch.Tensor) -> torch.Tensor:
630630
def decode(self, x: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
631631
x = self.decoder(x)
632632
if not return_dict:
633-
return x
633+
return (x, )
634634
else:
635635
return DecoderOutput(sample=x)
636636

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from transformers import AutoModelForCausalLM, AutoTokenizer
2323

2424
from ...image_processor import PixArtImageProcessor
25-
from ...models import DCAE_HF, SanaTransformer2DModel
25+
from ...models import DCAE, SanaTransformer2DModel
2626
from ...models.attention_processor import PAGCFGSanaLinearAttnProcessor2_0, PAGIdentitySanaLinearAttnProcessor2_0
2727
from ...schedulers import FlowDPMSolverMultistepScheduler
2828
from ...utils import (
@@ -162,7 +162,7 @@ def __init__(
162162
self,
163163
tokenizer: AutoTokenizer,
164164
text_encoder: AutoModelForCausalLM,
165-
vae: DCAE_HF,
165+
vae: DCAE,
166166
transformer: SanaTransformer2DModel,
167167
scheduler: FlowDPMSolverMultistepScheduler,
168168
pag_applied_layers: Union[str, List[str]] = "blocks.1", # 1st transformer block
@@ -840,22 +840,27 @@ def __call__(
840840
noise_pred = noise_pred
841841

842842
# compute previous image: x_t -> x_t-1
843+
latents_dtype = latents.dtype
843844
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
844845

846+
if latents.dtype != latents_dtype:
847+
if torch.backends.mps.is_available():
848+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
849+
latents = latents.to(latents_dtype)
845850
# call the callback, if provided
846851
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
847852
progress_bar.update()
848853
if callback is not None and i % callback_steps == 0:
849854
step_idx = i // getattr(self.scheduler, "order", 1)
850855
callback(step_idx, t, latents)
851-
# set to None for next
852856

853-
if not output_type == "latent":
854-
image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor)
857+
if output_type == "latent":
858+
image = latents
859+
else:
860+
latents = latents.to(self.vae.dtype)
861+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
855862
if use_resolution_binning:
856863
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
857-
else:
858-
image = latents
859864

860865
if not output_type == "latent":
861866
image = self.image_processor.postprocess(image, output_type=output_type)

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from transformers import AutoModelForCausalLM, AutoTokenizer
2323

2424
from ...image_processor import PixArtImageProcessor
25-
from ...models import DCAE_HF, SanaTransformer2DModel
25+
from ...models import DCAE, SanaTransformer2DModel
2626
from ...schedulers import FlowDPMSolverMultistepScheduler
2727
from ...utils import (
2828
BACKENDS_MAPPING,
@@ -157,7 +157,7 @@ def __init__(
157157
self,
158158
tokenizer: AutoTokenizer,
159159
text_encoder: AutoModelForCausalLM,
160-
vae: DCAE_HF,
160+
vae: DCAE,
161161
transformer: SanaTransformer2DModel,
162162
scheduler: FlowDPMSolverMultistepScheduler,
163163
):
@@ -793,23 +793,27 @@ def __call__(
793793
noise_pred = noise_pred
794794

795795
# compute previous image: x_t -> x_t-1
796+
latents_dtype = latents.dtype
796797
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
797798

799+
if latents.dtype != latents_dtype:
800+
if torch.backends.mps.is_available():
801+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
802+
latents = latents.to(latents_dtype)
798803
# call the callback, if provided
799804
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
800805
progress_bar.update()
801806
if callback is not None and i % callback_steps == 0:
802807
step_idx = i // getattr(self.scheduler, "order", 1)
803808
callback(step_idx, t, latents)
804809

805-
if not output_type == "latent":
806-
# image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
807-
# Temporary for DCAE_HF(the not ready version)
808-
image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor)
810+
if output_type == "latent":
811+
image = latents
812+
else:
813+
latents = latents.to(self.vae.dtype)
814+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
809815
if use_resolution_binning:
810816
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
811-
else:
812-
image = latents
813817

814818
if not output_type == "latent":
815819
image = self.image_processor.postprocess(image, output_type=output_type)

0 commit comments

Comments
 (0)