|
11 | 11 | import rich |
12 | 12 | import torch |
13 | 13 | import wandb |
| 14 | +import yaml |
14 | 15 | from accelerate import Accelerator |
15 | 16 | from accelerate.utils import set_seed |
16 | 17 | from diffusers.utils import export_to_video |
@@ -109,50 +110,10 @@ def __init__(self, trainer_config: LtxvTrainerConfig) -> None: |
109 | 110 | self._init_wandb() |
110 | 111 | self._training_strategy = get_training_strategy(self._config.conditioning) |
111 | 112 |
|
112 | | - def _init_wandb(self) -> None: |
113 | | - """Initialize Weights & Biases run.""" |
114 | | - if not self._config.wandb.enabled or not IS_MAIN_PROCESS: |
115 | | - self._wandb_run = None |
116 | | - return |
117 | | - |
118 | | - wandb_config = self._config.wandb |
119 | | - run = wandb.init( |
120 | | - project=wandb_config.project, |
121 | | - entity=wandb_config.entity, |
122 | | - name=Path(self._config.output_dir).name, |
123 | | - tags=wandb_config.tags, |
124 | | - config=self._config.model_dump(), |
125 | | - ) |
126 | | - self._wandb_run = run |
127 | | - |
128 | | - def _log_metrics(self, metrics: dict[str, float]) -> None: |
129 | | - """Log metrics to Weights & Biases.""" |
130 | | - if self._wandb_run is not None: |
131 | | - self._wandb_run.log(metrics) |
132 | | - |
133 | | - def _log_validation_videos(self, video_paths: list[Path], prompts: list[str]) -> None: |
134 | | - """Log validation videos to Weights & Biases.""" |
135 | | - if not self._config.wandb.log_validation_videos or self._wandb_run is None: |
136 | | - return |
137 | | - |
138 | | - # Create lists of videos with their captions |
139 | | - validation_videos = [ |
140 | | - wandb.Video(str(video_path), caption=prompt) |
141 | | - for video_path, prompt in zip(video_paths, prompts, strict=False) |
142 | | - ] |
143 | | - |
144 | | - # Log all videos at once |
145 | | - self._wandb_run.log( |
146 | | - { |
147 | | - "validation_videos": validation_videos, |
148 | | - }, |
149 | | - step=self._global_step, |
150 | | - ) |
151 | | - |
152 | 113 | def train( # noqa: PLR0912, PLR0915 |
153 | 114 | self, |
154 | 115 | disable_progress_bars: bool = False, |
155 | | - step_callback: StepCallback = None, |
| 116 | + step_callback: StepCallback | None = None, |
156 | 117 | ) -> tuple[Path, TrainingStats]: |
157 | 118 | """ |
158 | 119 | Start the training process. |
@@ -182,6 +143,9 @@ def train( # noqa: PLR0912, PLR0915 |
182 | 143 |
|
183 | 144 | Path(cfg.output_dir).mkdir(parents=True, exist_ok=True) |
184 | 145 |
|
| 146 | + # Save the training configuration as YAML |
| 147 | + self._save_config() |
| 148 | + |
185 | 149 | logger.info("🚀 Starting training...") |
186 | 150 |
|
187 | 151 | # Create progress columns with simplified styling |
@@ -890,3 +854,54 @@ def _cleanup_checkpoints(self) -> None: |
890 | 854 | logger.debug(f"Removed old checkpoints: {old_checkpoint}") |
891 | 855 | # Update the list to only contain kept checkpoints |
892 | 856 | self._checkpoint_paths = self._checkpoint_paths[-self._config.checkpoints.keep_last_n :] |
| 857 | + |
| 858 | + def _save_config(self) -> None: |
| 859 | + """Save the training configuration as a YAML file in the output directory.""" |
| 860 | + if not IS_MAIN_PROCESS: |
| 861 | + return |
| 862 | + |
| 863 | + config_path = Path(self._config.output_dir) / "training_config.yaml" |
| 864 | + with open(config_path, "w") as f: |
| 865 | + yaml.dump(self._config.model_dump(), f, default_flow_style=False, indent=2) |
| 866 | + |
| 867 | + logger.info(f"💾 Training configuration saved to: {config_path.relative_to(self._config.output_dir)}") |
| 868 | + |
| 869 | + def _init_wandb(self) -> None: |
| 870 | + """Initialize Weights & Biases run.""" |
| 871 | + if not self._config.wandb.enabled or not IS_MAIN_PROCESS: |
| 872 | + self._wandb_run = None |
| 873 | + return |
| 874 | + |
| 875 | + wandb_config = self._config.wandb |
| 876 | + run = wandb.init( |
| 877 | + project=wandb_config.project, |
| 878 | + entity=wandb_config.entity, |
| 879 | + name=Path(self._config.output_dir).name, |
| 880 | + tags=wandb_config.tags, |
| 881 | + config=self._config.model_dump(), |
| 882 | + ) |
| 883 | + self._wandb_run = run |
| 884 | + |
| 885 | + def _log_metrics(self, metrics: dict[str, float]) -> None: |
| 886 | + """Log metrics to Weights & Biases.""" |
| 887 | + if self._wandb_run is not None: |
| 888 | + self._wandb_run.log(metrics) |
| 889 | + |
| 890 | + def _log_validation_videos(self, video_paths: list[Path], prompts: list[str]) -> None: |
| 891 | + """Log validation videos to Weights & Biases.""" |
| 892 | + if not self._config.wandb.log_validation_videos or self._wandb_run is None: |
| 893 | + return |
| 894 | + |
| 895 | + # Create lists of videos with their captions |
| 896 | + validation_videos = [ |
| 897 | + wandb.Video(str(video_path), caption=prompt, format="mp4") |
| 898 | + for video_path, prompt in zip(video_paths, prompts, strict=False) |
| 899 | + ] |
| 900 | + |
| 901 | + # Log all videos at once |
| 902 | + self._wandb_run.log( |
| 903 | + { |
| 904 | + "validation_videos": validation_videos, |
| 905 | + }, |
| 906 | + step=self._global_step, |
| 907 | + ) |
0 commit comments