-
Couldn't load subscription status.
- Fork 6.5k
CogVideoX 1.5 #9877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CogVideoX 1.5 #9877
Changes from 11 commits
b02915b
87535d6
b033aad
67cb373
de84a04
e481843
9edddc1
ea56788
d833f72
b87b07e
e254bcb
5e96cae
be80dbf
be8aff7
b94c704
048a5f0
0c98aad
7a1b579
27441fc
7a15767
e2a88cb
8966cb0
f2213e8
8b28232
3587317
17957d0
3dba37f
25a9e1c
a8ec9f2
7990958
2c3b78d
e063e9d
f054c44
3849cae
4d14abb
9c846eb
9ef66d1
23abe7b
f47516d
b4d629d
4a4df63
ea166f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -80,6 +80,8 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): | |||||
| "post_attn1_layernorm": "norm2.norm", | ||||||
| "time_embed.0": "time_embedding.linear_1", | ||||||
| "time_embed.2": "time_embedding.linear_2", | ||||||
| "ofs_embed.0": "ofs_embedding.linear_1", | ||||||
| "ofs_embed.2": "ofs_embedding.linear_2", | ||||||
| "mixins.patch_embed": "patch_embed", | ||||||
| "mixins.final_layer.norm_final": "norm_out.norm", | ||||||
| "mixins.final_layer.linear": "proj_out", | ||||||
|
|
@@ -140,6 +142,7 @@ def convert_transformer( | |||||
| use_rotary_positional_embeddings: bool, | ||||||
| i2v: bool, | ||||||
| dtype: torch.dtype, | ||||||
| init_kwargs: Dict[str, Any], | ||||||
| ): | ||||||
| PREFIX_KEY = "model.diffusion_model." | ||||||
|
|
||||||
|
|
@@ -149,7 +152,9 @@ def convert_transformer( | |||||
| num_layers=num_layers, | ||||||
| num_attention_heads=num_attention_heads, | ||||||
| use_rotary_positional_embeddings=use_rotary_positional_embeddings, | ||||||
| use_learned_positional_embeddings=i2v, | ||||||
| ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V | ||||||
| use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V | ||||||
| **init_kwargs, | ||||||
| ).to(dtype=dtype) | ||||||
|
|
||||||
| for key in list(original_state_dict.keys()): | ||||||
|
|
@@ -163,6 +168,7 @@ def convert_transformer( | |||||
| if special_key not in key: | ||||||
| continue | ||||||
| handler_fn_inplace(key, original_state_dict) | ||||||
|
|
||||||
| transformer.load_state_dict(original_state_dict, strict=True) | ||||||
| return transformer | ||||||
|
|
||||||
|
|
@@ -187,6 +193,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): | |||||
| return vae | ||||||
|
|
||||||
|
|
||||||
| def get_init_kwargs(version: str): | ||||||
| if version == "1.0": | ||||||
| vae_scale_factor_spatial = 8 | ||||||
| init_kwargs = { | ||||||
| "patch_size": 2, | ||||||
| "patch_size_t": None, | ||||||
| "patch_bias": True, | ||||||
| "sample_height": 480 // vae_scale_factor_spatial, | ||||||
| "sample_width": 720 // vae_scale_factor_spatial, | ||||||
| "sample_frames": 49, | ||||||
| } | ||||||
|
|
||||||
| elif version == "1.5": | ||||||
| vae_scale_factor_spatial = 8 | ||||||
| init_kwargs = { | ||||||
| "patch_size": 2, | ||||||
| "patch_size_t": 2, | ||||||
| "patch_bias": False, | ||||||
| "sample_height": 768 // vae_scale_factor_spatial, | ||||||
| "sample_width": 1360 // vae_scale_factor_spatial, | ||||||
| "sample_frames": 81, # TODO: Need Test with 161 for 10 seconds | ||||||
|
||||||
| "sample_frames": 81, # TODO: Need Test with 161 for 10 seconds | |
| "sample_frames": 81, |
This is just to determine the default number of frames for sampling, so we do not need to make a modification here (which would affect the config.json of the converted transformer model). Users can still specify 161 frames (in the call to pipeline) for generation normally and we will still be compatible without needing any modifications here.
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@zRzRzRzRzRzRzR This is a bit of a breaking change. The SAT VAE is in fp32 but the diffusers format VAE is in bf16/fp16. This can lead to poorer quality, so it is best to just keep the VAE in fp32 and let users decide what configuration to use. I will open a PR to the other model weight CogVideoX repositories with the updated VAE weights soon.
cc @yiyixuxu @DN6 The VAE quality doesn't take too much of a hit, but best to have the default in FP32 and update all existing checkpoints. Apologies that this slipped through earlier but I definitely notice very minor differences in quality (atleast in training cc @sayakpaul). The transformer modeling weights don't use variants because there is no FP32 weights as training is done in BF16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to the explanation above, we shouldn't typecast all weights in the pipeline. VAE is best in FP32, text encoder could be saved in FP32 but works well at lower precisions as well, and transformer is either in BF16, or FP16 for CogVideoX-2B text-to-video
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood, this is the right thing to do.
Uh oh!
There was an error while loading. Please reload this page.