From 7c18e568cbe326f0094bb19122978a5516018f36 Mon Sep 17 00:00:00 2001 From: magicwang1111 Date: Fri, 23 May 2025 18:03:32 +0800 Subject: [PATCH] Fix: unwrap DDP before enabling gradient checkpointing for HF compatibility --- .gitignore | 1 + requirements.txt | 74 ++++++++++++++++++++++++++++++++++++ scripts/train_distributed.py | 25 +++++++----- src/ltxv_trainer/trainer.py | 16 +++++++- 4 files changed, 105 insertions(+), 11 deletions(-) create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore index d253580..9c35298 100644 --- a/.gitignore +++ b/.gitignore @@ -27,3 +27,4 @@ configs/*.yaml !configs/ltxv_2b_lora_low_vram.yaml outputs *.mov +dataset/ \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ec11747 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,74 @@ +accelerate==1.2.1 +annotated-types==0.7.0 +av==14.4.0 +bitsandbytes==0.45.2 ; sys_platform == 'linux' +certifi==2024.12.14 +charset-normalizer==3.4.1 +click==8.1.8 +colorama==0.4.6 ; sys_platform == 'win32' +decord==0.6.0 ; sys_platform == 'linux' +diffusers==0.33.1 +filelock==3.16.1 +fsspec==2024.12.0 +huggingface-hub==0.31.4 +idna==3.10 +imageio==2.37.0 +imageio-ffmpeg==0.6.0 +importlib-metadata==8.5.0 +jinja2==3.1.5 +markdown-it-py==3.0.0 +markupsafe==2.1.5 +mdurl==0.1.2 +mpmath==1.3.0 +networkx==3.4.2 +ninja==1.11.1.3 +numpy==2.2.1 +nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and sys_platform == 'linux' +nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' +nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' +nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' +nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and sys_platform == 'linux' +nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux' +nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and sys_platform == 'linux' +nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and sys_platform == 'linux' +nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and sys_platform == 'linux' +nvidia-cusparselt-cu12==0.6.2 ; platform_machine == 'x86_64' and sys_platform == 'linux' +nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and sys_platform == 'linux' +nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' +nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux' +opencv-python==4.11.0.86 +optimum-quanto==0.2.6 +packaging==24.2 +pandas==2.2.3 +peft==0.14.0 +pillow==11.1.0 +pillow-heif==0.21.0 +protobuf==5.29.3 +psutil==6.1.1 +pydantic==2.10.4 +pydantic-core==2.27.2 +pygments==2.19.1 +python-dateutil==2.9.0.post0 +pytz==2025.1 +pyyaml==6.0.2 +regex==2024.11.6 +requests==2.32.3 +rich==13.9.4 +safetensors==0.5.0 +scenedetect==0.6.5.2 +sentencepiece==0.2.0 +setuptools==75.6.0 +shellingham==1.5.4 +six==1.17.0 +sympy==1.13.1 +tokenizers==0.21.0 +torch==2.6.0 +torchvision==0.21.0 +tqdm==4.67.1 +transformers==4.52.2 +triton==3.2.0 ; platform_machine == 'x86_64' and sys_platform == 'linux' +typer==0.15.1 +typing-extensions==4.12.2 +tzdata==2025.1 +urllib3==2.3.0 +zipp==3.21.0 \ No newline at end of file diff --git a/scripts/train_distributed.py b/scripts/train_distributed.py index 919e2e6..074c769 100755 --- a/scripts/train_distributed.py +++ b/scripts/train_distributed.py @@ -32,10 +32,17 @@ is_flag=True, help="Disable progress bars during training", ) +@click.option( + "--main_process_port", + type=int, + default=None, + help="Override master port for Accelerate distributed communication", +) def main( config: str, num_processes: int | None, disable_progress_bars: bool, + main_process_port: int | None, ) -> None: # Get path to the training script script_dir = Path(__file__).parent @@ -60,16 +67,16 @@ def main( # Get the accelerate launch parser launch_parser = launch_command_parser() - # Construct the launch arguments - launch_args = [ - "--num_processes", - str(num_processes), - training_script, - *training_args, - ] - + # Construct accelerate launch arguments + launch_args = [] if num_processes > 1: - launch_args.insert(0, "--multi_gpu") + launch_args.append("--multi_gpu") + launch_args.extend(["--num_processes", str(num_processes)]) + if main_process_port is not None: + launch_args.extend(["--main_process_port", str(main_process_port)]) + # Add the actual training script and its args + launch_args.append(training_script) + launch_args.extend(training_args) # Parse the launch arguments launch_args = launch_parser.parse_args(launch_args) diff --git a/src/ltxv_trainer/trainer.py b/src/ltxv_trainer/trainer.py index 5726f7d..5ddebb7 100644 --- a/src/ltxv_trainer/trainer.py +++ b/src/ltxv_trainer/trainer.py @@ -547,6 +547,8 @@ def _load_checkpoint(self) -> None: def _prepare_models_for_training(self) -> None: """Prepare models for training with Accelerate.""" + from torch.nn.parallel import DistributedDataParallel + # Prepare and move models to the correct devices prepare = self._accelerator.prepare self._vae = prepare(self._vae).to("cpu") self._transformer = prepare(self._transformer) @@ -555,9 +557,19 @@ def _prepare_models_for_training(self) -> None: if not self._config.acceleration.load_text_encoder_in_8bit: self._text_encoder = self._text_encoder.to("cpu") - # Enable gradient checkpointing if requested + # Enable gradient checkpointing on the base model if requested if self._config.optimization.enable_gradient_checkpointing: - self._transformer.enable_gradient_checkpointing() + model = self._transformer + if isinstance(model, DistributedDataParallel): + model = model.module + # Use Hugging Face's API if available + if hasattr(model, "gradient_checkpointing_enable"): + model.gradient_checkpointing_enable() + elif hasattr(model, "enable_gradient_checkpointing"): + model.enable_gradient_checkpointing() + else: + logger.warning("Model does not support gradient checkpointing.") + @staticmethod def _find_checkpoint(checkpoint_path: str | Path) -> Path | None: