-
Couldn't load subscription status.
- Fork 6.5k
CogVideoX 1.5 #9877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CogVideoX 1.5 #9877
Changes from 40 commits
b02915b
87535d6
b033aad
67cb373
de84a04
e481843
9edddc1
ea56788
d833f72
b87b07e
e254bcb
5e96cae
be80dbf
be8aff7
b94c704
048a5f0
0c98aad
7a1b579
27441fc
7a15767
e2a88cb
8966cb0
f2213e8
8b28232
3587317
17957d0
3dba37f
25a9e1c
a8ec9f2
7990958
2c3b78d
e063e9d
f054c44
3849cae
4d14abb
9c846eb
9ef66d1
23abe7b
f47516d
b4d629d
4a4df63
ea166f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -80,6 +80,8 @@ def replace_up_keys_inplace(key: str, state_dict: Dict[str, Any]): | |
| "post_attn1_layernorm": "norm2.norm", | ||
| "time_embed.0": "time_embedding.linear_1", | ||
| "time_embed.2": "time_embedding.linear_2", | ||
| "ofs_embed.0": "ofs_embedding.linear_1", | ||
| "ofs_embed.2": "ofs_embedding.linear_2", | ||
| "mixins.patch_embed": "patch_embed", | ||
| "mixins.final_layer.norm_final": "norm_out.norm", | ||
| "mixins.final_layer.linear": "proj_out", | ||
|
|
@@ -140,6 +142,7 @@ def convert_transformer( | |
| use_rotary_positional_embeddings: bool, | ||
| i2v: bool, | ||
| dtype: torch.dtype, | ||
| init_kwargs: Dict[str, Any], | ||
| ): | ||
| PREFIX_KEY = "model.diffusion_model." | ||
|
|
||
|
|
@@ -149,7 +152,9 @@ def convert_transformer( | |
| num_layers=num_layers, | ||
| num_attention_heads=num_attention_heads, | ||
| use_rotary_positional_embeddings=use_rotary_positional_embeddings, | ||
| use_learned_positional_embeddings=i2v, | ||
| ofs_embed_dim=512 if (i2v and init_kwargs["patch_size_t"] is not None) else None, # CogVideoX1.5-5B-I2V | ||
| use_learned_positional_embeddings=i2v and init_kwargs["patch_size_t"] is None, # CogVideoX-5B-I2V | ||
| **init_kwargs, | ||
| ).to(dtype=dtype) | ||
|
|
||
| for key in list(original_state_dict.keys()): | ||
|
|
@@ -163,13 +168,18 @@ def convert_transformer( | |
| if special_key not in key: | ||
| continue | ||
| handler_fn_inplace(key, original_state_dict) | ||
|
|
||
| transformer.load_state_dict(original_state_dict, strict=True) | ||
| return transformer | ||
|
|
||
|
|
||
| def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): | ||
| def convert_vae(ckpt_path: str, scaling_factor: float, version: str, dtype: torch.dtype): | ||
| init_kwargs = {"scaling_factor": scaling_factor} | ||
| if version == "1.5": | ||
| init_kwargs.update({"invert_scale_latents": True}) | ||
|
|
||
| original_state_dict = get_state_dict(torch.load(ckpt_path, map_location="cpu", mmap=True)) | ||
| vae = AutoencoderKLCogVideoX(scaling_factor=scaling_factor).to(dtype=dtype) | ||
| vae = AutoencoderKLCogVideoX(**init_kwargs).to(dtype=dtype) | ||
|
|
||
| for key in list(original_state_dict.keys()): | ||
| new_key = key[:] | ||
|
|
@@ -187,6 +197,34 @@ def convert_vae(ckpt_path: str, scaling_factor: float, dtype: torch.dtype): | |
| return vae | ||
|
|
||
|
|
||
| def get_transformer_init_kwargs(version: str): | ||
| if version == "1.0": | ||
| vae_scale_factor_spatial = 8 | ||
| init_kwargs = { | ||
| "patch_size": 2, | ||
| "patch_size_t": None, | ||
| "patch_bias": True, | ||
| "sample_height": 480 // vae_scale_factor_spatial, | ||
| "sample_width": 720 // vae_scale_factor_spatial, | ||
| "sample_frames": 49, | ||
| } | ||
|
|
||
| elif version == "1.5": | ||
| vae_scale_factor_spatial = 8 | ||
| init_kwargs = { | ||
| "patch_size": 2, | ||
| "patch_size_t": 2, | ||
| "patch_bias": False, | ||
| "sample_height": 300, | ||
| "sample_width": 300, | ||
| "sample_frames": 81, | ||
| } | ||
| else: | ||
| raise ValueError("Unsupported version of CogVideoX.") | ||
|
|
||
| return init_kwargs | ||
|
|
||
|
|
||
| def get_args(): | ||
| parser = argparse.ArgumentParser() | ||
| parser.add_argument( | ||
|
|
@@ -202,6 +240,12 @@ def get_args(): | |
| parser.add_argument( | ||
| "--text_encoder_cache_dir", type=str, default=None, help="Path to text encoder cache directory" | ||
| ) | ||
| parser.add_argument( | ||
| "--typecast_text_encoder", | ||
| action="store_true", | ||
| default=False, | ||
| help="Whether or not to apply fp16/bf16 precision to text_encoder", | ||
| ) | ||
| # For CogVideoX-2B, num_layers is 30. For 5B, it is 42 | ||
| parser.add_argument("--num_layers", type=int, default=30, help="Number of transformer blocks") | ||
| # For CogVideoX-2B, num_attention_heads is 30. For 5B, it is 48 | ||
|
|
@@ -214,7 +258,18 @@ def get_args(): | |
| parser.add_argument("--scaling_factor", type=float, default=1.15258426, help="Scaling factor in the VAE") | ||
| # For CogVideoX-2B, snr_shift_scale is 3.0. For 5B, it is 1.0 | ||
| parser.add_argument("--snr_shift_scale", type=float, default=3.0, help="Scaling factor in the VAE") | ||
| parser.add_argument("--i2v", action="store_true", default=False, help="Whether to save the model weights in fp16") | ||
| parser.add_argument( | ||
| "--i2v", | ||
| action="store_true", | ||
| default=False, | ||
| help="Whether the model to be converted is the Image-to-Video version of CogVideoX.", | ||
| ) | ||
| parser.add_argument( | ||
| "--version", | ||
| choices=["1.0", "1.5"], | ||
| default="1.0", | ||
| help="Which version of CogVideoX to use for initializing default modeling parameters.", | ||
| ) | ||
| return parser.parse_args() | ||
|
|
||
|
|
||
|
|
@@ -230,21 +285,27 @@ def get_args(): | |
| dtype = torch.float16 if args.fp16 else torch.bfloat16 if args.bf16 else torch.float32 | ||
|
|
||
| if args.transformer_ckpt_path is not None: | ||
| init_kwargs = get_transformer_init_kwargs(args.version) | ||
| transformer = convert_transformer( | ||
| args.transformer_ckpt_path, | ||
| args.num_layers, | ||
| args.num_attention_heads, | ||
| args.use_rotary_positional_embeddings, | ||
| args.i2v, | ||
| dtype, | ||
| init_kwargs, | ||
| ) | ||
| if args.vae_ckpt_path is not None: | ||
| vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, dtype) | ||
| # Keep VAE in float32 for better quality | ||
| vae = convert_vae(args.vae_ckpt_path, args.scaling_factor, args.version, torch.float32) | ||
|
|
||
| text_encoder_id = "google/t5-v1_1-xxl" | ||
| tokenizer = T5Tokenizer.from_pretrained(text_encoder_id, model_max_length=TOKENIZER_MAX_LENGTH) | ||
| text_encoder = T5EncoderModel.from_pretrained(text_encoder_id, cache_dir=args.text_encoder_cache_dir) | ||
|
|
||
| if args.typecast_text_encoder: | ||
| text_encoder = text_encoder.to(dtype=dtype) | ||
|
|
||
| # Apparently, the conversion does not work anymore without this :shrug: | ||
| for param in text_encoder.parameters(): | ||
| param.data = param.data.contiguous() | ||
|
|
@@ -276,11 +337,6 @@ def get_args(): | |
| scheduler=scheduler, | ||
| ) | ||
|
|
||
| if args.fp16: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Due to the explanation above, we shouldn't typecast all weights in the pipeline. VAE is best in FP32, text encoder could be saved in FP32 but works well at lower precisions as well, and transformer is either in BF16, or FP16 for CogVideoX-2B text-to-video There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Understood, this is the right thing to do. |
||
| pipe = pipe.to(dtype=torch.float16) | ||
| if args.bf16: | ||
| pipe = pipe.to(dtype=torch.bfloat16) | ||
|
|
||
| # We don't use variant here because the model must be run in fp16 (2B) or bf16 (5B). It would be weird | ||
| # for users to specify variant when the default is not fp32 and they want to run with the correct default (which | ||
| # is either fp16/bf16 here). | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1057,6 +1057,7 @@ def __init__( | |
| force_upcast: float = True, | ||
| use_quant_conv: bool = False, | ||
| use_post_quant_conv: bool = False, | ||
| invert_scale_latents: bool = False, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this change is needed - we can just adjust when we scale it, ie.g instead of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the original CogVideoX 1.0 models, we need to do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part is necessary because there is indeed a difference in execution between 1.5 and 1.0 here. The training process for 1.5 forgot to use *scale_factor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why cannot we just update the scale_factor here? for example if it is 5 , we just use 1/5 instead There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't invert it because we use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we cannot update the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, one way or the other, we will need to determine when to multiply by scale factor and when to divide. The encode method needs to support both (multiply for 1.0, divide for Cog 1.5). In the decode method, we just need to support divide (already exists). Even if we invert the scale factor during conversion of VAE, we still need to be able to determine which version of the model is running in order for the encode method to work as expected. As of now, we can do this by a check like |
||
| ): | ||
| super().__init__() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -338,6 +338,7 @@ class CogVideoXPatchEmbed(nn.Module): | |
| def __init__( | ||
| self, | ||
| patch_size: int = 2, | ||
| patch_size_t: Optional[int] = None, | ||
| in_channels: int = 16, | ||
| embed_dim: int = 1920, | ||
| text_embed_dim: int = 4096, | ||
|
|
@@ -355,6 +356,7 @@ def __init__( | |
| super().__init__() | ||
|
|
||
| self.patch_size = patch_size | ||
| self.patch_size_t = patch_size_t | ||
| self.embed_dim = embed_dim | ||
| self.sample_height = sample_height | ||
| self.sample_width = sample_width | ||
|
|
@@ -366,9 +368,15 @@ def __init__( | |
| self.use_positional_embeddings = use_positional_embeddings | ||
| self.use_learned_positional_embeddings = use_learned_positional_embeddings | ||
|
|
||
| self.proj = nn.Conv2d( | ||
| in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias | ||
| ) | ||
| if patch_size_t is None: | ||
| # CogVideoX 1.0 checkpoints | ||
| self.proj = nn.Conv2d( | ||
| in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias | ||
| ) | ||
| else: | ||
| # CogVideoX 1.5 checkpoints | ||
| self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim) | ||
|
|
||
| self.text_proj = nn.Linear(text_embed_dim, embed_dim) | ||
|
|
||
| if use_positional_embeddings or use_learned_positional_embeddings: | ||
|
|
@@ -407,12 +415,24 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): | |
| """ | ||
| text_embeds = self.text_proj(text_embeds) | ||
|
|
||
| batch, num_frames, channels, height, width = image_embeds.shape | ||
| image_embeds = image_embeds.reshape(-1, channels, height, width) | ||
| image_embeds = self.proj(image_embeds) | ||
| image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) | ||
| image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] | ||
| image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] | ||
| batch_size, num_frames, channels, height, width = image_embeds.shape | ||
|
|
||
| if self.patch_size_t is None: | ||
| image_embeds = image_embeds.reshape(-1, channels, height, width) | ||
| image_embeds = self.proj(image_embeds) | ||
| image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:]) | ||
| image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] | ||
| image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] | ||
| else: | ||
| p = self.patch_size | ||
| p_t = self.patch_size_t | ||
|
|
||
| image_embeds = image_embeds.permute(0, 1, 3, 4, 2) | ||
| image_embeds = image_embeds.reshape( | ||
| batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels | ||
| ) | ||
| image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) | ||
| image_embeds = self.proj(image_embeds) | ||
|
|
||
| embeds = torch.cat( | ||
| [text_embeds, image_embeds], dim=1 | ||
|
|
@@ -497,7 +517,14 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens | |
|
|
||
|
|
||
| def get_3d_rotary_pos_embed( | ||
| embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True | ||
| embed_dim, | ||
| crops_coords, | ||
| grid_size, | ||
| temporal_size, | ||
| theta: int = 10000, | ||
| use_real: bool = True, | ||
| grid_type: str = "linspace", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @yiyixuxu |
||
| max_size: Optional[Tuple[int, int]] = None, | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
| """ | ||
| RoPE for video tokens with 3D structure. | ||
|
|
@@ -513,17 +540,30 @@ def get_3d_rotary_pos_embed( | |
| The size of the temporal dimension. | ||
| theta (`float`): | ||
| Scaling factor for frequency computation. | ||
| grid_type (`str`): | ||
| Whether to use "linspace" or "slice" to compute grids. | ||
|
|
||
| Returns: | ||
| `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. | ||
| """ | ||
| if use_real is not True: | ||
| raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") | ||
| start, stop = crops_coords | ||
| grid_size_h, grid_size_w = grid_size | ||
| grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) | ||
| grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) | ||
| grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) | ||
|
|
||
| if grid_type == "linspace": | ||
| start, stop = crops_coords | ||
| grid_size_h, grid_size_w = grid_size | ||
| grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) | ||
| grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) | ||
| grid_t = np.arange(temporal_size, dtype=np.float32) | ||
| grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) | ||
| elif grid_type == "slice": | ||
| max_h, max_w = max_size | ||
| grid_size_h, grid_size_w = grid_size | ||
| grid_h = np.arange(max_h, dtype=np.float32) | ||
| grid_w = np.arange(max_w, dtype=np.float32) | ||
| grid_t = np.arange(temporal_size, dtype=np.float32) | ||
| else: | ||
| raise ValueError("Invalid value passed for `grid_type`.") | ||
|
|
||
| # Compute dimensions for each axis | ||
| dim_t = embed_dim // 4 | ||
|
|
@@ -559,6 +599,12 @@ def combine_time_height_width(freqs_t, freqs_h, freqs_w): | |
| t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t | ||
| h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h | ||
| w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w | ||
|
|
||
| if grid_type == "slice": | ||
| t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size] | ||
| h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h] | ||
| w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w] | ||
|
|
||
| cos = combine_time_height_width(t_cos, h_cos, w_cos) | ||
| sin = combine_time_height_width(t_sin, h_sin, w_sin) | ||
| return cos, sin | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.