Skip to content

Commit e7b820b

Browse files
committed
trainer: save training config YAML in output dir.
1 parent 2c6db04 commit e7b820b

File tree

1 file changed

+56
-41
lines changed

1 file changed

+56
-41
lines changed

src/ltxv_trainer/trainer.py

Lines changed: 56 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import rich
1212
import torch
1313
import wandb
14+
import yaml
1415
from accelerate import Accelerator
1516
from accelerate.utils import set_seed
1617
from diffusers.utils import export_to_video
@@ -109,50 +110,10 @@ def __init__(self, trainer_config: LtxvTrainerConfig) -> None:
109110
self._init_wandb()
110111
self._training_strategy = get_training_strategy(self._config.conditioning)
111112

112-
def _init_wandb(self) -> None:
113-
"""Initialize Weights & Biases run."""
114-
if not self._config.wandb.enabled or not IS_MAIN_PROCESS:
115-
self._wandb_run = None
116-
return
117-
118-
wandb_config = self._config.wandb
119-
run = wandb.init(
120-
project=wandb_config.project,
121-
entity=wandb_config.entity,
122-
name=Path(self._config.output_dir).name,
123-
tags=wandb_config.tags,
124-
config=self._config.model_dump(),
125-
)
126-
self._wandb_run = run
127-
128-
def _log_metrics(self, metrics: dict[str, float]) -> None:
129-
"""Log metrics to Weights & Biases."""
130-
if self._wandb_run is not None:
131-
self._wandb_run.log(metrics)
132-
133-
def _log_validation_videos(self, video_paths: list[Path], prompts: list[str]) -> None:
134-
"""Log validation videos to Weights & Biases."""
135-
if not self._config.wandb.log_validation_videos or self._wandb_run is None:
136-
return
137-
138-
# Create lists of videos with their captions
139-
validation_videos = [
140-
wandb.Video(str(video_path), caption=prompt)
141-
for video_path, prompt in zip(video_paths, prompts, strict=False)
142-
]
143-
144-
# Log all videos at once
145-
self._wandb_run.log(
146-
{
147-
"validation_videos": validation_videos,
148-
},
149-
step=self._global_step,
150-
)
151-
152113
def train( # noqa: PLR0912, PLR0915
153114
self,
154115
disable_progress_bars: bool = False,
155-
step_callback: StepCallback = None,
116+
step_callback: StepCallback | None = None,
156117
) -> tuple[Path, TrainingStats]:
157118
"""
158119
Start the training process.
@@ -182,6 +143,9 @@ def train( # noqa: PLR0912, PLR0915
182143

183144
Path(cfg.output_dir).mkdir(parents=True, exist_ok=True)
184145

146+
# Save the training configuration as YAML
147+
self._save_config()
148+
185149
logger.info("🚀 Starting training...")
186150

187151
# Create progress columns with simplified styling
@@ -890,3 +854,54 @@ def _cleanup_checkpoints(self) -> None:
890854
logger.debug(f"Removed old checkpoints: {old_checkpoint}")
891855
# Update the list to only contain kept checkpoints
892856
self._checkpoint_paths = self._checkpoint_paths[-self._config.checkpoints.keep_last_n :]
857+
858+
def _save_config(self) -> None:
859+
"""Save the training configuration as a YAML file in the output directory."""
860+
if not IS_MAIN_PROCESS:
861+
return
862+
863+
config_path = Path(self._config.output_dir) / "training_config.yaml"
864+
with open(config_path, "w") as f:
865+
yaml.dump(self._config.model_dump(), f, default_flow_style=False, indent=2)
866+
867+
logger.info(f"💾 Training configuration saved to: {config_path.relative_to(self._config.output_dir)}")
868+
869+
def _init_wandb(self) -> None:
870+
"""Initialize Weights & Biases run."""
871+
if not self._config.wandb.enabled or not IS_MAIN_PROCESS:
872+
self._wandb_run = None
873+
return
874+
875+
wandb_config = self._config.wandb
876+
run = wandb.init(
877+
project=wandb_config.project,
878+
entity=wandb_config.entity,
879+
name=Path(self._config.output_dir).name,
880+
tags=wandb_config.tags,
881+
config=self._config.model_dump(),
882+
)
883+
self._wandb_run = run
884+
885+
def _log_metrics(self, metrics: dict[str, float]) -> None:
886+
"""Log metrics to Weights & Biases."""
887+
if self._wandb_run is not None:
888+
self._wandb_run.log(metrics)
889+
890+
def _log_validation_videos(self, video_paths: list[Path], prompts: list[str]) -> None:
891+
"""Log validation videos to Weights & Biases."""
892+
if not self._config.wandb.log_validation_videos or self._wandb_run is None:
893+
return
894+
895+
# Create lists of videos with their captions
896+
validation_videos = [
897+
wandb.Video(str(video_path), caption=prompt, format="mp4")
898+
for video_path, prompt in zip(video_paths, prompts, strict=False)
899+
]
900+
901+
# Log all videos at once
902+
self._wandb_run.log(
903+
{
904+
"validation_videos": validation_videos,
905+
},
906+
step=self._global_step,
907+
)

0 commit comments

Comments
 (0)