Skip to content

Commit ee02219

Browse files
a-r-r-o-wsayakpaul
andauthored
readme updates + refactor (#14)
* refactor * update * update * Update README.md Co-authored-by: Sayak Paul <[email protected]> * Update README.md Co-authored-by: Sayak Paul <[email protected]> * address review comments part i * fix adapter name causing peft error * prompts -> prompt * update * add video * update --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent 2a70b16 commit ee02219

11 files changed

+234
-198
lines changed

README.md

Lines changed: 179 additions & 112 deletions
Large diffs are not rendered by default.

assets/CogVideoX-LoRA.webm

472 KB
Binary file not shown.

assets/lora_2b.png

104 KB
Loading

assets/lora_5b.png

107 KB
Loading

assets/sft_2b.png

85.4 KB
Loading

assets/sft_5b.png

87.8 KB
Loading

prepare_dataset.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ NUM_GPUS=8
66

77
# For more details on the expected data format, please refer to the README.
88
DATA_ROOT="/path/to/my/datasets/video-dataset" # This needs to be the path to the base directory where your videos are located.
9-
CAPTION_COLUMN="prompts.txt"
9+
CAPTION_COLUMN="prompt.txt"
1010
VIDEO_COLUMN="videos.txt"
1111
OUTPUT_DIR="/path/to/my/datasets/preprocessed-dataset"
1212
HEIGHT=480

training/cogvideox_image_to_video_lora.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ def save_model_card(
107107
from diffusers.utils import export_to_video, load_image
108108
109109
pipe = CogVideoXImageToVideoPipeline.from_pretrained("THUDM/CogVideoX-5b-I2V", torch_dtype=torch.bfloat16).to("cuda")
110-
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-lora"])
110+
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora")
111111
112112
# The LoRA adapter weights are determined by what was used for training.
113113
# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
@@ -465,36 +465,25 @@ def load_model_hook(models, input_dir):
465465
)
466466

467467
# Dataset and DataLoader
468-
if not args.video_reshape_mode:
469-
train_dataset = VideoDatasetWithResizing(
470-
data_root=args.data_root,
471-
dataset_file=args.dataset_file,
472-
caption_column=args.caption_column,
473-
video_column=args.video_column,
474-
max_num_frames=args.max_num_frames,
475-
id_token=args.id_token,
476-
height_buckets=args.height_buckets,
477-
width_buckets=args.width_buckets,
478-
frame_buckets=args.frame_buckets,
479-
load_tensors=args.load_tensors,
480-
random_flip=args.random_flip,
481-
image_to_video=True,
482-
)
468+
dataset_init_kwargs = {
469+
"data_root": args.data_root,
470+
"dataset_file": args.dataset_file,
471+
"caption_column": args.caption_column,
472+
"video_column": args.video_column,
473+
"max_num_frames": args.max_num_frames,
474+
"id_token": args.id_token,
475+
"height_buckets": args.height_buckets,
476+
"width_buckets": args.width_buckets,
477+
"frame_buckets": args.frame_buckets,
478+
"load_tensors": args.load_tensors,
479+
"random_flip": args.random_flip,
480+
"image_to_video": True,
481+
}
482+
if args.video_reshape_mode is None:
483+
train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
483484
else:
484485
train_dataset = VideoDatasetWithResizeAndRectangleCrop(
485-
video_reshape_mode=args.video_reshape_mode,
486-
data_root=args.data_root,
487-
dataset_file=args.dataset_file,
488-
caption_column=args.caption_column,
489-
video_column=args.video_column,
490-
max_num_frames=args.max_num_frames,
491-
id_token=args.id_token,
492-
height_buckets=args.height_buckets,
493-
width_buckets=args.width_buckets,
494-
frame_buckets=args.frame_buckets,
495-
load_tensors=args.load_tensors,
496-
random_flip=args.random_flip,
497-
image_to_video=True,
486+
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
498487
)
499488

500489
def collate_fn(data):

training/cogvideox_text_to_video_lora.py

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -103,10 +103,10 @@ def save_model_card(
103103
```py
104104
import torch
105105
from diffusers import CogVideoXPipeline
106-
from diffusers import export_to_video
106+
from diffusers.utils import export_to_video
107107
108108
pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16).to("cuda")
109-
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name=["cogvideox-lora"])
109+
pipe.load_lora_weights("{repo_id}", weight_name="pytorch_lora_weights.safetensors", adapter_name="cogvideox-lora")
110110
111111
# The LoRA adapter weights are determined by what was used for training.
112112
# In this case, we assume `--lora_alpha` is 32 and `--rank` is 64.
@@ -462,34 +462,24 @@ def load_model_hook(models, input_dir):
462462
)
463463

464464
# Dataset and DataLoader
465-
if not args.video_reshape_mode:
466-
train_dataset = VideoDatasetWithResizing(
467-
data_root=args.data_root,
468-
dataset_file=args.dataset_file,
469-
caption_column=args.caption_column,
470-
video_column=args.video_column,
471-
max_num_frames=args.max_num_frames,
472-
id_token=args.id_token,
473-
height_buckets=args.height_buckets,
474-
width_buckets=args.width_buckets,
475-
frame_buckets=args.frame_buckets,
476-
load_tensors=args.load_tensors,
477-
random_flip=args.random_flip,
478-
)
465+
dataset_init_kwargs = {
466+
"data_root": args.data_root,
467+
"dataset_file": args.dataset_file,
468+
"caption_column": args.caption_column,
469+
"video_column": args.video_column,
470+
"max_num_frames": args.max_num_frames,
471+
"id_token": args.id_token,
472+
"height_buckets": args.height_buckets,
473+
"width_buckets": args.width_buckets,
474+
"frame_buckets": args.frame_buckets,
475+
"load_tensors": args.load_tensors,
476+
"random_flip": args.random_flip,
477+
}
478+
if args.video_reshape_mode is None:
479+
train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
479480
else:
480481
train_dataset = VideoDatasetWithResizeAndRectangleCrop(
481-
video_reshape_mode=args.video_reshape_mode,
482-
data_root=args.data_root,
483-
dataset_file=args.dataset_file,
484-
caption_column=args.caption_column,
485-
video_column=args.video_column,
486-
max_num_frames=args.max_num_frames,
487-
id_token=args.id_token,
488-
height_buckets=args.height_buckets,
489-
width_buckets=args.width_buckets,
490-
frame_buckets=args.frame_buckets,
491-
load_tensors=args.load_tensors,
492-
random_flip=args.random_flip,
482+
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
493483
)
494484

495485
def collate_fn(data):

training/cogvideox_text_to_video_sft.py

Lines changed: 16 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -426,34 +426,24 @@ def load_model_hook(models, input_dir):
426426
)
427427

428428
# Dataset and DataLoader
429-
if not args.video_reshape_mode:
430-
train_dataset = VideoDatasetWithResizing(
431-
data_root=args.data_root,
432-
dataset_file=args.dataset_file,
433-
caption_column=args.caption_column,
434-
video_column=args.video_column,
435-
max_num_frames=args.max_num_frames,
436-
id_token=args.id_token,
437-
height_buckets=args.height_buckets,
438-
width_buckets=args.width_buckets,
439-
frame_buckets=args.frame_buckets,
440-
load_tensors=args.load_tensors,
441-
random_flip=args.random_flip,
442-
)
429+
dataset_init_kwargs = {
430+
"data_root": args.data_root,
431+
"dataset_file": args.dataset_file,
432+
"caption_column": args.caption_column,
433+
"video_column": args.video_column,
434+
"max_num_frames": args.max_num_frames,
435+
"id_token": args.id_token,
436+
"height_buckets": args.height_buckets,
437+
"width_buckets": args.width_buckets,
438+
"frame_buckets": args.frame_buckets,
439+
"load_tensors": args.load_tensors,
440+
"random_flip": args.random_flip,
441+
}
442+
if args.video_reshape_mode is None:
443+
train_dataset = VideoDatasetWithResizing(**dataset_init_kwargs)
443444
else:
444445
train_dataset = VideoDatasetWithResizeAndRectangleCrop(
445-
video_reshape_mode=args.video_reshape_mode,
446-
data_root=args.data_root,
447-
dataset_file=args.dataset_file,
448-
caption_column=args.caption_column,
449-
video_column=args.video_column,
450-
max_num_frames=args.max_num_frames,
451-
id_token=args.id_token,
452-
height_buckets=args.height_buckets,
453-
width_buckets=args.width_buckets,
454-
frame_buckets=args.frame_buckets,
455-
load_tensors=args.load_tensors,
456-
random_flip=args.random_flip,
446+
video_reshape_mode=args.video_reshape_mode, **dataset_init_kwargs
457447
)
458448

459449
def collate_fn(data):

0 commit comments

Comments
 (0)