Skip to content

Commit de925be

Browse files
committed
update
1 parent 2dda910 commit de925be

File tree

4 files changed

+161
-38
lines changed

4 files changed

+161
-38
lines changed

scripts/convert_cosmos_to_diffusers.py

Lines changed: 119 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import argparse
2+
import pathlib
23
from typing import Any, Dict
34

45
import torch
56
from accelerate import init_empty_weights
7+
from huggingface_hub import snapshot_download
68
from transformers import T5EncoderModel, T5TokenizerFast
79

8-
from diffusers import CosmosTransformer3DModel, EDMEulerScheduler
10+
from diffusers import AutoencoderKLCosmos, CosmosTransformer3DModel, EDMEulerScheduler
911

1012

1113
def remove_keys_(key: str, state_dict: Dict[str, Any]):
@@ -63,10 +65,81 @@ def rename_transformer_blocks_(key: str, state_dict: Dict[str, Any]):
6365
}
6466

6567
VAE_KEYS_RENAME_DICT = {
66-
"conv3d": "conv",
68+
"down.0": "down_blocks.0",
69+
"down.1": "down_blocks.1",
70+
"down.2": "down_blocks.2",
71+
"up.0": "up_blocks.2",
72+
"up.1": "up_blocks.1",
73+
"up.2": "up_blocks.0",
74+
".block.": ".resnets.",
75+
"downsample": "downsamplers.0",
76+
"upsample": "upsamplers.0",
77+
"mid.block_1": "mid_block.resnets.0",
78+
"mid.attn_1.0": "mid_block.attentions.0",
79+
"mid.attn_1.1": "mid_block.temp_attentions.0",
80+
"mid.block_2": "mid_block.resnets.1",
81+
".q.conv3d": ".to_q",
82+
".k.conv3d": ".to_k",
83+
".v.conv3d": ".to_v",
84+
".proj_out.conv3d": ".to_out.0",
85+
".0.conv3d": ".conv_s",
86+
".1.conv3d": ".conv_t",
87+
"conv1.conv3d": "conv1",
88+
"conv2.conv3d": "conv2",
89+
"conv3.conv3d": "conv3",
90+
"nin_shortcut.conv3d": "conv_shortcut",
91+
"quant_conv.conv3d": "quant_conv",
92+
"post_quant_conv.conv3d": "post_quant_conv",
6793
}
6894

69-
VAE_SPECIAL_KEYS_REMAP = {}
95+
VAE_SPECIAL_KEYS_REMAP = {
96+
"wavelets": remove_keys_,
97+
"_arange": remove_keys_,
98+
"patch_size_buffer": remove_keys_,
99+
}
100+
101+
VAE_CONFIGS = {
102+
"CV8x8x8-0.1": {
103+
"name": "nvidia/Cosmos-0.1-Tokenizer-CV8x8x8",
104+
"diffusers_config": {
105+
"in_channels": 3,
106+
"out_channels": 3,
107+
"latent_channels": 16,
108+
"encoder_block_out_channels": (128, 256, 512, 512),
109+
"decode_block_out_channels": (256, 512, 512, 512),
110+
"attention_resolutions": (32,),
111+
"resolution": 1024,
112+
"num_layers": 2,
113+
"patch_size": 4,
114+
"patch_type": "haar",
115+
"scaling_factor": 1.0,
116+
"spatial_compression_ratio": 8,
117+
"temporal_compression_ratio": 8,
118+
"latents_mean": None,
119+
"latents_std": None,
120+
},
121+
},
122+
"CV8x8x8-1.0": {
123+
"name": "nvidia/Cosmos-1.0-Tokenizer-CV8x8x8",
124+
"diffusers_config": {
125+
"in_channels": 3,
126+
"out_channels": 3,
127+
"latent_channels": 16,
128+
"encoder_block_out_channels": (128, 256, 512, 512),
129+
"decode_block_out_channels": (256, 512, 512, 512),
130+
"attention_resolutions": (32,),
131+
"resolution": 1024,
132+
"num_layers": 2,
133+
"patch_size": 4,
134+
"patch_type": "haar",
135+
"scaling_factor": 1.0,
136+
"spatial_compression_ratio": 8,
137+
"temporal_compression_ratio": 8,
138+
"latents_mean": None,
139+
"latents_std": None,
140+
},
141+
},
142+
}
70143

71144

