Skip to content

Commit 0b80dba

Browse files
adaption for CogVideoX1.5 (#92)
* adaption for CogVideoX1.5 * add patch_size_t in full finetuning of T2V and lora finetuning of I2V * Update training/args.py Co-authored-by: Sayak Paul <[email protected]> --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent d63a826 commit 0b80dba

File tree

5 files changed

+26
-7
lines changed

5 files changed

+26
-7
lines changed

training/args.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ def _get_dataset_args(parser: argparse.ArgumentParser) -> None:
7878
nargs="+",
7979
type=int,
8080
default=[49],
81+
help="CogVideoX1.5 need to guarantee that ((num_frames - 1) // self.vae_scale_factor_temporal + 1) % patch_size_t != 0, such as 53"
8182
)
8283
parser.add_argument(
8384
"--load_tensors",

training/cogvideox_image_to_video_lora.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,7 @@ def load_model_hook(models, input_dir):
787787
num_frames=num_frames,
788788
vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
789789
patch_size=model_config.patch_size,
790+
patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
790791
attention_head_dim=model_config.attention_head_dim,
791792
device=accelerator.device,
792793
)

training/cogvideox_text_to_video_lora.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -696,6 +696,7 @@ def load_model_hook(models, input_dir):
696696
num_frames=num_frames,
697697
vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
698698
patch_size=model_config.patch_size,
699+
patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
699700
attention_head_dim=model_config.attention_head_dim,
700701
device=accelerator.device,
701702
)

training/cogvideox_text_to_video_sft.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ def load_model_hook(models, input_dir):
662662
num_frames=num_frames,
663663
vae_scale_factor_spatial=VAE_SCALE_FACTOR_SPATIAL,
664664
patch_size=model_config.patch_size,
665+
patch_size_t=model_config.patch_size_t if hasattr(model_config, "patch_size_t") else None,
665666
attention_head_dim=model_config.attention_head_dim,
666667
device=accelerator.device,
667668
)

training/utils.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def prepare_rotary_positional_embeddings(
198198
num_frames: int,
199199
vae_scale_factor_spatial: int = 8,
200200
patch_size: int = 2,
201+
patch_size_t: int = None,
201202
attention_head_dim: int = 64,
202203
device: Optional[torch.device] = None,
203204
base_height: int = 480,
@@ -207,14 +208,28 @@ def prepare_rotary_positional_embeddings(
207208
grid_width = width // (vae_scale_factor_spatial * patch_size)
208209
base_size_width = base_width // (vae_scale_factor_spatial * patch_size)
209210
base_size_height = base_height // (vae_scale_factor_spatial * patch_size)
211+
if patch_size_t is None:
212+
# CogVideoX 1.0
213+
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
214+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
215+
embed_dim=attention_head_dim,
216+
crops_coords=grid_crops_coords,
217+
grid_size=(grid_height, grid_width),
218+
temporal_size=num_frames,
219+
)
220+
else:
221+
# CogVideoX 1.5
222+
base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t
223+
224+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
225+
embed_dim=attention_head_dim,
226+
crops_coords=None,
227+
grid_size=(grid_height, grid_width),
228+
temporal_size=base_num_frames,
229+
grid_type="slice",
230+
max_size=(base_size_height, base_size_width),
231+
)
210232

211-
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size_width, base_size_height)
212-
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
213-
embed_dim=attention_head_dim,
214-
crops_coords=grid_crops_coords,
215-
grid_size=(grid_height, grid_width),
216-
temporal_size=num_frames,
217-
)
218233

219234
freqs_cos = freqs_cos.to(device=device)
220235
freqs_sin = freqs_sin.to(device=device)

0 commit comments

Comments
 (0)