Skip to content

Commit 4513375

Browse files
add push to HuggingFace Hub capability (#28)
* add `push_to_hub` support --------- Co-authored-by: linoytsaban <linoy@huggingface.co>
1 parent 144769c commit 4513375

File tree

6 files changed

+326
-6
lines changed

6 files changed

+326
-6
lines changed

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -302,6 +302,27 @@ The trainer loads your configuration, initializes models, applies optimizations,
302302
For LoRA training, the weights will be saved as `lora_weights.safetensors` in your output directory.
303303
For full model fine-tuning, the weights will be saved as `model_weights.safetensors`.
304304

305+
### 🤗 Pushing Models to Hugging Face Hub
306+
307+
You can automatically push your trained models to the Hugging Face Hub by adding the following to your configuration YAML:
308+
309+
```yaml
310+
hub:
311+
push_to_hub: true
312+
hub_model_id: "your-username/your-model-name" # Your HF username and desired repo name
313+
```
314+
315+
Before pushing, make sure you:
316+
1. Have a Hugging Face account
317+
2. Are logged in via `huggingface-cli login` or have set the `HUGGING_FACE_HUB_TOKEN` environment variable
318+
3. Have write access to the specified repository (it will be created if it doesn't exist)
319+
320+
The trainer will:
321+
- Create a model card with training details and sample outputs
322+
- Upload the model weights (both original and ComfyUI-compatible versions)
323+
- Push sample videos as GIFs in the model card
324+
- Include training configuration and prompts
325+
305326
---
306327

307328
## Fast and simple: Running the Complete Pipeline as one command

src/ltxv_trainer/config.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from pathlib import Path
22
from typing import Literal
33

4-
from pydantic import BaseModel, ConfigDict, Field, field_validator
4+
from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator
55

66
from ltxv_trainer.model_loader import LtxvModelVersion
77
from ltxv_trainer.quantization import QuantizationOptions
@@ -246,6 +246,22 @@ class CheckpointsConfig(ConfigBaseModel):
246246
)
247247

248248

249+
class HubConfig(ConfigBaseModel):
250+
"""Configuration for Hugging Face Hub integration"""
251+
252+
push_to_hub: bool = Field(default=False, description="Whether to push the model weights to the Hugging Face Hub")
253+
hub_model_id: str | None = Field(
254+
default=None, description="Hugging Face Hub repository ID (e.g., 'username/repo-name')"
255+
)
256+
257+
@model_validator(mode="after")
258+
def validate_hub_config(self) -> "HubConfig":
259+
"""Validate that hub_model_id is not None when push_to_hub is True."""
260+
if self.push_to_hub and not self.hub_model_id:
261+
raise ValueError("hub_model_id must be specified when push_to_hub is True")
262+
return self
263+
264+
249265
class FlowMatchingConfig(ConfigBaseModel):
250266
"""Configuration for flow matching training"""
251267

@@ -271,6 +287,7 @@ class LtxvTrainerConfig(ConfigBaseModel):
271287
data: DataConfig = Field(default_factory=DataConfig)
272288
validation: ValidationConfig = Field(default_factory=ValidationConfig)
273289
checkpoints: CheckpointsConfig = Field(default_factory=CheckpointsConfig)
290+
hub: HubConfig = Field(default_factory=HubConfig)
274291
flow_matching: FlowMatchingConfig = Field(default_factory=FlowMatchingConfig)
275292

276293
# General configuration

src/ltxv_trainer/hub_utils.py

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
import tempfile
2+
from pathlib import Path
3+
from typing import List, Union
4+
5+
import imageio
6+
from huggingface_hub import HfApi, create_repo
7+
from loguru import logger
8+
9+
from ltxv_trainer.config import LtxvTrainerConfig
10+
from ltxv_trainer.model_loader import try_parse_version
11+
from scripts.convert_checkpoint import convert_checkpoint
12+
13+
14+
def convert_video_to_gif(video_path: Path, output_path: Path) -> None:
15+
"""Convert a video file to GIF format."""
16+
try:
17+
# Read the video file
18+
reader = imageio.get_reader(str(video_path))
19+
fps = reader.get_meta_data()["fps"]
20+
21+
# Write GIF file with infinite loop
22+
writer = imageio.get_writer(
23+
str(output_path),
24+
fps=min(fps, 15), # Cap FPS at 15 for reasonable file size
25+
loop=0, # 0 means infinite loop
26+
)
27+
28+
for frame in reader:
29+
writer.append_data(frame)
30+
31+
writer.close()
32+
reader.close()
33+
except Exception as e:
34+
logger.error(f"Failed to convert video to GIF: {e}")
35+
return None
36+
37+
38+
def create_model_card(
39+
output_dir: Union[str, Path],
40+
videos: List[Path],
41+
config: LtxvTrainerConfig,
42+
) -> Path:
43+
"""Generate and save a model card for the trained model."""
44+
45+
repo_id = config.hub.hub_model_id
46+
pretrained_model_name_or_path = config.model.model_source
47+
validation_prompts = config.validation.prompts
48+
output_dir = Path(output_dir)
49+
template_path = Path(__file__).parent.parent.parent / "templates" / "model_card.md"
50+
51+
if not template_path.exists():
52+
logger.warning("⚠️ Model card template not found, using default template")
53+
return
54+
55+
# Read the template
56+
template = template_path.read_text()
57+
58+
# Get model name from repo_id
59+
model_name = repo_id.split("/")[-1]
60+
61+
# Get base model information
62+
version = try_parse_version(pretrained_model_name_or_path)
63+
if version:
64+
base_model_link = version.safetensors_url
65+
base_model_name = str(version)
66+
else:
67+
base_model_link = f"https://huggingface.co/{pretrained_model_name_or_path}"
68+
base_model_name = pretrained_model_name_or_path
69+
70+
# Format validation prompts and create grid layout
71+
prompts_text = ""
72+
sample_grid = []
73+
74+
if validation_prompts and videos:
75+
prompts_text = "Example prompts used during validation:\n\n"
76+
77+
# Create samples directory
78+
samples_dir = output_dir / "samples"
79+
samples_dir.mkdir(exist_ok=True, parents=True)
80+
81+
# Process videos and create cells
82+
cells = []
83+
for i, (prompt, video) in enumerate(zip(validation_prompts, videos, strict=False)):
84+
if video.exists():
85+
# Add prompt to text section
86+
prompts_text += f"- `{prompt}`\n"
87+
88+
# Convert video to GIF
89+
gif_path = samples_dir / f"sample_{i}.gif"
90+
try:
91+
convert_video_to_gif(video, gif_path)
92+
93+
# Create grid cell with collapsible description
94+
cell = (
95+
f"![example{i + 1}](./samples/sample_{i}.gif)"
96+
"<br>"
97+
'<details style="max-width: 300px; margin: auto;">'
98+
f"<summary>Prompt</summary>"
99+
f"{prompt}"
100+
"</details>"
101+
)
102+
cells.append(cell)
103+
except Exception as e:
104+
logger.error(f"Failed to process video {video}: {e}")
105+
106+
# Calculate optimal grid dimensions
107+
num_cells = len(cells)
108+
if num_cells > 0:
109+
# Aim for a roughly square grid, with max 4 columns
110+
num_cols = min(4, num_cells)
111+
num_rows = (num_cells + num_cols - 1) // num_cols # Ceiling division
112+
113+
# Create grid rows
114+
for row in range(num_rows):
115+
start_idx = row * num_cols
116+
end_idx = min(start_idx + num_cols, num_cells)
117+
row_cells = cells[start_idx:end_idx]
118+
# Properly format the row with table markers and exact number of cells
119+
formatted_row = "| " + " | ".join(row_cells) + " |"
120+
sample_grid.append(formatted_row)
121+
122+
# Join grid rows with just the content, no headers needed
123+
grid_text = "\n".join(sample_grid) if sample_grid else ""
124+
125+
# Fill in the template
126+
model_card_content = template.format(
127+
base_model=base_model_name,
128+
base_model_link=base_model_link,
129+
model_name=model_name,
130+
training_type="LoRA fine-tuning" if config.model.training_mode == "lora" else "Full model fine-tuning",
131+
training_steps=config.optimization.steps,
132+
learning_rate=config.optimization.learning_rate,
133+
batch_size=config.optimization.batch_size,
134+
validation_prompts=prompts_text,
135+
sample_grid=grid_text,
136+
)
137+
138+
# Save the model card directly
139+
model_card_path = output_dir / "README.md"
140+
model_card_path.write_text(model_card_content)
141+
142+
return model_card_path
143+
144+
145+
def push_to_hub(weights_path: Path, sampled_videos_paths: List[Path], config: LtxvTrainerConfig) -> None:
146+
"""Push the trained LoRA weights to HuggingFace Hub."""
147+
if not config.hub.push_to_hub:
148+
return
149+
150+
if not config.hub.hub_model_id:
151+
logger.warning("⚠️ HuggingFace hub_model_id not specified, skipping push to hub")
152+
return
153+
154+
api = HfApi()
155+
156+
# Try to create repo if it doesn't exist
157+
try:
158+
create_repo(
159+
repo_id=config.hub.hub_model_id,
160+
repo_type="model",
161+
exist_ok=True, # Don't raise error if repo exists
162+
)
163+
except Exception as e:
164+
logger.error(f"❌ Failed to create repository: {e}")
165+
return
166+
167+
# Upload the original weights file
168+
try:
169+
api.upload_file(
170+
path_or_fileobj=str(weights_path),
171+
path_in_repo=weights_path.name,
172+
repo_id=config.hub.hub_model_id,
173+
repo_type="model",
174+
)
175+
except Exception as e:
176+
logger.error(f"❌ Failed to push {weights_path.name} to HuggingFace Hub: {e}")
177+
# Create a temporary directory for the files we want to upload
178+
with tempfile.TemporaryDirectory() as temp_dir:
179+
temp_path = Path(temp_dir)
180+
181+
try:
182+
# Save model card and copy videos to temp directory
183+
create_model_card(
184+
output_dir=temp_path,
185+
videos=sampled_videos_paths,
186+
config=config,
187+
)
188+
189+
# Upload the model card and samples directory
190+
api.upload_folder(
191+
folder_path=str(temp_path), # Convert to string for compatibility
192+
repo_id=config.hub.hub_model_id,
193+
repo_type="model",
194+
)
195+
196+
logger.info(f"✅ Successfully uploaded model card and sample videos to {config.hub.hub_model_id}")
197+
except Exception as e:
198+
logger.error(f"❌ Failed to save/upload model card and videos: {e}")
199+
200+
logger.info(f"✅ Successfully pushed original LoRA weights to {config.hub.hub_model_id}")
201+
202+
# Convert and upload ComfyUI version
203+
try:
204+
# Create a temporary directory for the converted file
205+
with tempfile.TemporaryDirectory() as temp_dir:
206+
# Convert the weights to ComfyUI format
207+
comfy_path = Path(temp_dir) / f"{weights_path.stem}_comfy{weights_path.suffix}"
208+
209+
convert_checkpoint(
210+
input_path=str(weights_path),
211+
to_comfy=True,
212+
output_path=str(comfy_path),
213+
)
214+
215+
# Find the converted file
216+
converted_files = list(Path(temp_dir).glob("*.safetensors"))
217+
if not converted_files:
218+
logger.warning("⚠️ No converted ComfyUI weights found")
219+
return
220+
221+
converted_file = converted_files[0]
222+
comfy_filename = f"comfyui_{weights_path.name}"
223+
224+
# Upload the converted file
225+
api.upload_file(
226+
path_or_fileobj=str(converted_file),
227+
path_in_repo=comfy_filename,
228+
repo_id=config.hub.hub_model_id,
229+
repo_type="model",
230+
)
231+
logger.info(f"✅ Successfully pushed ComfyUI LoRA weights to {config.hub.hub_model_id}")
232+
233+
except Exception as e:
234+
logger.error(f"❌ Failed to convert and push ComfyUI version: {e}")

src/ltxv_trainer/model_loader.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ def load_vae(
160160
"""
161161
if isinstance(source, str): # noqa: SIM102
162162
# Try to parse as version first
163-
if version := _try_parse_version(source):
163+
if version := try_parse_version(source):
164164
source = version
165165

166166
if isinstance(source, LtxvModelVersion):
@@ -217,7 +217,7 @@ def load_transformer(
217217
"""
218218
if isinstance(source, str): # noqa: SIM102
219219
# Try to parse as version first
220-
if version := _try_parse_version(source):
220+
if version := try_parse_version(source):
221221
source = version
222222

223223
if isinstance(source, LtxvModelVersion):
@@ -285,7 +285,7 @@ def load_ltxv_components(
285285
)
286286

287287

288-
def _try_parse_version(source: str | Path) -> LtxvModelVersion | None:
288+
def try_parse_version(source: str | Path) -> LtxvModelVersion | None:
289289
"""
290290
Try to parse a string as an LtxvModelVersion.
291291

src/ltxv_trainer/trainer.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646

4747
from ltxv_trainer.config import LtxvTrainerConfig
4848
from ltxv_trainer.datasets import PrecomputedDataset
49+
from ltxv_trainer.hub_utils import push_to_hub
4950
from ltxv_trainer.model_loader import load_ltxv_components
5051
from ltxv_trainer.quantization import quantize_model
5152
from ltxv_trainer.timestep_samplers import SAMPLERS
@@ -155,6 +156,8 @@ def train( # noqa: PLR0912, PLR0915
155156
# Track when actual training starts (after compilation)
156157
actual_training_start = None
157158

159+
sampled_videos_paths = None
160+
158161
with Live(Panel(Group(train_progress, sample_progress)), refresh_per_second=2):
159162
task = train_progress.add_task(
160163
"Training",
@@ -165,7 +168,7 @@ def train( # noqa: PLR0912, PLR0915
165168
)
166169

167170
if cfg.validation.interval:
168-
self._sample_videos(sample_progress)
171+
sampled_videos_paths = self._sample_videos(sample_progress)
169172

170173
for step in range(cfg.optimization.steps):
171174
# Get next batch, reset the dataloader if needed
@@ -202,7 +205,6 @@ def train( # noqa: PLR0912, PLR0915
202205

203206
if self._lr_scheduler is not None:
204207
self._lr_scheduler.step()
205-
206208
# Run validation if needed
207209
if (
208210
cfg.validation.interval
@@ -291,6 +293,10 @@ def train( # noqa: PLR0912, PLR0915
291293
if self._accelerator.is_main_process:
292294
saved_path = self._save_checkpoint()
293295

296+
# Upload artifacts to hub if enabled
297+
if cfg.hub.push_to_hub:
298+
push_to_hub(saved_path, sampled_videos_paths, self._config)
299+
294300
# Log the training statistics
295301
self._log_training_stats(stats)
296302

0 commit comments

Comments
 (0)