72145
def get_state_dict(saved_dict: Dict[str, Any]) -> Dict[str, Any]:
@@ -105,36 +178,53 @@ def convert_transformer(ckpt_path: str):
105178
return transformer
106179

107180

108-
# def convert_vae(ckpt_path: str):
109-
# original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=True))
181+
def convert_vae(vae_type: str):
182+
model_name = VAE_CONFIGS[vae_type]["name"]
183+
snapshot_directory = snapshot_download(model_name, repo_type="model")
184+
directory = pathlib.Path(snapshot_directory)
110185

111-
# with init_empty_weights():
112-
# vae = AutoencoderKLHunyuanVideo()
186+
autoencoder_file = directory / "autoencoder.jit"
187+
mean_std_file = directory / "mean_std.pt"
113188

114-
# for key in list(original_state_dict.keys()):
115-
# new_key = key[:]
116-
# for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
117-
# new_key = new_key.replace(replace_key, rename_key)
118-
# update_state_dict_(original_state_dict, key, new_key)
189+
original_state_dict = torch.jit.load(autoencoder_file.as_posix()).state_dict()
190+
if mean_std_file.exists():
191+
mean_std = torch.load(mean_std_file, map_location="cpu", weights_only=True)
192+
else:
193+
mean_std = (None, None)
119194

120-
# for key in list(original_state_dict.keys()):
121-
# for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
122-
# if special_key not in key:
123-
# continue
124-
# handler_fn_inplace(key, original_state_dict)
195+
config = VAE_CONFIGS[vae_type]["diffusers_config"]
196+
config.update(
197+
{
198+
"latents_mean": mean_std[0],
199+
"latents_std": mean_std[1],
200+
}
201+
)
202+
vae = AutoencoderKLCosmos(**config)
125203

126-
# vae.load_state_dict(original_state_dict, strict=True, assign=True)
127-
# return vae
204+
for key in list(original_state_dict.keys()):
205+
new_key = key[:]
206+
for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items():
207+
new_key = new_key.replace(replace_key, rename_key)
208+
update_state_dict_(original_state_dict, key, new_key)
209+
210+
for key in list(original_state_dict.keys()):
211+
for special_key, handler_fn_inplace in VAE_SPECIAL_KEYS_REMAP.items():
212+
if special_key not in key:
213+
continue
214+
handler_fn_inplace(key, original_state_dict)
215+
216+
vae.load_state_dict(original_state_dict, strict=True, assign=True)
217+
return vae
128218

129219

130220
def get_args():
131221
parser = argparse.ArgumentParser()
132222
parser.add_argument(
133223
"--transformer_ckpt_path", type=str, default=None, help="Path to original transformer checkpoint"
134224
)
135-
parser.add_argument("--vae_ckpt_path", type=str, default=None, help="Path to original VAE checkpoint")
136-
parser.add_argument("--text_encoder_path", type=str, default=None, help="Path to original T5 checkpoint")
137-
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path to original T5 tokenizer")
225+
parser.add_argument("--vae_type", type=str, default=None, choices=list(VAE_CONFIGS.keys()), help="Type of VAE")
226+
parser.add_argument("--text_encoder_path", type=str, default=None, help="Path or HF id to original T5 checkpoint")
227+
parser.add_argument("--tokenizer_path", type=str, default=None, help="Path or HF id to original T5 tokenizer")
138228
parser.add_argument("--save_pipeline", action="store_true")
139229
parser.add_argument("--output_path", type=str, required=True, help="Path where converted model should be saved")
140230
parser.add_argument("--dtype", default="bf16", help="Torch dtype to save the transformer in.")
@@ -155,7 +245,8 @@ def get_args():
155245
dtype = DTYPE_MAPPING[args.dtype]
156246

157247
if args.save_pipeline:
158-
assert args.transformer_ckpt_path is not None and args.vae_ckpt_path is not None
248+
assert args.transformer_ckpt_path is not None
249+
assert args.vae_type is not None
159250
assert args.text_encoder_path is not None
160251
assert args.tokenizer_path is not None
161252
assert args.text_encoder_2_path is not None
@@ -166,10 +257,10 @@ def get_args():
166257
if not args.save_pipeline:
167258
transformer.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
168259

