|
| 1 | +""" |
| 2 | +Run this test in Lora adpater checking: |
| 3 | +
|
| 4 | +```shell |
| 5 | +python3 test_lora_inference.py --prompt "A girl is ridding a bike." --model_path "THUDM/CogVideoX-5B" --lora_path "path/to/lora" --lora_name "lora_adapter" --output_file "output.mp4" --fps 8 |
| 6 | +``` |
| 7 | +
|
| 8 | +""" |
| 9 | + |
| 10 | +import argparse |
| 11 | +import torch |
| 12 | +from diffusers import CogVideoXPipeline |
| 13 | +from diffusers.utils import export_to_video |
| 14 | + |
| 15 | + |
| 16 | +def generate_video(model_path, prompt, lora_path, lora_name, output_file, fps): |
| 17 | + pipe = CogVideoXPipeline.from_pretrained(model_path, torch_dtype=torch.bfloat16).to("cuda") |
| 18 | + pipe.load_lora_weights(lora_path, weight_name="pytorch_lora_weights.safetensors", adapter_name=lora_name) |
| 19 | + pipe.set_adapters([lora_name], [1.0]) |
| 20 | + pipe.enable_model_cpu_offload() |
| 21 | + pipe.vae.enable_slicing() |
| 22 | + pipe.vae.enable_tiling() |
| 23 | + |
| 24 | + video = pipe(prompt=prompt).frames[0] |
| 25 | + export_to_video(video, output_file, fps=fps) |
| 26 | + |
| 27 | + |
| 28 | +def main(): |
| 29 | + parser = argparse.ArgumentParser(description="Generate video using CogVideoX and LoRA weights") |
| 30 | + parser.add_argument("--prompt", type=str, required=True, help="Text prompt for the video generation") |
| 31 | + parser.add_argument("--model_path", type=str, default="THUDM/CogVideoX-5B", help="Base Model path or HF ID") |
| 32 | + parser.add_argument("--lora_path", type=str, required=True, help="Path to the LoRA weights") |
| 33 | + parser.add_argument("--lora_name", type=str, default="lora_adapter", help="Name of the LoRA adapter") |
| 34 | + parser.add_argument("--output_file", type=str, default="output.mp4", help="Output video file name") |
| 35 | + parser.add_argument("--fps", type=int, default=8, help="Frames per second for the output video") |
| 36 | + |
| 37 | + args = parser.parse_args() |
| 38 | + |
| 39 | + generate_video(args.prompt, args.lora_path, args.lora_name, args.output_file, args.fps) |
| 40 | + |
| 41 | + |
| 42 | +if __name__ == "__main__": |
| 43 | + main() |
0 commit comments