diff --git a/examples/formats/hunyuan_video/convert_to_original_format.py b/examples/formats/hunyuan_video/convert_to_original_format.py index c776fe16..7e013fa3 100644 --- a/examples/formats/hunyuan_video/convert_to_original_format.py +++ b/examples/formats/hunyuan_video/convert_to_original_format.py @@ -108,6 +108,17 @@ def convert_lora_sd(diffusers_lora_sd): elif "proj_out" in key: new_key = key.replace("proj_out", "linear2").replace(single_block_pattern, prefix + "single_blocks") converted_lora_sd[new_key] = diffusers_lora_sd[key] + elif "x_embedder" in key: + new_key = key.replace("x_embedder", "img_in").replace(double_block_pattern, prefix + "") + if "lora_A" in key: + embed = diffusers_lora_sd[key] + sizes = embed.size() + x_reshaped = embed.view(sizes[0], 16, sizes[2], sizes[3], sizes[4], 2) + x_meaned = x_reshaped.mean(dim=2) + converted_lora_sd[new_key] = x_meaned + else: + converted_lora_sd[new_key] = diffusers_lora_sd[key] + print(new_key, diffusers_lora_sd[key].size()) else: print(f"unknown or not implemented: {key}") diff --git a/examples/training/control/hunyuan_video/image_condition/train.sh b/examples/training/control/hunyuan_video/image_condition/train.sh new file mode 100755 index 00000000..34ef61a8 --- /dev/null +++ b/examples/training/control/hunyuan_video/image_condition/train.sh @@ -0,0 +1,175 @@ +#!/bin/bash + +set -e -x + +# export TORCH_LOGS="+dynamo,recompiles,graph_breaks" +# export TORCHDYNAMO_VERBOSE=1 +export WANDB_MODE="offline" +export NCCL_P2P_DISABLE=1 +export NCCL_IB_DISABLE=1 +export TORCH_NCCL_ENABLE_MONITORING=0 +export FINETRAINERS_LOG_LEVEL="INFO" + +# Download the validation dataset +if [ ! -d "examples/training/control/wan/image_condition/validation_dataset" ]; then + echo "Downloading validation dataset..." + huggingface-cli download --repo-type dataset finetrainers/OpenVid-1k-split-validation --local-dir examples/training/control/wan/image_condition/validation_dataset +else + echo "Validation dataset already exists. Skipping download." +fi + +# Finetrainers supports multiple backends for distributed training. Select your favourite and benchmark the differences! +# BACKEND="accelerate" +BACKEND="ptd" + +# In this setting, I'm using 1 GPU on 4-GPU node for training +NUM_GPUS=1 +CUDA_VISIBLE_DEVICES="3" + +# Check the JSON files for the expected JSON format +TRAINING_DATASET_CONFIG="examples/training/control/hunyuan_video/image_condition/training.json" +VALIDATION_DATASET_FILE="examples/training/control/hunyuan_video/image_condition/validation.json" + +# Depending on how many GPUs you have available, choose your degree of parallelism and technique! +DDP_1="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 4 --dp_shards 1 --cp_degree 1 --tp_degree 1" +DDP_8="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 8 --dp_shards 1 --cp_degree 1 --tp_degree 1" +FSDP_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 2 --cp_degree 1 --tp_degree 1" +FSDP_4="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 1 --dp_shards 4 --cp_degree 1 --tp_degree 1" +HSDP_2_2="--parallel_backend $BACKEND --pp_degree 1 --dp_degree 2 --dp_shards 2 --cp_degree 1 --tp_degree 1" + +# Parallel arguments +parallel_cmd=( + $DDP_1 +) + +# Model arguments +model_cmd=( + --model_name "hunyuan_video" + --pretrained_model_name_or_path "hunyuanvideo-community/HunyuanVideo" + --compile_modules transformer +) + +# Control arguments +control_cmd=( + --control_type none + --rank 128 + --lora_alpha 128 + --target_modules "blocks.*(to_q|to_k|to_v|to_out.0|ff.net.0.proj|ff.net.2)" + --frame_conditioning_type index + --frame_conditioning_index 0 +) + +# Dataset arguments +dataset_cmd=( + --dataset_config $TRAINING_DATASET_CONFIG + --dataset_shuffle_buffer_size 32 +) + +# Dataloader arguments +dataloader_cmd=( + --dataloader_num_workers 0 +) + +# Diffusion arguments +diffusion_cmd=( + --flow_weighting_scheme "logit_normal" +) + +# Training arguments +# We target just the attention projections layers for LoRA training here. +# You can modify as you please and target any layer (regex is supported) +training_cmd=( + --training_type control-lora + --seed 42 + --batch_size 1 + --train_steps 10000 + --gradient_accumulation_steps 1 + --gradient_checkpointing + --checkpointing_steps 1000 + --checkpointing_limit 2 + # --resume_from_checkpoint 3000 + --enable_slicing + --enable_tiling +) + +# Optimizer arguments +optimizer_cmd=( + --optimizer "adamw" + --lr 2e-5 + --lr_scheduler "constant_with_warmup" + --lr_warmup_steps 1000 + --lr_num_cycles 1 + --beta1 0.9 + --beta2 0.99 + --weight_decay 1e-4 + --epsilon 1e-8 + --max_grad_norm 1.0 +) + +# Validation arguments +validation_cmd=( + --validation_dataset_file "$VALIDATION_DATASET_FILE" + --validation_steps 501 +) + +# Miscellaneous arguments +miscellaneous_cmd=( + --tracker_name "finetrainers-hunyuan_video-control" + --output_dir "/raid/aryan/hunyuan_video-control-image-condition" + --init_timeout 600 + --nccl_timeout 600 + --report_to "wandb" +) + +# Execute the training script +if [ "$BACKEND" == "accelerate" ]; then + + ACCELERATE_CONFIG_FILE="" + if [ "$NUM_GPUS" == 1 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_1.yaml" + elif [ "$NUM_GPUS" == 2 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_2.yaml" + elif [ "$NUM_GPUS" == 4 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_4.yaml" + elif [ "$NUM_GPUS" == 8 ]; then + ACCELERATE_CONFIG_FILE="accelerate_configs/uncompiled_8.yaml" + fi + + accelerate launch --config_file "$ACCELERATE_CONFIG_FILE" --gpu_ids $CUDA_VISIBLE_DEVICES train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${control_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" + +elif [ "$BACKEND" == "ptd" ]; then + + export CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES + + torchrun \ + --standalone \ + --nnodes=1 \ + --nproc_per_node=$NUM_GPUS \ + --rdzv_backend c10d \ + --rdzv_endpoint="localhost:19242" \ + train.py \ + "${parallel_cmd[@]}" \ + "${model_cmd[@]}" \ + "${control_cmd[@]}" \ + "${dataset_cmd[@]}" \ + "${dataloader_cmd[@]}" \ + "${diffusion_cmd[@]}" \ + "${training_cmd[@]}" \ + "${optimizer_cmd[@]}" \ + "${validation_cmd[@]}" \ + "${miscellaneous_cmd[@]}" +fi + +echo -ne "-------------------- Finished executing script --------------------\n\n" diff --git a/finetrainers/config.py b/finetrainers/config.py index 46e713e9..dbc17647 100644 --- a/finetrainers/config.py +++ b/finetrainers/config.py @@ -5,7 +5,7 @@ from .models.cogvideox import CogVideoXModelSpecification from .models.cogview4 import CogView4ControlModelSpecification, CogView4ModelSpecification from .models.flux import FluxModelSpecification -from .models.hunyuan_video import HunyuanVideoModelSpecification +from .models.hunyuan_video import HunyuanVideoControlModelSpecification, HunyuanVideoModelSpecification from .models.ltx_video import LTXVideoModelSpecification from .models.wan import WanControlModelSpecification, WanModelSpecification @@ -49,6 +49,7 @@ class TrainingType(str, Enum): ModelType.HUNYUAN_VIDEO: { TrainingType.LORA: HunyuanVideoModelSpecification, TrainingType.FULL_FINETUNE: HunyuanVideoModelSpecification, + TrainingType.CONTROL_LORA: HunyuanVideoControlModelSpecification, }, ModelType.LTX_VIDEO: { TrainingType.LORA: LTXVideoModelSpecification, diff --git a/finetrainers/models/hunyuan_video/__init__.py b/finetrainers/models/hunyuan_video/__init__.py index 518a4286..b35bf2e6 100644 --- a/finetrainers/models/hunyuan_video/__init__.py +++ b/finetrainers/models/hunyuan_video/__init__.py @@ -1 +1,2 @@ from .base_specification import HunyuanVideoModelSpecification +from .control_specification import HunyuanVideoControlModelSpecification diff --git a/finetrainers/models/hunyuan_video/base_specification.py b/finetrainers/models/hunyuan_video/base_specification.py index 80d02c93..8de71c08 100644 --- a/finetrainers/models/hunyuan_video/base_specification.py +++ b/finetrainers/models/hunyuan_video/base_specification.py @@ -58,6 +58,7 @@ def forward( video = video.to(device=device, dtype=vae.dtype) video = video.permute(0, 2, 1, 3, 4).contiguous() # [B, F, C, H, W] -> [B, C, F, H, W] + compute_posterior = False if compute_posterior: latents = vae.encode(video).latent_dist.sample(generator=generator) latents = latents.to(dtype=dtype) diff --git a/finetrainers/models/hunyuan_video/control_specification.py b/finetrainers/models/hunyuan_video/control_specification.py new file mode 100644 index 00000000..88ba6ea8 --- /dev/null +++ b/finetrainers/models/hunyuan_video/control_specification.py @@ -0,0 +1,445 @@ +import functools +import os +from typing import Any, Dict, List, Optional, Tuple + +import safetensors +import torch +from accelerate import init_empty_weights +from diffusers import ( + AutoencoderKLHunyuanVideo, + FlowMatchEulerDiscreteScheduler, + HunyuanVideoPipeline, + HunyuanVideoTransformer3DModel, +) +from diffusers.models.autoencoders.vae import DiagonalGaussianDistribution +from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, LlamaModel + +from finetrainers.data._artifact import VideoArtifact +from finetrainers.models.utils import _expand_conv3d_with_zeroed_weights +from finetrainers.trainer.control_trainer.config import FrameConditioningType +from finetrainers.utils.serialization import safetensors_torch_save_function + +from ... import functional as FF +from ...logging import get_logger +from ...patches.dependencies.diffusers.control import control_channel_concat +from ...processors import CLIPPooledProcessor, LlamaProcessor, ProcessorMixin +from ...typing import ArtifactType, SchedulerType +from ...utils import _enable_vae_memory_optimizations, get_non_null_items +from ..modeling_utils import ControlModelSpecification +from .base_specification import HunyuanLatentEncodeProcessor + + +logger = get_logger() + + +class HunyuanVideoControlModelSpecification(ControlModelSpecification): + def __init__( + self, + pretrained_model_name_or_path: str = "hunyuanvideo-community/HunyuanVideo", + tokenizer_id: Optional[str] = None, + text_encoder_id: Optional[str] = None, + transformer_id: Optional[str] = None, + vae_id: Optional[str] = None, + text_encoder_dtype: torch.dtype = torch.bfloat16, + transformer_dtype: torch.dtype = torch.bfloat16, + vae_dtype: torch.dtype = torch.bfloat16, + revision: Optional[str] = None, + cache_dir: Optional[str] = None, + condition_model_processors: List[ProcessorMixin] = None, + latent_model_processors: List[ProcessorMixin] = None, + control_model_processors: List[ProcessorMixin] = None, + **kwargs, + ) -> None: + super().__init__( + pretrained_model_name_or_path=pretrained_model_name_or_path, + tokenizer_id=tokenizer_id, + text_encoder_id=text_encoder_id, + transformer_id=transformer_id, + vae_id=vae_id, + text_encoder_dtype=text_encoder_dtype, + transformer_dtype=transformer_dtype, + vae_dtype=vae_dtype, + revision=revision, + cache_dir=cache_dir, + ) + + if condition_model_processors is None: + condition_model_processors = [ + LlamaProcessor(["encoder_hidden_states", "encoder_attention_mask"]), + CLIPPooledProcessor( + ["pooled_projections"], + input_names={"tokenizer_2": "tokenizer", "text_encoder_2": "text_encoder"}, + ), + ] + if latent_model_processors is None: + latent_model_processors = [HunyuanLatentEncodeProcessor(["latents"])] + if control_model_processors is None: + control_model_processors = [HunyuanLatentEncodeProcessor(["control_latents", "__drop__", "__drop__"])] + + self.condition_model_processors = condition_model_processors + self.latent_model_processors = latent_model_processors + self.control_model_processors = control_model_processors + + @property + def control_injection_layer_name(self) -> str: + return "x_embedder.proj" + + @property + def _resolution_dim_keys(self): + return {"latents": (2, 3, 4)} + + def load_condition_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.tokenizer_id is not None: + tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_id, **common_kwargs) + else: + tokenizer = AutoTokenizer.from_pretrained( + self.pretrained_model_name_or_path, subfolder="tokenizer", **common_kwargs + ) + + if self.tokenizer_2_id is not None: + tokenizer_2 = AutoTokenizer.from_pretrained(self.tokenizer_2_id, **common_kwargs) + else: + tokenizer_2 = CLIPTokenizer.from_pretrained( + self.pretrained_model_name_or_path, subfolder="tokenizer_2", **common_kwargs + ) + + if self.text_encoder_id is not None: + text_encoder = LlamaModel.from_pretrained( + self.text_encoder_id, torch_dtype=self.text_encoder_dtype, **common_kwargs + ) + else: + text_encoder = LlamaModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder", + torch_dtype=self.text_encoder_dtype, + **common_kwargs, + ) + + if self.text_encoder_2_id is not None: + text_encoder_2 = CLIPTextModel.from_pretrained( + self.text_encoder_2_id, torch_dtype=self.text_encoder_2_dtype, **common_kwargs + ) + else: + text_encoder_2 = CLIPTextModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="text_encoder_2", + torch_dtype=self.text_encoder_2_dtype, + **common_kwargs, + ) + + return { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + } + + def load_latent_models(self) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.vae_id is not None: + vae = AutoencoderKLHunyuanVideo.from_pretrained(self.vae_id, torch_dtype=self.vae_dtype, **common_kwargs) + else: + vae = AutoencoderKLHunyuanVideo.from_pretrained( + self.pretrained_model_name_or_path, subfolder="vae", torch_dtype=self.vae_dtype, **common_kwargs + ) + + return {"vae": vae} + + def load_diffusion_models(self, new_in_features: int) -> Dict[str, torch.nn.Module]: + common_kwargs = {"revision": self.revision, "cache_dir": self.cache_dir} + + if self.transformer_id is not None: + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + self.transformer_id, torch_dtype=self.transformer_dtype, **common_kwargs + ) + else: + transformer = HunyuanVideoTransformer3DModel.from_pretrained( + self.pretrained_model_name_or_path, + subfolder="transformer", + torch_dtype=self.transformer_dtype, + **common_kwargs, + ) + + transformer.x_embedder.proj = _expand_conv3d_with_zeroed_weights( + transformer.x_embedder.proj, new_in_channels=new_in_features + ) + transformer.register_to_config(in_channels=new_in_features) + scheduler = FlowMatchEulerDiscreteScheduler() + + return {"transformer": transformer, "scheduler": scheduler} + + def load_pipeline( + self, + tokenizer: Optional[AutoTokenizer] = None, + tokenizer_2: Optional[CLIPTokenizer] = None, + text_encoder: Optional[LlamaModel] = None, + text_encoder_2: Optional[CLIPTextModel] = None, + transformer: Optional[HunyuanVideoTransformer3DModel] = None, + vae: Optional[AutoencoderKLHunyuanVideo] = None, + scheduler: Optional[FlowMatchEulerDiscreteScheduler] = None, + enable_slicing: bool = False, + enable_tiling: bool = False, + enable_model_cpu_offload: bool = False, + training: bool = False, + **kwargs, + ) -> HunyuanVideoPipeline: + components = { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "transformer": transformer, + "vae": vae, + "scheduler": scheduler, + } + components = get_non_null_items(components) + + pipe = HunyuanVideoPipeline.from_pretrained( + self.pretrained_model_name_or_path, **components, revision=self.revision, cache_dir=self.cache_dir + ) + pipe.text_encoder.to(self.text_encoder_dtype) + pipe.text_encoder_2.to(self.text_encoder_2_dtype) + pipe.vae.to(self.vae_dtype) + + _enable_vae_memory_optimizations(pipe.vae, enable_slicing, enable_tiling) + if not training: + pipe.transformer.to(self.transformer_dtype) + + # TODO(aryan): add support in diffusers + # if enable_slicing: + # pipe.vae.enable_slicing() + # if enable_tiling: + # pipe.vae.enable_tiling() + if enable_model_cpu_offload: + pipe.enable_model_cpu_offload() + + return pipe + + @torch.no_grad() + def prepare_conditions( + self, + tokenizer: AutoTokenizer, + tokenizer_2: CLIPTokenizer, + text_encoder: LlamaModel, + text_encoder_2: CLIPTextModel, + caption: str, + max_sequence_length: int = 256, + **kwargs, + ) -> Dict[str, Any]: + conditions = { + "tokenizer": tokenizer, + "tokenizer_2": tokenizer_2, + "text_encoder": text_encoder, + "text_encoder_2": text_encoder_2, + "caption": caption, + "max_sequence_length": max_sequence_length, + **kwargs, + } + input_keys = set(conditions.keys()) + conditions = super().prepare_conditions(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + return conditions + + @torch.no_grad() + def prepare_latents( + self, + vae: AutoencoderKLHunyuanVideo, + image: Optional[torch.Tensor] = None, + video: Optional[torch.Tensor] = None, + control_image: Optional[torch.Tensor] = None, + control_video: Optional[torch.Tensor] = None, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Dict[str, torch.Tensor]: + common_kwargs = { + "vae": vae, + "generator": generator, + "compute_posterior": compute_posterior, + **kwargs, + } + conditions = {"image": image, "video": video, **common_kwargs} + input_keys = set(conditions.keys()) + conditions = super().prepare_latents(**conditions) + conditions = {k: v for k, v in conditions.items() if k not in input_keys} + + control_conditions = {"image": control_image, "video": control_video, **common_kwargs} + input_keys = set(control_conditions.keys()) + control_conditions = ControlModelSpecification.prepare_latents( + self, self.control_model_processors, **control_conditions + ) + control_conditions = {k: v for k, v in control_conditions.items() if k not in input_keys} + + return {**control_conditions, **conditions} + + def forward( + self, + transformer: HunyuanVideoTransformer3DModel, + condition_model_conditions: Dict[str, torch.Tensor], + latent_model_conditions: Dict[str, torch.Tensor], + sigmas: torch.Tensor, + guidance: float = 1.0, + generator: Optional[torch.Generator] = None, + compute_posterior: bool = True, + **kwargs, + ) -> Tuple[torch.Tensor, ...]: + from finetrainers.trainer.control_trainer.data import apply_frame_conditioning_on_latents + + if compute_posterior: + latents = latent_model_conditions.pop("latents") + else: + posterior = DiagonalGaussianDistribution(latent_model_conditions.pop("latents")) + latents = posterior.sample(generator=generator) + del posterior + + noise = torch.zeros_like(latents).normal_(generator=generator) + timesteps = (sigmas.flatten() * 1000.0).long() + guidance = latents.new_full((latents.size(0),), fill_value=guidance) * 1000.0 + + noisy_latents = FF.flow_match_xt(latents, noise, sigmas) + control_latents = apply_frame_conditioning_on_latents( + control_latents, + noisy_latents.shape[2], + channel_dim=1, + frame_dim=2, + frame_conditioning_type=self.frame_conditioning_type, + frame_conditioning_index=self.frame_conditioning_index, + concatenate_mask=self.frame_conditioning_concatenate_mask, + ) + noisy_latents = torch.cat([noisy_latents, control_latents], dim=1) + + latent_model_conditions["hidden_states"] = noisy_latents.to(latents) + latent_model_conditions["guidance"] = guidance + + pred = transformer( + **latent_model_conditions, + **condition_model_conditions, + timestep=timesteps, + return_dict=False, + )[0] + target = FF.flow_match_target(noise, latents) + + return pred, target, sigmas + + def validation( + self, + pipeline: HunyuanVideoPipeline, + prompt: str, + control_image: Optional[torch.Tensor] = None, + control_video: Optional[torch.Tensor] = None, + height: Optional[int] = None, + width: Optional[int] = None, + num_frames: Optional[int] = None, + num_inference_steps: int = 50, + generator: Optional[torch.Generator] = None, + frame_conditioning_type: "FrameConditioningType" = "full", + frame_conditioning_index: int = 0, + **kwargs, + ) -> List[ArtifactType]: + from finetrainers.trainer.control_trainer.data import apply_frame_conditioning_on_latents + + with torch.no_grad(): + dtype = pipeline.vae.dtype + device = pipeline._execution_device + in_channels = self.transformer_config.in_channels # We need to use the original in_channels + latents = pipeline.prepare_latents(1, in_channels, height, width, num_frames, dtype, device, generator) + latents_mean = ( + torch.tensor(self.vae_config.latents_mean) + .view(1, self.vae_config.z_dim, 1, 1, 1) + .to(latents.device, latents.dtype) + ) + latents_std = 1.0 / torch.tensor(self.vae_config.latents_std).view(1, self.vae_config.z_dim, 1, 1, 1).to( + latents.device, latents.dtype + ) + + if control_image is not None: + control_video = pipeline.video_processor.preprocess( + control_image, height=height, width=width + ).unsqueeze(2) + else: + control_video = pipeline.video_processor.preprocess_video(control_video, height=height, width=width) + + control_video = control_video.to(device=device, dtype=dtype) + control_latents = pipeline.vae.encode(control_video).latent_dist.mode() + control_latents = self._normalize_latents(control_latents, latents_mean, latents_std) + control_latents = apply_frame_conditioning_on_latents( + control_latents, + latents.shape[2], + channel_dim=1, + frame_dim=2, + frame_conditioning_type=frame_conditioning_type, + frame_conditioning_index=frame_conditioning_index, + concatenate_mask=self.frame_conditioning_concatenate_mask, + ) + + generation_kwargs = { + "latents": latents, + "prompt": prompt, + "height": height, + "width": width, + "num_frames": num_frames, + "num_inference_steps": num_inference_steps, + "generator": generator, + "return_dict": True, + "output_type": "pil", + } + generation_kwargs = get_non_null_items(generation_kwargs) + + with control_channel_concat(pipeline.transformer, ["hidden_states"], [control_latents], dims=[1]): + video = pipeline(**generation_kwargs).frames[0] + + return [VideoArtifact(value=video)] + + def _save_lora_weights( + self, + directory: str, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + norm_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + metadata: Optional[Dict[str, str]] = None, + *args, + **kwargs, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + HunyuanVideoPipeline.save_lora_weights( + directory, + transformer_state_dict, + save_function=functools.partial(safetensors_torch_save_function, metadata=metadata), + safe_serialization=True, + ) + if norm_state_dict is not None: + safetensors.torch.save_file(norm_state_dict, os.path.join(directory, "norm_state_dict.safetensors")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + def _save_model( + self, + directory: str, + transformer: HunyuanVideoTransformer3DModel, + transformer_state_dict: Optional[Dict[str, torch.Tensor]] = None, + scheduler: Optional[SchedulerType] = None, + ) -> None: + # TODO(aryan): this needs refactoring + if transformer_state_dict is not None: + with init_empty_weights(): + transformer_copy = HunyuanVideoTransformer3DModel.from_config(transformer.config) + transformer_copy.load_state_dict(transformer_state_dict, strict=True, assign=True) + transformer_copy.save_pretrained(os.path.join(directory, "transformer")) + if scheduler is not None: + scheduler.save_pretrained(os.path.join(directory, "scheduler")) + + @property + def _original_control_layer_in_features(self): + return self.transformer_config.in_channels + + @property + def _original_control_layer_out_features(self): + return self.transformer_config.num_attention_heads * self.transformer_config.attention_head_dim + + @property + def _qk_norm_identifiers(self): + return ["norm_q", "norm_k", "norm_added_q", "norm_added_k"] diff --git a/finetrainers/trainer/control_trainer/trainer.py b/finetrainers/trainer/control_trainer/trainer.py index f4d829b4..853e8d98 100644 --- a/finetrainers/trainer/control_trainer/trainer.py +++ b/finetrainers/trainer/control_trainer/trainer.py @@ -11,7 +11,6 @@ import safetensors.torch import torch import torch.backends -import wandb from diffusers import DiffusionPipeline from diffusers.hooks import apply_layerwise_casting from diffusers.training_utils import cast_training_params @@ -20,8 +19,9 @@ from peft import LoraConfig, get_peft_model_state_dict from tqdm import tqdm +import wandb from finetrainers import data, logging, optimizer, parallel, patches, utils -from finetrainers.config import TrainingType + from finetrainers.patches import load_lora_weights from finetrainers.state import State, TrainState @@ -124,6 +124,8 @@ def _prepare_trainable_parameters(self) -> None: parallel_backend = self.state.parallel_backend model_spec = self.model_specification + from finetrainers.config import TrainingType + if self.args.training_type == TrainingType.CONTROL_FULL_FINETUNE: logger.info("Finetuning transformer with no additional parameters") utils.set_requires_grad([self.transformer], True) @@ -335,6 +337,7 @@ def _prepare_dataset(self) -> None: def _prepare_checkpointing(self) -> None: parallel_backend = self.state.parallel_backend + from finetrainers.config import TrainingType def save_model_hook(state_dict: Dict[str, Any]) -> None: state_dict = utils.get_unwrapped_model_state_dict(state_dict) @@ -910,6 +913,7 @@ def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline: ) else: self._delete_components() + from finetrainers.config import TrainingType # TODO(aryan): allow multiple control conditions instead of just one if there's a use case for it new_in_features = self.model_specification._original_control_layer_in_features * 2 diff --git a/finetrainers/trainer/sft_trainer/trainer.py b/finetrainers/trainer/sft_trainer/trainer.py index f9892e36..32fd6b40 100644 --- a/finetrainers/trainer/sft_trainer/trainer.py +++ b/finetrainers/trainer/sft_trainer/trainer.py @@ -19,7 +19,6 @@ from tqdm import tqdm from finetrainers import data, logging, optimizer, parallel, patches, utils -from finetrainers.config import TrainingType from finetrainers.state import State, TrainState from .config import SFTFullRankConfig, SFTLowRankConfig @@ -112,6 +111,7 @@ def _prepare_trainable_parameters(self) -> None: logger.info("Initializing trainable parameters") parallel_backend = self.state.parallel_backend + from finetrainers.config import TrainingType if self.args.training_type == TrainingType.FULL_FINETUNE: logger.info("Finetuning transformer with no additional parameters") @@ -297,6 +297,7 @@ def _prepare_dataset(self) -> None: def _prepare_checkpointing(self) -> None: parallel_backend = self.state.parallel_backend + from finetrainers.config import TrainingType def save_model_hook(state_dict: Dict[str, Any]) -> None: state_dict = utils.get_unwrapped_model_state_dict(state_dict) @@ -842,6 +843,7 @@ def _init_pipeline(self, final_validation: bool = False) -> DiffusionPipeline: ) else: self._delete_components() + from finetrainers.config import TrainingType # Load the transformer weights from the final checkpoint if performing full-finetune transformer = None