Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion finetrainers/patches/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ def perform_patches_for_training(args: "BaseArgs", parallel_backend: "ParallelBa
if parallel_backend.tensor_parallel_enabled:
patch.patch_apply_rotary_emb_for_tp_compatibility()

if args.model_name == ModelType.WAN:
from .models.wan import patch

patch.patch_time_text_image_embedding_forward()

if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0:
from dependencies.peft import patch
from .dependencies.peft import patch

patch.patch_peft_move_adapter_to_device_of_base_layer()
4 changes: 2 additions & 2 deletions finetrainers/patches/models/ltx_video/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def patch_apply_rotary_emb_for_tp_compatibility() -> None:


def _perform_ltx_transformer_forward_patch() -> None:
LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3Dforward
LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3D_forward


def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None:
Expand All @@ -35,7 +35,7 @@ def apply_rotary_emb(x, freqs):
diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb


def _patched_LTXVideoTransformer3Dforward(
def _patched_LTXVideoTransformer3D_forward(
self,
hidden_states: torch.Tensor,
encoder_hidden_states: torch.Tensor,
Expand Down
33 changes: 33 additions & 0 deletions finetrainers/patches/models/wan/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import Optional

import diffusers
import torch


def patch_time_text_image_embedding_forward() -> None:
_patch_time_text_image_embedding_forward()


def _patch_time_text_image_embedding_forward() -> None:
diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding.forward = (
_patched_WanTimeTextImageEmbedding_forward
)


def _patched_WanTimeTextImageEmbedding_forward(
self,
timestep: torch.Tensor,
encoder_hidden_states: torch.Tensor,
encoder_hidden_states_image: Optional[torch.Tensor] = None,
):
# Some code has been removed compared to original implementation in Diffusers
# Also, timestep is typed as that of encoder_hidden_states
timestep = self.timesteps_proj(timestep).type_as(encoder_hidden_states)
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
timestep_proj = self.time_proj(self.act_fn(temb))

encoder_hidden_states = self.text_embedder(encoder_hidden_states)
if encoder_hidden_states_image is not None:
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)

return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image
5 changes: 3 additions & 2 deletions finetrainers/trainer/sft_trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,7 @@ def _train(self) -> None:
parallel_backend = self.state.parallel_backend
train_state = self.state.train_state
device = parallel_backend.device
dtype = self.args.transformer_dtype

memory_statistics = utils.get_memory_statistics()
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
Expand Down Expand Up @@ -447,8 +448,8 @@ def _train(self) -> None:

logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})")

utils.align_device_and_dtype(latent_model_conditions, device, self.args.transformer_dtype)
utils.align_device_and_dtype(condition_model_conditions, device, self.args.transformer_dtype)
latent_model_conditions = utils.align_device_and_dtype(latent_model_conditions, device, dtype)
condition_model_conditions = utils.align_device_and_dtype(condition_model_conditions, device, dtype)
latent_model_conditions = utils.make_contiguous(latent_model_conditions)
condition_model_conditions = utils.make_contiguous(condition_model_conditions)

