-
Couldn't load subscription status.
- Fork 6.5k
CogVideoX 1.5 #9877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CogVideoX 1.5 #9877
Changes from all commits
b02915b
87535d6
b033aad
67cb373
de84a04
e481843
9edddc1
ea56788
d833f72
b87b07e
e254bcb
5e96cae
be80dbf
be8aff7
b94c704
048a5f0
0c98aad
7a1b579
27441fc
7a15767
e2a88cb
8966cb0
f2213e8
8b28232
3587317
17957d0
3dba37f
25a9e1c
a8ec9f2
7990958
2c3b78d
e063e9d
f054c44
3849cae
4d14abb
9c846eb
9ef66d1
23abe7b
f47516d
b4d629d
4a4df63
ea166f8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1057,6 +1057,7 @@ def __init__( | |
| force_upcast: float = True, | ||
| use_quant_conv: bool = False, | ||
| use_post_quant_conv: bool = False, | ||
| invert_scale_latents: bool = False, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think this change is needed - we can just adjust when we scale it, ie.g instead of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the original CogVideoX 1.0 models, we need to do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This part is necessary because there is indeed a difference in execution between 1.5 and 1.0 here. The training process for 1.5 forgot to use *scale_factor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why cannot we just update the scale_factor here? for example if it is 5 , we just use 1/5 instead There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can't invert it because we use the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we cannot update the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, one way or the other, we will need to determine when to multiply by scale factor and when to divide. The encode method needs to support both (multiply for 1.0, divide for Cog 1.5). In the decode method, we just need to support divide (already exists). Even if we invert the scale factor during conversion of VAE, we still need to be able to determine which version of the model is running in order for the encode method to work as expected. As of now, we can do this by a check like |
||
| ): | ||
| super().__init__() | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -338,6 +338,7 @@ class CogVideoXPatchEmbed(nn.Module): | |
| def __init__( | ||
| self, | ||
| patch_size: int = 2, | ||
| patch_size_t: Optional[int] = None, | ||
| in_channels: int = 16, | ||
| embed_dim: int = 1920, | ||
| text_embed_dim: int = 4096, | ||
|
|
@@ -355,6 +356,7 @@ def __init__( | |
| super().__init__() | ||
|
|
||
| self.patch_size = patch_size | ||
| self.patch_size_t = patch_size_t | ||
| self.embed_dim = embed_dim | ||
| self.sample_height = sample_height | ||
| self.sample_width = sample_width | ||
|
|
@@ -366,9 +368,15 @@ def __init__( | |
| self.use_positional_embeddings = use_positional_embeddings | ||
| self.use_learned_positional_embeddings = use_learned_positional_embeddings | ||
|
|
||
| self.proj = nn.Conv2d( | ||
| in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias | ||
| ) | ||
| if patch_size_t is None: | ||
| # CogVideoX 1.0 checkpoints | ||
| self.proj = nn.Conv2d( | ||
| in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias | ||
| ) | ||
| else: | ||
| # CogVideoX 1.5 checkpoints | ||
| self.proj = nn.Linear(in_channels * patch_size * patch_size * patch_size_t, embed_dim) | ||
|
|
||
| self.text_proj = nn.Linear(text_embed_dim, embed_dim) | ||
|
|
||
| if use_positional_embeddings or use_learned_positional_embeddings: | ||
|
|
@@ -407,12 +415,24 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): | |
| """ | ||
| text_embeds = self.text_proj(text_embeds) | ||
|
|
||
| batch, num_frames, channels, height, width = image_embeds.shape | ||
| image_embeds = image_embeds.reshape(-1, channels, height, width) | ||
| image_embeds = self.proj(image_embeds) | ||
| image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) | ||
| image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] | ||
| image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] | ||
| batch_size, num_frames, channels, height, width = image_embeds.shape | ||
|
|
||
| if self.patch_size_t is None: | ||
| image_embeds = image_embeds.reshape(-1, channels, height, width) | ||
| image_embeds = self.proj(image_embeds) | ||
| image_embeds = image_embeds.view(batch_size, num_frames, *image_embeds.shape[1:]) | ||
| image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] | ||
| image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] | ||
| else: | ||
| p = self.patch_size | ||
| p_t = self.patch_size_t | ||
|
|
||
| image_embeds = image_embeds.permute(0, 1, 3, 4, 2) | ||
| image_embeds = image_embeds.reshape( | ||
| batch_size, num_frames // p_t, p_t, height // p, p, width // p, p, channels | ||
| ) | ||
| image_embeds = image_embeds.permute(0, 1, 3, 5, 7, 2, 4, 6).flatten(4, 7).flatten(1, 3) | ||
| image_embeds = self.proj(image_embeds) | ||
|
|
||
| embeds = torch.cat( | ||
| [text_embeds, image_embeds], dim=1 | ||
|
|
@@ -497,7 +517,14 @@ def forward(self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tens | |
|
|
||
|
|
||
| def get_3d_rotary_pos_embed( | ||
| embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True | ||
| embed_dim, | ||
| crops_coords, | ||
| grid_size, | ||
| temporal_size, | ||
| theta: int = 10000, | ||
| use_real: bool = True, | ||
| grid_type: str = "linspace", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @yiyixuxu |
||
| max_size: Optional[Tuple[int, int]] = None, | ||
| ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: | ||
| """ | ||
| RoPE for video tokens with 3D structure. | ||
|
|
@@ -513,17 +540,30 @@ def get_3d_rotary_pos_embed( | |
| The size of the temporal dimension. | ||
| theta (`float`): | ||
| Scaling factor for frequency computation. | ||
| grid_type (`str`): | ||
| Whether to use "linspace" or "slice" to compute grids. | ||
|
|
||
| Returns: | ||
| `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`. | ||
| """ | ||
| if use_real is not True: | ||
| raise ValueError(" `use_real = False` is not currently supported for get_3d_rotary_pos_embed") | ||
| start, stop = crops_coords | ||
| grid_size_h, grid_size_w = grid_size | ||
| grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) | ||
| grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) | ||
| grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) | ||
|
|
||
| if grid_type == "linspace": | ||
| start, stop = crops_coords | ||
| grid_size_h, grid_size_w = grid_size | ||
| grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32) | ||
| grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32) | ||
| grid_t = np.arange(temporal_size, dtype=np.float32) | ||
| grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32) | ||
| elif grid_type == "slice": | ||
| max_h, max_w = max_size | ||
| grid_size_h, grid_size_w = grid_size | ||
| grid_h = np.arange(max_h, dtype=np.float32) | ||
| grid_w = np.arange(max_w, dtype=np.float32) | ||
| grid_t = np.arange(temporal_size, dtype=np.float32) | ||
| else: | ||
| raise ValueError("Invalid value passed for `grid_type`.") | ||
|
|
||
| # Compute dimensions for each axis | ||
| dim_t = embed_dim // 4 | ||
|
|
@@ -559,6 +599,12 @@ def combine_time_height_width(freqs_t, freqs_h, freqs_w): | |
| t_cos, t_sin = freqs_t # both t_cos and t_sin has shape: temporal_size, dim_t | ||
| h_cos, h_sin = freqs_h # both h_cos and h_sin has shape: grid_size_h, dim_h | ||
| w_cos, w_sin = freqs_w # both w_cos and w_sin has shape: grid_size_w, dim_w | ||
|
|
||
| if grid_type == "slice": | ||
| t_cos, t_sin = t_cos[:temporal_size], t_sin[:temporal_size] | ||
| h_cos, h_sin = h_cos[:grid_size_h], h_sin[:grid_size_h] | ||
| w_cos, w_sin = w_cos[:grid_size_w], w_sin[:grid_size_w] | ||
|
|
||
| cos = combine_time_height_width(t_cos, h_cos, w_cos) | ||
| sin = combine_time_height_width(t_sin, h_sin, w_sin) | ||
| return cos, sin | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Due to the explanation above, we shouldn't typecast all weights in the pipeline. VAE is best in FP32, text encoder could be saved in FP32 but works well at lower precisions as well, and transformer is either in BF16, or FP16 for CogVideoX-2B text-to-video
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Understood, this is the right thing to do.