- 
                Notifications
    You must be signed in to change notification settings 
- Fork 137
Control lora trainer hunyuan #373
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 58 commits
dc85bbf
              755fee8
              8812036
              e39d255
              18bd9ce
              3ef07fc
              ea07973
              9f3d2cb
              2af75b1
              28b86c8
              84ffbd3
              1684ee5
              483e891
              657fb74
              cd859b3
              8cea261
              90d6d38
              45bbf22
              825976d
              053757d
              eaafeab
              9144f28
              3745ae5
              d7ba5e1
              7245b5a
              c1c600f
              e1ef448
              8587874
              495e2b1
              322d610
              2aeca67
              f256ea7
              c9fa316
              ea09ef7
              1722964
              c80c995
              232508d
              2853be5
              d014b04
              7db51c4
              f5fb737
              b2b77a8
              916bd33
              3c53e1e
              52a6034
              197e2fe
              022fcfe
              7d31522
              d4002ce
              4fcb7c6
              d894b05
              9d43e8a
              9be0e0c
              1e12216
              681a62f
              edc50a8
              09f8b7d
              1d6e74f
              12d61f3
              3706569
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -0,0 +1,175 @@ | ||
| #!/bin/bash | ||
|  | ||
| set -e -x | ||
|  | ||
| # export TORCH_LOGS="+dynamo,recompiles,graph_breaks" | ||
| # export TORCHDYNAMO_VERBOSE=1 | ||
| export WANDB_MODE="offline" | ||
| export NCCL_P2P_DISABLE=1 | ||
| export NCCL_IB_DISABLE=1 | ||
| export TORCH_NCCL_ENABLE_MONITORING=0 | ||
| export FINETRAINERS_LOG_LEVEL="INFO" | ||
|  | ||
| # Download the validation dataset | ||
| if [ ! -d "examples/training/control/wan/image_condition/validation_dataset" ]; then | ||
| echo "Downloading validation dataset..." | ||
| huggingface-cli download --repo-type dataset finetrainers/OpenVid-1k-split-validation --local-dir examples/training/control/wan/image_condition/validation_dataset | ||
| else | ||
| echo "Validation dataset already exists. Skipping download." | ||
| fi | ||
|  | ||
| # Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences! | ||
| # BACKEND="accelerate" | ||
| BACKEND="ptd" | ||
|  | ||
| # In this setting, I'm using 1 GPU on 4-GPU node for training | ||
| NUM_GPUS=1 | ||
| CUDA_VISIBLE_DEVICES="3" | ||
|  | ||
| # Check the JSON files for the expected JSON format | ||
| TRAINING_DATASET_CONFIG="examples/training/control/hunyuan_video/image_condition/training.json" | ||
| VALIDATION_DATASET_FILE="examples/training/control/hunyuan_video/image_condition/validation.json" | ||
|  | ||
| # Depending on how many GPUs you have available, choose your degree of parallelism and technique! | ||
| DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" | ||
| DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" | ||
| DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" | ||
| DDP_8="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 8 --dp_shards 1 --cp_degree 1 --tp_degree 1" | ||
| FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" | ||
| FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" | ||
| HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" | ||
|  | ||
| # Parallel arguments | ||
| parallel_cmd=( | ||
| $DDP_1 | ||
| ) | ||
|  | ||
| # Model arguments | ||
| model_cmd=( | ||
| --model_name "hunyuan_video" | ||
| --pretrained_model_name_or_path "hunyuanvideo-community/HunyuanVideo" | ||
| --compile_modules transformer | ||
| ) | ||
|  | ||
| # Control arguments | ||
| control_cmd=( | ||
| --control_type none | ||
| --rank 128 | ||
| --lora_alpha 128 | ||
| --target_modules "blocks.*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)" | ||
| --frame_conditioning_type index | ||
| --frame_conditioning_index 0 | ||
| ) | ||
|  | ||
| # Dataset arguments | ||
| dataset_cmd=( | ||
| --dataset_config $TRAINING_DATASET_CONFIG | ||
| --dataset_shuffle_buffer_size 32 | ||
| ) | ||
|  | ||
| # Dataloader arguments | ||
| dataloader_cmd=( | ||
| --dataloader_num_workers 0 | ||
| ) | ||
|  | ||
| # Diffusion arguments | ||
| diffusion_cmd=( | ||
| --flow_weighting_scheme "logit_normal" | ||
| ) | ||
|  | ||
| # Training arguments | ||
| # We target just the attention projections layers for LoRA training here. | ||
| # You can modify as you please and target any layer (regex is supported) | ||
| training_cmd=( | ||
| --training_type control-lora | ||
| --seed 42 | ||
| --batch_size 1 | ||
| --train_steps 10000 | ||
| --gradient_accumulation_steps 1 | ||
| --gradient_checkpointing | ||
| --checkpointing_steps 1000 | ||
| --checkpointing_limit 2 | ||
| # --resume_from_checkpoint 3000 | ||
| --enable_slicing | ||
| --enable_tiling | ||
| ) | ||
|  | ||
| # Optimizer arguments | ||
| optimizer_cmd=( | ||
| --optimizer "adamw" | ||
| --lr 2e-5 | ||
| --lr_scheduler "constant_with_warmup" | ||
| --lr_warmup_steps 1000 | ||
| --lr_num_cycles 1 | ||
| --beta1 0.9 | ||
| --beta2 0.99 | ||
| --weight_decay 1e-4 | ||
| --epsilon 1e-8 | ||
| --max_grad_norm 1.0 | ||
| ) | ||
|  | ||
| # Validation arguments | ||
| validation_cmd=( | ||
| --validation_dataset_file "$VALIDATION_DATASET_FILE" | ||
| --validation_steps 501 | ||
| ) | ||
|  | ||
| # Miscellaneous arguments | ||
| miscellaneous_cmd=( | ||
| --tracker_name "finetrainers-hunyuan_video-control" | ||
| --output_dir "/raid/aryan/hunyuan_video-control-image-condition" | ||
| --init_timeout 600 | ||
| --nccl_timeout 600 | ||
| --report_to "wandb" | ||
| ) | ||
|  | ||
| # Execute the training script | ||
| if [ "$BACKEND" == "accelerate" ]; then | ||
|  | ||
| ACCELERATE_CONFIG_FILE="" | ||
| if [ "$NUM_GPUS" == 1 ]; then | ||
| ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" | ||
| elif [ "$NUM_GPUS" == 2 ]; then | ||
| ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml" | ||
| elif [ "$NUM_GPUS" == 4 ]; then | ||
| ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml" | ||
| elif [ "$NUM_GPUS" == 8 ]; then | ||
| ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml" | ||
| fi | ||
|  | ||
| accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \ | ||
| "${parallel_cmd[@]}" \ | ||
| "${model_cmd[@]}" \ | ||
| "${control_cmd[@]}" \ | ||
| "${dataset_cmd[@]}" \ | ||
| "${dataloader_cmd[@]}" \ | ||
| "${diffusion_cmd[@]}" \ | ||
| "${training_cmd[@]}" \ | ||
| "${optimizer_cmd[@]}" \ | ||
| "${validation_cmd[@]}" \ | ||
| "${miscellaneous_cmd[@]}" | ||
|  | ||
| elif [ "$BACKEND" == "ptd" ]; then | ||
|  | ||
| export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES | ||
|  | ||
| torchrun \ | ||
| --standalone \ | ||
| --nnodes=1 \ | ||
| --nproc_per_node=$NUM_GPUS \ | ||
| --rdzv_backend c10d \ | ||
| --rdzv_endpoint="localhost:19242" \ | ||
| train.py \ | ||
| "${parallel_cmd[@]}" \ | ||
| "${model_cmd[@]}" \ | ||
| "${control_cmd[@]}" \ | ||
| "${dataset_cmd[@]}" \ | ||
| "${dataloader_cmd[@]}" \ | ||
| "${diffusion_cmd[@]}" \ | ||
| "${training_cmd[@]}" \ | ||
| "${optimizer_cmd[@]}" \ | ||
| "${validation_cmd[@]}" \ | ||
| "${miscellaneous_cmd[@]}" | ||
| fi | ||
|  | ||
| echo -ne "-------------------- Finished executing script --------------------\n\n" | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
| @@ -1 +1,2 @@ | ||
| from .base_specification import HunyuanVideoModelSpecification | ||
| from .control_specification import HunyuanVideoControlModelSpecification | 
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -38,7 +38,7 @@ class HunyuanLatentEncodeProcessor(ProcessorMixin): | |
| def __init__(self, output_names: List[str]): | ||
| super().__init__() | ||
| self.output_names = output_names | ||
| assert len(self.output_names) == 1 | ||
| assert len(self.output_names) == 3 | ||
|  | ||
| def forward( | ||
| self, | ||
|  | @@ -58,18 +58,24 @@ def forward( | |
| video = video.to(device=device, dtype=vae.dtype) | ||
| video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] | ||
|  | ||
| compute_posterior = False | ||
| if compute_posterior: | ||
| latents = vae.encode(video).latent_dist.sample(generator=generator) | ||
| latents = latents.to(dtype=dtype) | ||
| else: | ||
| if vae.use_slicing and video.shape[0] > 1: | ||
| encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] | ||
| moments = torch.cat(encoded_slices) | ||
| else: | ||
| moments = vae._encode(video) | ||
| # TODO(aryan): refactor in diffusers to have use_slicing attribute | ||
| # if vae.use_slicing and video.shape[0] > 1: | ||
| # encoded_slices = [vae._encode(x_slice) for x_slice in video.split(1)] | ||
| # moments = torch.cat(encoded_slices) | ||
| # else: | ||
| # moments = vae._encode(video) | ||
| moments = vae._encode(video) | ||
| latents = moments.to(dtype=dtype) | ||
|  | ||
| return {self.output_names[0]: latents} | ||
| latents_mean = torch.tensor(vae.latent_channels) | ||
|          | ||
| latents_std = 1.0 / torch.tensor(vae.latent_channels) | ||
|  | ||
| return {self.output_names[0]: latents, self.output_names[1]: latents_mean, self.output_names[2]: latents_std} | ||
|  | ||
|  | ||
| class HunyuanVideoModelSpecification(ModelSpecification): | ||
|  | @@ -115,7 +121,7 @@ def __init__( | |
| ), | ||
| ] | ||
| if latent_model_processors is None: | ||
| latent_model_processors = [HunyuanLatentEncodeProcessor(["latents"])] | ||
| latent_model_processors = [HunyuanLatentEncodeProcessor(["latents", "latents_mean", "latents_std"])] | ||
|  | ||
| self.condition_model_processors = condition_model_processors | ||
| self.latent_model_processors = latent_model_processors | ||
|  | @@ -305,7 +311,16 @@ def forward( | |
| if compute_posterior: | ||
| latents = latent_model_conditions.pop("latents") | ||
| else: | ||
| posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents")) | ||
| latents = latent_model_conditions.pop("latents") | ||
| latents_mean = latent_model_conditions.pop("latents_mean") | ||
| latents_std = latent_model_conditions.pop("latents_std") | ||
|  | ||
| mu, logvar = torch.chunk(latents, 2, dim=1) | ||
| mu = self._normalize_latents(mu, latents_mean, latents_std) | ||
| logvar = self._normalize_latents(logvar, latents_mean, latents_std) | ||
| latents = torch.cat([mu, logvar], dim=1) | ||
|  | ||
| posterior = DiagonalGaussianDistribution(latents) | ||
| latents = posterior.sample(generator=generator) | ||
| del posterior | ||
|  | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So far only made it work with compute_posterior false