Expand Down
2 changes: 2 additions & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@ TODO(aryan): everything here needs to be improved.
```
# world_size=1 tests
torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_1___batch_size_1
torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k test___layerwise_upcasting___dp_degree_1___batch_size_1
torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_1___batch_size_2

# world_size=2 tests
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_2___batch_size_1
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___layerwise_upcasting___dp_degree_2___batch_size_1
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_2___batch_size_2
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_shards_2___batch_size_1
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_shards_2___batch_size_2
Expand Down
11 changes: 10 additions & 1 deletion tests/models/cogvideox/base_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)

def load_condition_models(self):
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
text_encoder = T5EncoderModel.from_pretrained(
"hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
return {"text_encoder": text_encoder, "tokenizer": tokenizer}

Expand All @@ -44,6 +46,10 @@ def load_latent_models(self):
norm_num_groups=2,
temporal_compression_ratio=4,
)
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
# Doing so overrides things like _keep_in_fp32_modules
vae.to(self.vae_dtype)
self.vae_config = vae.config
return {"vae": vae}

def load_diffusion_models(self):
Expand All @@ -64,6 +70,9 @@ def load_diffusion_models(self):
max_text_seq_length=16,
use_rotary_positional_embeddings=True,
)
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
# Doing so overrides things like _keep_in_fp32_modules
transformer.to(self.transformer_dtype)
self.transformer_config = transformer.config
scheduler = CogVideoXDDIMScheduler()
return {"transformer": transformer, "scheduler": scheduler}
35 changes: 11 additions & 24 deletions tests/models/cogview4/base_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch
from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
from transformers import AutoTokenizer, GlmConfig, GlmModel
from transformers import AutoTokenizer, GlmModel


project_root = pathlib.Path(__file__).resolve().parents[2]
Expand All @@ -17,39 +17,26 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)

def load_condition_models(self):
text_encoder_config = GlmConfig(
hidden_size=32, intermediate_size=8, num_hidden_layers=2, num_attention_heads=4, head_dim=8
text_encoder = GlmModel.from_pretrained(
"hf-internal-testing/tiny-random-cogview4", subfolder="text_encoder", torch_dtype=self.text_encoder_dtype
)
tokenizer = AutoTokenizer.from_pretrained(
"hf-internal-testing/tiny-random-cogview4", subfolder="tokenizer", trust_remote_code=True
)
text_encoder = GlmModel(text_encoder_config)
# TODO(aryan): try to not rely on trust_remote_code by creating dummy tokenizer
tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True)
return {"text_encoder": text_encoder, "tokenizer": tokenizer}

def load_latent_models(self):
torch.manual_seed(0)
vae = AutoencoderKL(
block_out_channels=[32, 64],
in_channels=3,
out_channels=3,
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
latent_channels=4,
sample_size=128,
vae = AutoencoderKL.from_pretrained(
"hf-internal-testing/tiny-random-cogview4", subfolder="vae", torch_dtype=self.vae_dtype
)
self.vae_config = vae.config
return {"vae": vae}

def load_diffusion_models(self):
torch.manual_seed(0)
transformer = CogView4Transformer2DModel(
patch_size=2,
in_channels=4,
num_layers=2,
attention_head_dim=4,
num_attention_heads=4,
out_channels=4,
text_embed_dim=32,
time_embed_dim=8,
condition_dim=4,
transformer = CogView4Transformer2DModel.from_pretrained(
"hf-internal-testing/tiny-random-cogview4", subfolder="transformer", torch_dtype=self.transformer_dtype
)
scheduler = FlowMatchEulerDiscreteScheduler()
return {"transformer": transformer, "scheduler": scheduler}
10 changes: 10 additions & 0 deletions tests/models/hunyuan_video/base_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def load_condition_models(self):
text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")

text_encoder.to(self.text_encoder_dtype)
text_encoder_2.to(self.text_encoder_2_dtype)

return {
"tokenizer": tokenizer,
"tokenizer_2": tokenizer_2,
Expand Down Expand Up @@ -93,6 +96,10 @@ def load_latent_models(self):
temporal_compression_ratio=4,
mid_block_add_attention=True,
)
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
# Doing so overrides things like _keep_in_fp32_modules
vae.to(self.vae_dtype)
self.vae_config = vae.config
return {"vae": vae}

def load_diffusion_models(self):
Expand All @@ -112,5 +119,8 @@ def load_diffusion_models(self):
pooled_projection_dim=8,
rope_axes_dim=(2, 4, 4),
)
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
# Doing so overrides things like _keep_in_fp32_modules
transformer.to(self.transformer_dtype)
scheduler = FlowMatchEulerDiscreteScheduler()
return {"transformer": transformer, "scheduler": scheduler}
11 changes: 10 additions & 1 deletion tests/models/ltx_video/base_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)

def load_condition_models(self):
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
text_encoder = T5EncoderModel.from_pretrained(
"hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
return {"text_encoder": text_encoder, "tokenizer": tokenizer}

Expand All @@ -42,6 +44,10 @@ def load_latent_models(self):
encoder_causal=True,
decoder_causal=False,
)
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
# Doing so overrides things like _keep_in_fp32_modules
vae.to(self.vae_dtype)
self.vae_config = vae.config
return {"vae": vae}

def load_diffusion_models(self):
Expand All @@ -57,5 +63,8 @@ def load_diffusion_models(self):
num_layers=1,
caption_channels=32,
)
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
# Doing so overrides things like _keep_in_fp32_modules
transformer.to(self.transformer_dtype)
scheduler = FlowMatchEulerDiscreteScheduler()
return {"transformer": transformer, "scheduler": scheduler}
11 changes: 10 additions & 1 deletion tests/models/wan/base_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ def __init__(self, **kwargs):
super().__init__(**kwargs)

def load_condition_models(self):
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
text_encoder = T5EncoderModel.from_pretrained(
"hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype
)
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
return {"text_encoder": text_encoder, "tokenizer": tokenizer}

Expand All @@ -30,6 +32,10 @@ def load_latent_models(self):
num_res_blocks=1,
temperal_downsample=[False, True, True],
)
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
# Doing so overrides things like _keep_in_fp32_modules
vae.to(self.vae_dtype)
self.vae_config = vae.config
return {"vae": vae}

def load_diffusion_models(self):
Expand All @@ -48,5 +54,8 @@ def load_diffusion_models(self):
qk_norm="rms_norm_across_heads",
rope_max_seq_len=32,
)
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
# Doing so overrides things like _keep_in_fp32_modules
transformer.to(self.transformer_dtype)
scheduler = FlowMatchEulerDiscreteScheduler()
return {"transformer": transformer, "scheduler": scheduler}
26 changes: 25 additions & 1 deletion tests/trainer/test_sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import unittest

import pytest
import torch
from diffusers.utils import export_to_video
from parameterized import parameterized
from PIL import Image
Expand All @@ -34,7 +35,7 @@ def slow_down_tests():
# Sleep between each test so that process groups are cleaned and resources are released.
# Not doing so seems to randomly trigger some test failures, which wouldn't fail if run individually.
# !!!Look into this in future!!!
time.sleep(3)
time.sleep(5)


class SFTTrainerFastTestsMixin:
Expand Down Expand Up @@ -79,6 +80,11 @@ def setUp(self):

def tearDown(self):
self.tmpdir.cleanup()
# For some reason, if the process group is not destroyed, the tests that follow will fail. Just manually
# make sure to destroy it here.
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
time.sleep(3)

def get_base_args(self) -> BaseArgs:
args = BaseArgs()
Expand Down Expand Up @@ -121,6 +127,15 @@ def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool):
args.enable_precomputation = enable_precomputation
self._test_training(args)

@parameterized.expand([(False,), (True,)])
def test___layerwise_upcasting___dp_degree_1___batch_size_1(self, enable_precomputation: bool):
args = self.get_args()
args.dp_degree = 1
args.batch_size = 1
args.enable_precomputation = enable_precomputation
args.layerwise_upcasting_modules = ["transformer"]
self._test_training(args)

@parameterized.expand([(False,), (True,)])
def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool):
args = self.get_args()
Expand All @@ -137,6 +152,15 @@ def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool):
args.enable_precomputation = enable_precomputation
self._test_training(args)

@parameterized.expand([(False,), (True,)])
def test___layerwise_upcasting___dp_degree_2___batch_size_1(self, enable_precomputation: bool):
args = self.get_args()
args.dp_degree = 2
args.batch_size = 1
args.enable_precomputation = enable_precomputation
args.layerwise_upcasting_modules = ["transformer"]
self._test_training(args)

@parameterized.expand([(False,), (True,)])
def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool):
args = self.get_args()
Expand Down