Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions docs/models/hunyuan_video.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,6 @@ On Windows, you will have to modify the script to a compatible format to run it.
Assuming your LoRA is saved and pushed to the HF Hub, and named `my-awesome-name/my-awesome-lora`, we can now use the finetuned model for inference:

```py
import torch
from diffusers import HunyuanVideoPipeline

import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
Expand Down
139 changes: 131 additions & 8 deletions finetrainers/utils/hub.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import List, Union
from typing import List, Union, Callable

import numpy as np
import wandb
Expand All @@ -8,12 +8,127 @@
from PIL import Image


# Define inference examples as template functions to allow customization
def get_ltx_inference(pretrained_model_name_or_path, repo_id, validation_prompt):
return f"""
import torch
from diffusers import LTXPipeline
from diffusers.utils import export_to_video

pipe = LTXPipeline.from_pretrained(
"{pretrained_model_name_or_path}", torch_dtype=torch.bfloat16
).to("cuda")
pipe.load_lora_weights("{repo_id}", adapter_name="ltxv-lora")
pipe.set_adapters(["ltxv-lora"], [0.75])

video = pipe("{validation_prompt}") # Custom prompt used here
export_to_video(video, "output.mp4", fps=8)
"""


def get_hunyuan_inference(pretrained_model_name_or_path, repo_id, validation_prompt):
return f"""
import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video

model_id = "hunyuanvideo-community/HunyuanVideo"
transformer = HunyuanVideoTransformer3DModel.from_pretrained(
model_id, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained("{pretrained_model_name_or_path}", transformer=transformer, torch_dtype=torch.float16)
pipe.load_lora_weights("{repo_id}", adapter_name="hunyuanvideo-lora")
pipe.set_adapters(["hunyuanvideo-lora"], [0.6])
pipe.vae.enable_tiling()
pipe.to("cuda")

output = pipe(
prompt="{validation_prompt}", # Custom prompt used here
height=320,
width=512,
num_frames=61,
num_inference_steps=30,
).frames[0]
export_to_video(output, "output.mp4", fps=15)
"""


def get_wan_inference(pretrained_model_name_or_path, repo_id, validation_prompt):
return f"""
import torch
from diffusers import WanPipeline
from diffusers.utils import export_to_video

pipe = WanPipeline.from_pretrained(
"{pretrained_model_name_or_path}", torch_dtype=torch.bfloat16
).to("cuda")
pipe.load_lora_weights("{repo_id}", adapter_name="wan-lora")
pipe.set_adapters(["wan-lora"], [0.75])

video = pipe("{validation_prompt}") # Custom prompt used here
export_to_video(video, "output.mp4", fps=8)
"""


def get_cog_video_inference(pretrained_model_name_or_path, repo_id, validation_prompt):
return f"""
import torch
from diffusers import CogVideoXPipeline
from diffusers.utils import export_to_video

pipe = CogVideoXPipeline.from_pretrained(
"{pretrained_model_name_or_path}", torch_dtype=torch.bfloat16
).to("cuda")
pipe.load_lora_weights("{repo_id}", adapter_name="cogvideox-lora")
pipe.set_adapters(["cogvideox-lora"], [0.75])

video = pipe("{validation_prompt}") # Custom prompt used here
export_to_video(video, "output.mp4")
"""


def get_cogview4_inference(pretrained_model_name_or_path, repo_id, validation_prompt):
return f"""
import torch
from diffusers import CogView4Pipeline
from diffusers.utils import export_to_video

pipe = CogView4Pipeline.from_pretrained(
"{pretrained_model_name_or_path}", torch_dtype=torch.bfloat16
).to("cuda")
pipe.load_lora_weights("{repo_id}", adapter_name="cogview4-lora")
pipe.set_adapters(["cogview4-lora"], [0.9])

video = pipe("{validation_prompt}") # Custom prompt used here
export_to_video(video, "output.mp4")
"""


# Model path to inference example generator function mapping
MODEL_INFERENCE_MAP = {
"Lightricks/LTX-Video": get_ltx_inference,
"hunyuanvideo-community/HunyuanVideo": get_hunyuan_inference,
"Wan-AI/Wan2.1-T2V-1.3B-Diffusers": get_wan_inference,
"THUDM/CogVideoX-5b": get_cog_video_inference,
"THUDM/CogView4-6B": get_cogview4_inference,
}


# get the appropriate inference example based on the model path
def get_inference_example(model_path, args, repo_id, validation_prompt):

if model_path in MODEL_INFERENCE_MAP:
return MODEL_INFERENCE_MAP[model_path](args, repo_id, validation_prompt)

return ""


def save_model_card(
args,
repo_id: str,
videos: Union[List[str], Union[List[Image.Image], List[np.ndarray]]],
validation_prompts: List[str],
fps: int = 30,
args,
repo_id: str,
videos: Union[List[str], Union[List[Image.Image], List[np.ndarray]]],
validation_prompts: List[str],
fps: int = 30,
) -> None:
widget_dict = []
output_dir = str(args.output_dir)
Expand All @@ -28,6 +143,14 @@ def save_model_card(
}
)

# get the appropriate inference example based on the model path and parameters
validation_prompt = validation_prompts[0] if validation_prompts else "my-awesome-prompt"
inference_example = get_inference_example(
args.pretrained_model_name_or_path,
repo_id,
validation_prompt
)

model_description = f"""
# LoRA Finetune

Expand All @@ -48,7 +171,7 @@ def save_model_card(
Requires the [🧨 Diffusers library](https://github.com/huggingface/diffusers) installed.

```py
TODO
{inference_example}
```

For more details, including weighting, merging and fusing LoRAs, check the [documentation](https://huggingface.co/docs/diffusers/main/en/using-diffusers/loading_adapters) on loading LoRAs in diffusers.
Expand All @@ -74,4 +197,4 @@ def save_model_card(
]

model_card = populate_model_card(model_card, tags=tags)
model_card.save(os.path.join(args.output_dir, "README.md"))
model_card.save(os.path.join(args.output_dir, "README.md"))