169-
# if args.vae_ckpt_path is not None:
170-
# vae = convert_vae(args.vae_ckpt_path)
171-
# if not args.save_pipeline:
172-
# vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
260+
if args.vae_type is not None:
261+
vae = convert_vae(args.vae_type)
262+
if not args.save_pipeline:
263+
vae.save_pretrained(args.output_path, safe_serialization=True, max_shard_size="5GB")
173264

174265
if args.save_pipeline:
175266
text_encoder = T5EncoderModel.from_pretrained(args.text_encoder_path, torch_dtype=dtype)
@@ -184,6 +275,7 @@ def get_args():
184275
num_train_timesteps=1000,
185276
prediction_type="epsilon",
186277
rho=7.0,
278+
final_sigmas_type="sigma_min",
187279
)
188280

189281
# if args.save_pipeline:

src/diffusers/models/autoencoders/autoencoder_kl_cosmos.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,34 @@ class AutoencoderKLCosmos(ModelMixin, ConfigMixin):
853853
Number of output channels.
854854
latent_channels (`int`, defaults to `16`):
855855
Number of latent channels.
856+
encoder_block_out_channels (`Tuple[int, ...]`, defaults to `(128, 256, 512, 512)`):
857+
Number of output channels for each encoder down block.
858+
decode_block_out_channels (`Tuple[int, ...]`, defaults to `(256, 512, 512, 512)`):
859+
Number of output channels for each decoder up block.
860+
attention_resolutions (`Tuple[int, ...]`, defaults to `(32,)`):
861+
List of image/video resolutions at which to apply attention.
862+
resolution (`int`, defaults to `1024`):
863+
Base image/video resolution used for computing whether a block should have attention layers.
864+
num_layers (`int`, defaults to `2`):
865+
Number of resnet blocks in each encoder/decoder block.
866+
patch_size (`int`, defaults to `4`):
867+
Patch size used for patching the input image/video.
868+
patch_type (`str`, defaults to `haar`):
869+
Patch type used for patching the input image/video. Can be either `haar` or `rearrange`.
870+
scaling_factor (`float`, defaults to `1.0`):
871+
The component-wise standard deviation of the trained latent space computed using the first batch of the
872+
training set. This is used to scale the latent space to have unit variance when training the diffusion
873+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
874+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
875+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
876+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. Not applicable in Cosmos,
877+
but we default to 1.0 for consistency.
878+
spatial_compression_ratio (`int`, defaults to `8`):
879+
The spatial compression ratio to apply in the VAE. The number of downsample blocks is determined using
880+
this.
881+
temporal_compression_ratio (`int`, defaults to `8`):
882+
The temporal compression ratio to apply in the VAE. The number of downsample blocks is determined using
883+
this.
856884
"""
857885

858886
_supports_gradient_checkpointing = True

src/diffusers/pipelines/cosmos/pipeline_cosmos.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -595,16 +595,20 @@ def __call__(
595595
self._current_timestep = None
596596

597597
if not output_type == "latent":
598-
latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
599-
latents_mean = torch.tensor(latents_mean).view(1, self.vae.config.latent_channels, -1, 1, 1)[
600-
:, :, : latents.size(2)
601-
]
602-
latents_std = torch.tensor(latents_std).view(1, self.vae.config.latent_channels, -1, 1, 1)[
603-
:, :, : latents.size(2)
604-
]
605-
latents = (
606-
latents * self.vae.config.latent_std / self.scheduler.config.sigma_data + self.vae.config.latent_mean
607-
)
598+
if self.vae.config.latents_mean is not None:
599+
latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std
600+
latents_mean = torch.tensor(latents_mean).view(1, self.vae.config.latent_channels, -1, 1, 1)[
601+
:, :, : latents.size(2)
602+
]
603+
latents_std = torch.tensor(latents_std).view(1, self.vae.config.latent_channels, -1, 1, 1)[
604+
:, :, : latents.size(2)
605+
]
606+
latents = (
607+
latents * self.vae.config.latent_std / self.scheduler.config.sigma_data
608+
+ self.vae.config.latent_mean
609+
)
610+
else:
611+
latents = latents / self.scheduler.config.sigma_data
608612
video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0]
609613
video = self.video_processor.postprocess_video(video, output_type=output_type)
610614
else:

tests/models/autoencoders/test_models_autoencoder_cosmos.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# coding=utf-8
21
# Copyright 2024 HuggingFace Inc.
32
#
43
# Licensed under the Apache License, Version 2.0 (the "License");

0 commit comments

Comments
 (0)