Skip to content

Commit abee1ee

Browse files
committed
update Sana for DC-AE's recent commit;
1 parent d3fd40a commit abee1ee

File tree

6 files changed

+45
-66
lines changed

6 files changed

+45
-66
lines changed

scripts/convert_sana_pag_to_diffusers.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from accelerate import init_empty_weights
1010
from diffusers import (
1111
DCAE,
12-
DCAE_HF,
13-
FlowDPMSolverMultistepScheduler,
12+
DPMSolverMultistepScheduler,
1413
FlowMatchEulerDiscreteScheduler,
1514
SanaPAGPipeline,
1615
SanaTransformer2DModel,
@@ -186,27 +185,10 @@ def main(args):
186185
else:
187186
print(colored(f"Saving the whole SanaPAGPipeline containing {args.model_type}", "green", attrs=["bold"]))
188187
# VAE
189-
dc_ae = DCAE_HF.from_pretrained(f"mit-han-lab/dc-ae-f32c32-sana-1.0")
190-
dc_ae_state_dict = dc_ae.state_dict()
191-
dc_ae = DCAE(
192-
in_channels=3,
193-
latent_channels=32,
194-
encoder_width_list=[128, 256, 512, 512, 1024, 1024],
195-
encoder_depth_list=[2, 2, 2, 3, 3, 3],
196-
encoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"],
197-
encoder_norm="rms2d",
198-
encoder_act="silu",
199-
downsample_block_type="Conv",
200-
decoder_width_list=[128, 256, 512, 512, 1024, 1024],
201-
decoder_depth_list=[3, 3, 3, 3, 3, 3],
202-
decoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"],
203-
decoder_norm="rms2d",
204-
decoder_act="silu",
205-
upsample_block_type="InterpolateConv",
206-
scaling_factor=0.41407,
207-
)
208-
dc_ae.load_state_dict(dc_ae_state_dict, strict=True)
209-
dc_ae.to(torch.float32).to(device)
188+
dc_ae = DCAE.from_pretrained(
189+
"Efficient-Large-Model/dc_ae_f32c32_sana_1.0_diffusers",
190+
torch_dtype=torch.float32,
191+
).to(device)
210192

211193
# Text Encoder
212194
text_encoder_model_path = "google/gemma-2-2b-it"
@@ -220,7 +202,11 @@ def main(args):
220202

221203
# Scheduler
222204
if args.scheduler_type == "flow-dpm_solver":
223-
scheduler = FlowDPMSolverMultistepScheduler(flow_shift=flow_shift)
205+
scheduler = DPMSolverMultistepScheduler(
206+
flow_shift=flow_shift,
207+
use_flow_sigmas=True,
208+
prediction_type="flow_prediction",
209+
)
224210
elif args.scheduler_type == "flow-euler":
225211
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
226212
else:

scripts/convert_sana_to_diffusers.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@
99
from accelerate import init_empty_weights
1010
from diffusers import (
1111
DCAE,
12-
DCAE_HF,
13-
FlowDPMSolverMultistepScheduler,
12+
DPMSolverMultistepScheduler,
1413
FlowMatchEulerDiscreteScheduler,
1514
SanaPipeline,
1615
SanaTransformer2DModel,
@@ -186,27 +185,10 @@ def main(args):
186185
else:
187186
print(colored(f"Saving the whole SanaPipeline containing {args.model_type}", "green", attrs=["bold"]))
188187
# VAE
189-
dc_ae = DCAE_HF.from_pretrained(f"mit-han-lab/dc-ae-f32c32-sana-1.0")
190-
dc_ae_state_dict = dc_ae.state_dict()
191-
dc_ae = DCAE(
192-
in_channels=3,
193-
latent_channels=32,
194-
encoder_width_list=[128, 256, 512, 512, 1024, 1024],
195-
encoder_depth_list=[2, 2, 2, 3, 3, 3],
196-
encoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"],
197-
encoder_norm="rms2d",
198-
encoder_act="silu",
199-
downsample_block_type="Conv",
200-
decoder_width_list=[128, 256, 512, 512, 1024, 1024],
201-
decoder_depth_list=[3, 3, 3, 3, 3, 3],
202-
decoder_block_type=["ResBlock", "ResBlock", "ResBlock", "EViTS5_GLU", "EViTS5_GLU", "EViTS5_GLU"],
203-
decoder_norm="rms2d",
204-
decoder_act="silu",
205-
upsample_block_type="InterpolateConv",
206-
scaling_factor=0.41407,
207-
)
208-
dc_ae.load_state_dict(dc_ae_state_dict, strict=True)
209-
dc_ae.to(torch.float32).to(device)
188+
dc_ae = DCAE.from_pretrained(
189+
"Efficient-Large-Model/dc_ae_f32c32_sana_1.0_diffusers",
190+
torch_dtype=torch.float32,
191+
).to(device)
210192

211193
# Text Encoder
212194
text_encoder_model_path = "google/gemma-2-2b-it"
@@ -220,7 +202,11 @@ def main(args):
220202

221203
# Scheduler
222204
if args.scheduler_type == "flow-dpm_solver":
223-
scheduler = FlowDPMSolverMultistepScheduler(flow_shift=flow_shift)
205+
scheduler = DPMSolverMultistepScheduler(
206+
flow_shift=flow_shift,
207+
use_flow_sigmas=True,
208+
prediction_type="flow_prediction",
209+
)
224210
elif args.scheduler_type == "flow-euler":
225211
scheduler = FlowMatchEulerDiscreteScheduler(shift=flow_shift)
226212
else:

src/diffusers/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,6 @@
131131
"UVit2DModel",
132132
"VQModel",
133133
"DCAE",
134-
"DCAE_HF",
135134
]
136135
)
137136
_import_structure["optimization"] = [
@@ -577,7 +576,6 @@
577576
else:
578577
from .models import (
579578
DCAE,
580-
DCAE_HF,
581579
AllegroTransformer3DModel,
582580
AsymmetricAutoencoderKL,
583581
AuraFlowTransformer2DModel,

src/diffusers/models/autoencoders/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,5 @@
88
from .autoencoder_oobleck import AutoencoderOobleck
99
from .autoencoder_tiny import AutoencoderTiny
1010
from .consistency_decoder_vae import ConsistencyDecoderVAE
11-
from .dc_ae import DCAE, DCAE_HF
11+
from .autoencoder_dc import DCAE
1212
from .vq_model import VQModel

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from transformers import AutoModelForCausalLM, AutoTokenizer
2323

2424
from ...image_processor import PixArtImageProcessor
25-
from ...models import DCAE_HF, SanaTransformer2DModel
25+
from ...models import DCAE, SanaTransformer2DModel
2626
from ...models.attention_processor import PAGCFGSanaLinearAttnProcessor2_0, PAGIdentitySanaLinearAttnProcessor2_0
2727
from ...schedulers import FlowDPMSolverMultistepScheduler
2828
from ...utils import (
@@ -162,7 +162,7 @@ def __init__(
162162
self,
163163
tokenizer: AutoTokenizer,
164164
text_encoder: AutoModelForCausalLM,
165-
vae: DCAE_HF,
165+
vae: DCAE,
166166
transformer: SanaTransformer2DModel,
167167
scheduler: FlowDPMSolverMultistepScheduler,
168168
pag_applied_layers: Union[str, List[str]] = "blocks.1", # 1st transformer block
@@ -840,22 +840,27 @@ def __call__(
840840
noise_pred = noise_pred
841841

842842
# compute previous image: x_t -> x_t-1
843+
latents_dtype = latents.dtype
843844
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
844845

846+
if latents.dtype != latents_dtype:
847+
if torch.backends.mps.is_available():
848+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
849+
latents = latents.to(latents_dtype)
845850
# call the callback, if provided
846851
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
847852
progress_bar.update()
848853
if callback is not None and i % callback_steps == 0:
849854
step_idx = i // getattr(self.scheduler, "order", 1)
850855
callback(step_idx, t, latents)
851-
# set to None for next
852856

853-
if not output_type == "latent":
854-
image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor)
857+
if output_type == "latent":
858+
image = latents
859+
else:
860+
latents = latents.to(self.vae.dtype)
861+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
855862
if use_resolution_binning:
856863
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
857-
else:
858-
image = latents
859864

860865
if not output_type == "latent":
861866
image = self.image_processor.postprocess(image, output_type=output_type)

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from transformers import AutoModelForCausalLM, AutoTokenizer
2323

2424
from ...image_processor import PixArtImageProcessor
25-
from ...models import DCAE_HF, SanaTransformer2DModel
25+
from ...models import DCAE, SanaTransformer2DModel
2626
from ...schedulers import FlowDPMSolverMultistepScheduler
2727
from ...utils import (
2828
BACKENDS_MAPPING,
@@ -157,7 +157,7 @@ def __init__(
157157
self,
158158
tokenizer: AutoTokenizer,
159159
text_encoder: AutoModelForCausalLM,
160-
vae: DCAE_HF,
160+
vae: DCAE,
161161
transformer: SanaTransformer2DModel,
162162
scheduler: FlowDPMSolverMultistepScheduler,
163163
):
@@ -793,23 +793,27 @@ def __call__(
793793
noise_pred = noise_pred
794794

795795
# compute previous image: x_t -> x_t-1
796+
latents_dtype = latents.dtype
796797
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
797798

799+
if latents.dtype != latents_dtype:
800+
if torch.backends.mps.is_available():
801+
# some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
802+
latents = latents.to(latents_dtype)
798803
# call the callback, if provided
799804
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
800805
progress_bar.update()
801806
if callback is not None and i % callback_steps == 0:
802807
step_idx = i // getattr(self.scheduler, "order", 1)
803808
callback(step_idx, t, latents)
804809

805-
if not output_type == "latent":
806-
# image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
807-
# Temporary for DCAE_HF(the not ready version)
808-
image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor)
810+
if output_type == "latent":
811+
image = latents
812+
else:
813+
latents = latents.to(self.vae.dtype)
814+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
809815
if use_resolution_binning:
810816
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
811-
else:
812-
image = latents
813817

814818
if not output_type == "latent":
815819
image = self.image_processor.postprocess(image, output_type=output_type)

0 commit comments

Comments
 (0)