Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ configs/*.yaml
!configs/ltxv_2b_lora_low_vram.yaml
outputs
*.mov
dataset/
74 changes: 74 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
accelerate==1.2.1
annotated-types==0.7.0
av==14.4.0
bitsandbytes==0.45.2 ; sys_platform == 'linux'
certifi==2024.12.14
charset-normalizer==3.4.1
click==8.1.8
colorama==0.4.6 ; sys_platform == 'win32'
decord==0.6.0 ; sys_platform == 'linux'
diffusers==0.33.1
filelock==3.16.1
fsspec==2024.12.0
huggingface-hub==0.31.4
idna==3.10
imageio==2.37.0
imageio-ffmpeg==0.6.0
importlib-metadata==8.5.0
jinja2==3.1.5
markdown-it-py==3.0.0
markupsafe==2.1.5
mdurl==0.1.2
mpmath==1.3.0
networkx==3.4.2
ninja==1.11.1.3
numpy==2.2.1
nvidia-cublas-cu12==12.4.5.8 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-cupti-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-nvrtc-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cuda-runtime-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cudnn-cu12==9.1.0.70 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cufft-cu12==11.2.1.3 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-curand-cu12==10.3.5.147 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusolver-cu12==11.6.1.9 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusparse-cu12==12.3.1.170 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-cusparselt-cu12==0.6.2 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nccl-cu12==2.21.5 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nvjitlink-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
nvidia-nvtx-cu12==12.4.127 ; platform_machine == 'x86_64' and sys_platform == 'linux'
opencv-python==4.11.0.86
optimum-quanto==0.2.6
packaging==24.2
pandas==2.2.3
peft==0.14.0
pillow==11.1.0
pillow-heif==0.21.0
protobuf==5.29.3
psutil==6.1.1
pydantic==2.10.4
pydantic-core==2.27.2
pygments==2.19.1
python-dateutil==2.9.0.post0
pytz==2025.1
pyyaml==6.0.2
regex==2024.11.6
requests==2.32.3
rich==13.9.4
safetensors==0.5.0
scenedetect==0.6.5.2
sentencepiece==0.2.0
setuptools==75.6.0
shellingham==1.5.4
six==1.17.0
sympy==1.13.1
tokenizers==0.21.0
torch==2.6.0
torchvision==0.21.0
tqdm==4.67.1
transformers==4.52.2
triton==3.2.0 ; platform_machine == 'x86_64' and sys_platform == 'linux'
typer==0.15.1
typing-extensions==4.12.2
tzdata==2025.1
urllib3==2.3.0
zipp==3.21.0
25 changes: 16 additions & 9 deletions scripts/train_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,17 @@
is_flag=True,
help="Disable progress bars during training",
)
@click.option(
"--main_process_port",
type=int,
default=None,
help="Override master port for Accelerate distributed communication",
)
def main(
config: str,
num_processes: int | None,
disable_progress_bars: bool,
main_process_port: int | None,
) -> None:
# Get path to the training script
script_dir = Path(__file__).parent
Expand All @@ -60,16 +67,16 @@ def main(
# Get the accelerate launch parser
launch_parser = launch_command_parser()

# Construct the launch arguments
launch_args = [
"--num_processes",
str(num_processes),
training_script,
*training_args,
]

# Construct accelerate launch arguments
launch_args = []
if num_processes > 1:
launch_args.insert(0, "--multi_gpu")
launch_args.append("--multi_gpu")
launch_args.extend(["--num_processes", str(num_processes)])
if main_process_port is not None:
launch_args.extend(["--main_process_port", str(main_process_port)])
# Add the actual training script and its args
launch_args.append(training_script)
launch_args.extend(training_args)

# Parse the launch arguments
launch_args = launch_parser.parse_args(launch_args)
Expand Down
16 changes: 14 additions & 2 deletions src/ltxv_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,8 @@ def _load_checkpoint(self) -> None:

def _prepare_models_for_training(self) -> None:
"""Prepare models for training with Accelerate."""
from torch.nn.parallel import DistributedDataParallel
# Prepare and move models to the correct devices
prepare = self._accelerator.prepare
self._vae = prepare(self._vae).to("cpu")
self._transformer = prepare(self._transformer)
Expand All @@ -555,9 +557,19 @@ def _prepare_models_for_training(self) -> None:
if not self._config.acceleration.load_text_encoder_in_8bit:
self._text_encoder = self._text_encoder.to("cpu")

# Enable gradient checkpointing if requested
# Enable gradient checkpointing on the base model if requested
if self._config.optimization.enable_gradient_checkpointing:
self._transformer.enable_gradient_checkpointing()
model = self._transformer
if isinstance(model, DistributedDataParallel):
model = model.module
# Use Hugging Face's API if available
if hasattr(model, "gradient_checkpointing_enable"):
model.gradient_checkpointing_enable()
elif hasattr(model, "enable_gradient_checkpointing"):
model.enable_gradient_checkpointing()
else:
logger.warning("Model does not support gradient checkpointing.")


@staticmethod
def _find_checkpoint(checkpoint_path: str | Path) -> Path | None:
Expand Down