Skip to content

Commit cbbac06

Browse files
zRzRzRzRzRzRzRa-r-r-o-wsayakpaul
authored
add some script of lora test (#66)
* multi resolutions support * full chinese readme * Update README.md * Update README.md Co-authored-by: Sayak Paul <[email protected]> * reformat from pycharm * dataset.md * Update README.md * for merge * mergeing * torch update for use * add test lora script * Update test_lora_inference.py * Update requirements.txt * Update requirements.txt --------- Co-authored-by: Aryan <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent bd40627 commit cbbac06

File tree

5 files changed

+57
-4
lines changed

5 files changed

+57
-4
lines changed

README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ video = pipe("<my-awesome-prompt>").frames[0]
5353
export_to_video(video, "output.mp4", fps=8)
5454
```
5555

56+
You can also check if your LoRA is correctly mounted [here](tests/test_lora_inference.py).
57+
5658
**Note:** For Image-to-Video finetuning, you must install diffusers from [this](https://github.com/huggingface/diffusers/pull/9482) branch (which adds lora loading support in CogVideoX image-to-video) until it is merged.
5759

5860
Below we provide additional sections detailing on more options explored in this repository. They all attempt to make fine-tuning for video models as accessible as possible by reducing memory requirements as much as possible.

README_zh.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ video = pipe("<my-awesome-prompt>").frames[0]
5151
export_to_video(video, "output.mp4", fps=8)
5252
```
5353

54+
你也可以在[这里](tests/test_lora_inference.py)来检查你的Lora是否正常挂载。
55+
5456
**注意:** 对于图像到视频的微调,您必须从 [这个分支](https://github.com/huggingface/diffusers/pull/9482) 安装
5557
diffusers(该分支为 CogVideoX 的图像到视频添加了 LoRA 加载支持)直到它被合并。
5658

requirements.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@ peft>=0.12.0
88
decord>=0.6.0
99
wandb
1010
pandas
11-
torch>=2.4.0
12-
torchvision>=0.19.0
11+
torch<2.5.0
12+
torchvision<0.20.0
1313
torchao>=0.5.0
1414
sentencepiece>=0.2.0
1515
imageio-ffmpeg>=0.5.1

tests/test_lora_inference.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""
2+
Run this test in Lora adpater checking:
3+
4+
```shell
5+
python3 test_lora_inference.py --prompt "A girl is ridding a bike." --model_path "THUDM/CogVideoX-5B" --lora_path "path/to/lora" --lora_name "lora_adapter" --output_file "output.mp4" --fps 8
6+
```
7+
8+
"""
9+
10+
import argparse
11+
import torch
12+
from diffusers import CogVideoXPipeline
13+
from diffusers.utils import export_to_video
14+
15+
16+
def generate_video(model_path, prompt, lora_path, lora_name, output_file, fps):
17+
pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16).to("cuda")
18+
pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name=lora_name)
19+
pipe.set_adapters([lora_name], [1.0])
20+
pipe.enable_model_cpu_offload()
21+
pipe.vae.enable_slicing()
22+
pipe.vae.enable_tiling()
23+
24+
video = pipe(prompt=prompt).frames[0]
25+
export_to_video(video, output_file, fps=fps)
26+
27+
28+
def main():
29+
parser = argparse.ArgumentParser(description="Generate video using CogVideoX and LoRA weights")
30+
parser.add_argument("--prompt", type=str, required=True, help="Text prompt for the video generation")
31+
parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX-5B", help="Base Model path or HF ID")
32+
parser.add_argument("--lora_path", type=str, required=True, help="Path to the LoRA weights")
33+
parser.add_argument("--lora_name", type=str, default="lora_adapter", help="Name of the LoRA adapter")
34+
parser.add_argument("--output_file", type=str, default="output.mp4", help="Output video file name")
35+
parser.add_argument("--fps", type=int, default=8, help="Frames per second for the output video")
36+
37+
args = parser.parse_args()
38+
39+
generate_video(args.prompt, args.lora_path, args.lora_name, args.output_file, args.fps)
40+
41+
42+
if __name__ == "__main__":
43+
main()

train_text_to_video_lora.sh

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,22 @@ ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml"
2020
# This example assumes you downloaded an already prepared dataset from HF CLI as follows:
2121
# huggingface-cli download --repo-type dataset Wild-Heart/Disney-VideoGeneration-Dataset --local-dir /path/to/my/datasets/disney-dataset
2222
DATA_ROOT="/path/to/my/datasets/disney-dataset"
23+
2324
CAPTION_COLUMN="prompt.txt"
2425
VIDEO_COLUMN="videos.txt"
26+
MODEL_PATH="THUDM/CogVideoX-5b"
2527

28+
# Set ` --load_tensors ` to load tensors from disk instead of recomputing the encoder process.
2629
# Launch experiments with different hyperparameters
30+
2731
for learning_rate in "${LEARNING_RATES[@]}"; do
2832
for lr_schedule in "${LR_SCHEDULES[@]}"; do
2933
for optimizer in "${OPTIMIZERS[@]}"; do
3034
for steps in "${MAX_TRAIN_STEPS[@]}"; do
31-
output_dir="/path/to/my/models/cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
35+
output_dir="./cogvideox-lora__optimizer_${optimizer}__steps_${steps}__lr-schedule_${lr_schedule}__learning-rate_${learning_rate}/"
3236

3337
cmd="accelerate launch --config_file $ACCELERATE_CONFIG_FILE --gpu_ids $GPU_IDS training/cogvideox_text_to_video_lora.py \
34-
--pretrained_model_name_or_path THUDM/CogVideoX-5b \
38+
--pretrained_model_name_or_path $MODEL_PATH \
3539
--data_root $DATA_ROOT \
3640
--caption_column $CAPTION_COLUMN \
3741
--video_column $VIDEO_COLUMN \
@@ -62,6 +66,8 @@ for learning_rate in "${LEARNING_RATES[@]}"; do
6266
--lr_num_cycles 1 \
6367
--enable_slicing \
6468
--enable_tiling \
69+
--enable_model_cpu_offload \
70+
--load_tensors \
6571
--optimizer $optimizer \
6672
--beta1 0.9 \
6773
--beta2 0.95 \

0 commit comments

Comments
 (0)