diff --git a/finetrainers/patches/__init__.py b/finetrainers/patches/__init__.py index 0d499f4d..1deb0729 100644 --- a/finetrainers/patches/__init__.py +++ b/finetrainers/patches/__init__.py @@ -17,7 +17,12 @@ def perform_patches_for_training(args: "BaseArgs", parallel_backend: "ParallelBa if parallel_backend.tensor_parallel_enabled: patch.patch_apply_rotary_emb_for_tp_compatibility() + if args.model_name == ModelType.WAN: + from .models.wan import patch + + patch.patch_time_text_image_embedding_forward() + if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0: - from dependencies.peft import patch + from .dependencies.peft import patch patch.patch_peft_move_adapter_to_device_of_base_layer() diff --git a/finetrainers/patches/models/ltx_video/patch.py b/finetrainers/patches/models/ltx_video/patch.py index 851da6e7..9e8caa80 100644 --- a/finetrainers/patches/models/ltx_video/patch.py +++ b/finetrainers/patches/models/ltx_video/patch.py @@ -16,7 +16,7 @@ def patch_apply_rotary_emb_for_tp_compatibility() -> None: def _perform_ltx_transformer_forward_patch() -> None: - LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3Dforward + LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3D_forward def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None: @@ -35,7 +35,7 @@ def apply_rotary_emb(x, freqs): diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb -def _patched_LTXVideoTransformer3Dforward( +def _patched_LTXVideoTransformer3D_forward( self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, diff --git a/finetrainers/patches/models/wan/patch.py b/finetrainers/patches/models/wan/patch.py new file mode 100644 index 00000000..e5c44ae4 --- /dev/null +++ b/finetrainers/patches/models/wan/patch.py @@ -0,0 +1,33 @@ +from typing import Optional + +import diffusers +import torch + + +def patch_time_text_image_embedding_forward() -> None: + _patch_time_text_image_embedding_forward() + + +def _patch_time_text_image_embedding_forward() -> None: + diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding.forward = ( + _patched_WanTimeTextImageEmbedding_forward + ) + + +def _patched_WanTimeTextImageEmbedding_forward( + self, + timestep: torch.Tensor, + encoder_hidden_states: torch.Tensor, + encoder_hidden_states_image: Optional[torch.Tensor] = None, +): + # Some code has been removed compared to original implementation in Diffusers + # Also, timestep is typed as that of encoder_hidden_states + timestep = self.timesteps_proj(timestep).type_as(encoder_hidden_states) + temb = self.time_embedder(timestep).type_as(encoder_hidden_states) + timestep_proj = self.time_proj(self.act_fn(temb)) + + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + if encoder_hidden_states_image is not None: + encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image) + + return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image diff --git a/finetrainers/trainer/sft_trainer/trainer.py b/finetrainers/trainer/sft_trainer/trainer.py index 34d2a528..0ec83c81 100644 --- a/finetrainers/trainer/sft_trainer/trainer.py +++ b/finetrainers/trainer/sft_trainer/trainer.py @@ -334,6 +334,7 @@ def _train(self) -> None: parallel_backend = self.state.parallel_backend train_state = self.state.train_state device = parallel_backend.device + dtype = self.args.transformer_dtype memory_statistics = utils.get_memory_statistics() logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}") @@ -447,8 +448,8 @@ def _train(self) -> None: logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})") - utils.align_device_and_dtype(latent_model_conditions, device, self.args.transformer_dtype) - utils.align_device_and_dtype(condition_model_conditions, device, self.args.transformer_dtype) + latent_model_conditions = utils.align_device_and_dtype(latent_model_conditions, device, dtype) + condition_model_conditions = utils.align_device_and_dtype(condition_model_conditions, device, dtype) latent_model_conditions = utils.make_contiguous(latent_model_conditions) condition_model_conditions = utils.make_contiguous(condition_model_conditions) diff --git a/tests/README.md b/tests/README.md index bd189b1b..b3efa278 100644 --- a/tests/README.md +++ b/tests/README.md @@ -7,10 +7,12 @@ TODO(aryan): everything here needs to be improved. ``` # world_size=1 tests torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_1___batch_size_1 +torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k test___layerwise_upcasting___dp_degree_1___batch_size_1 torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_1___batch_size_2 # world_size=2 tests torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_2___batch_size_1 +torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___layerwise_upcasting___dp_degree_2___batch_size_1 torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_2___batch_size_2 torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_shards_2___batch_size_1 torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_shards_2___batch_size_2 diff --git a/tests/models/cogvideox/base_specification.py b/tests/models/cogvideox/base_specification.py index 0a66a9fb..338a4c4b 100644 --- a/tests/models/cogvideox/base_specification.py +++ b/tests/models/cogvideox/base_specification.py @@ -17,7 +17,9 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def load_condition_models(self): - text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + text_encoder = T5EncoderModel.from_pretrained( + "hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype + ) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") return {"text_encoder": text_encoder, "tokenizer": tokenizer} @@ -44,6 +46,10 @@ def load_latent_models(self): norm_num_groups=2, temporal_compression_ratio=4, ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + vae.to(self.vae_dtype) + self.vae_config = vae.config return {"vae": vae} def load_diffusion_models(self): @@ -64,6 +70,9 @@ def load_diffusion_models(self): max_text_seq_length=16, use_rotary_positional_embeddings=True, ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + transformer.to(self.transformer_dtype) self.transformer_config = transformer.config scheduler = CogVideoXDDIMScheduler() return {"transformer": transformer, "scheduler": scheduler} diff --git a/tests/models/cogview4/base_specification.py b/tests/models/cogview4/base_specification.py index 2cc906e5..90178be4 100644 --- a/tests/models/cogview4/base_specification.py +++ b/tests/models/cogview4/base_specification.py @@ -3,7 +3,7 @@ import torch from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler -from transformers import AutoTokenizer, GlmConfig, GlmModel +from transformers import AutoTokenizer, GlmModel project_root = pathlib.Path(__file__).resolve().parents[2] @@ -17,39 +17,26 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def load_condition_models(self): - text_encoder_config = GlmConfig( - hidden_size=32, intermediate_size=8, num_hidden_layers=2, num_attention_heads=4, head_dim=8 + text_encoder = GlmModel.from_pretrained( + "hf-internal-testing/tiny-random-cogview4", subfolder="text_encoder", torch_dtype=self.text_encoder_dtype + ) + tokenizer = AutoTokenizer.from_pretrained( + "hf-internal-testing/tiny-random-cogview4", subfolder="tokenizer", trust_remote_code=True ) - text_encoder = GlmModel(text_encoder_config) - # TODO(aryan): try to not rely on trust_remote_code by creating dummy tokenizer - tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True) return {"text_encoder": text_encoder, "tokenizer": tokenizer} def load_latent_models(self): torch.manual_seed(0) - vae = AutoencoderKL( - block_out_channels=[32, 64], - in_channels=3, - out_channels=3, - down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"], - up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"], - latent_channels=4, - sample_size=128, + vae = AutoencoderKL.from_pretrained( + "hf-internal-testing/tiny-random-cogview4", subfolder="vae", torch_dtype=self.vae_dtype ) + self.vae_config = vae.config return {"vae": vae} def load_diffusion_models(self): torch.manual_seed(0) - transformer = CogView4Transformer2DModel( - patch_size=2, - in_channels=4, - num_layers=2, - attention_head_dim=4, - num_attention_heads=4, - out_channels=4, - text_embed_dim=32, - time_embed_dim=8, - condition_dim=4, + transformer = CogView4Transformer2DModel.from_pretrained( + "hf-internal-testing/tiny-random-cogview4", subfolder="transformer", torch_dtype=self.transformer_dtype ) scheduler = FlowMatchEulerDiscreteScheduler() return {"transformer": transformer, "scheduler": scheduler} diff --git a/tests/models/hunyuan_video/base_specification.py b/tests/models/hunyuan_video/base_specification.py index e76b749e..b35064a2 100644 --- a/tests/models/hunyuan_video/base_specification.py +++ b/tests/models/hunyuan_video/base_specification.py @@ -59,6 +59,9 @@ def load_condition_models(self): text_encoder_2 = CLIPTextModel(clip_text_encoder_config) tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip") + text_encoder.to(self.text_encoder_dtype) + text_encoder_2.to(self.text_encoder_2_dtype) + return { "tokenizer": tokenizer, "tokenizer_2": tokenizer_2, @@ -93,6 +96,10 @@ def load_latent_models(self): temporal_compression_ratio=4, mid_block_add_attention=True, ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + vae.to(self.vae_dtype) + self.vae_config = vae.config return {"vae": vae} def load_diffusion_models(self): @@ -112,5 +119,8 @@ def load_diffusion_models(self): pooled_projection_dim=8, rope_axes_dim=(2, 4, 4), ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + transformer.to(self.transformer_dtype) scheduler = FlowMatchEulerDiscreteScheduler() return {"transformer": transformer, "scheduler": scheduler} diff --git a/tests/models/ltx_video/base_specification.py b/tests/models/ltx_video/base_specification.py index e21bebf7..6fc6e689 100644 --- a/tests/models/ltx_video/base_specification.py +++ b/tests/models/ltx_video/base_specification.py @@ -17,7 +17,9 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def load_condition_models(self): - text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + text_encoder = T5EncoderModel.from_pretrained( + "hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype + ) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") return {"text_encoder": text_encoder, "tokenizer": tokenizer} @@ -42,6 +44,10 @@ def load_latent_models(self): encoder_causal=True, decoder_causal=False, ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + vae.to(self.vae_dtype) + self.vae_config = vae.config return {"vae": vae} def load_diffusion_models(self): @@ -57,5 +63,8 @@ def load_diffusion_models(self): num_layers=1, caption_channels=32, ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + transformer.to(self.transformer_dtype) scheduler = FlowMatchEulerDiscreteScheduler() return {"transformer": transformer, "scheduler": scheduler} diff --git a/tests/models/wan/base_specification.py b/tests/models/wan/base_specification.py index cf1fc32e..2f046d35 100644 --- a/tests/models/wan/base_specification.py +++ b/tests/models/wan/base_specification.py @@ -17,7 +17,9 @@ def __init__(self, **kwargs): super().__init__(**kwargs) def load_condition_models(self): - text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5") + text_encoder = T5EncoderModel.from_pretrained( + "hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype + ) tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5") return {"text_encoder": text_encoder, "tokenizer": tokenizer} @@ -30,6 +32,10 @@ def load_latent_models(self): num_res_blocks=1, temperal_downsample=[False, True, True], ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + vae.to(self.vae_dtype) + self.vae_config = vae.config return {"vae": vae} def load_diffusion_models(self): @@ -48,5 +54,8 @@ def load_diffusion_models(self): qk_norm="rms_norm_across_heads", rope_max_seq_len=32, ) + # TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this. + # Doing so overrides things like _keep_in_fp32_modules + transformer.to(self.transformer_dtype) scheduler = FlowMatchEulerDiscreteScheduler() return {"transformer": transformer, "scheduler": scheduler} diff --git a/tests/trainer/test_sft_trainer.py b/tests/trainer/test_sft_trainer.py index 591cce18..498f52f1 100644 --- a/tests/trainer/test_sft_trainer.py +++ b/tests/trainer/test_sft_trainer.py @@ -8,6 +8,7 @@ import unittest import pytest +import torch from diffusers.utils import export_to_video from parameterized import parameterized from PIL import Image @@ -34,7 +35,7 @@ def slow_down_tests(): # Sleep between each test so that process groups are cleaned and resources are released. # Not doing so seems to randomly trigger some test failures, which wouldn't fail if run individually. # !!!Look into this in future!!! - time.sleep(3) + time.sleep(5) class SFTTrainerFastTestsMixin: @@ -79,6 +80,11 @@ def setUp(self): def tearDown(self): self.tmpdir.cleanup() + # For some reason, if the process group is not destroyed, the tests that follow will fail. Just manually + # make sure to destroy it here. + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + time.sleep(3) def get_base_args(self) -> BaseArgs: args = BaseArgs() @@ -121,6 +127,15 @@ def test___dp_degree_1___batch_size_1(self, enable_precomputation: bool): args.enable_precomputation = enable_precomputation self._test_training(args) + @parameterized.expand([(False,), (True,)]) + def test___layerwise_upcasting___dp_degree_1___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 1 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + args.layerwise_upcasting_modules = ["transformer"] + self._test_training(args) + @parameterized.expand([(False,), (True,)]) def test___dp_degree_1___batch_size_2(self, enable_precomputation: bool): args = self.get_args() @@ -137,6 +152,15 @@ def test___dp_degree_2___batch_size_1(self, enable_precomputation: bool): args.enable_precomputation = enable_precomputation self._test_training(args) + @parameterized.expand([(False,), (True,)]) + def test___layerwise_upcasting___dp_degree_2___batch_size_1(self, enable_precomputation: bool): + args = self.get_args() + args.dp_degree = 2 + args.batch_size = 1 + args.enable_precomputation = enable_precomputation + args.layerwise_upcasting_modules = ["transformer"] + self._test_training(args) + @parameterized.expand([(False,), (True,)]) def test___dp_degree_2___batch_size_2(self, enable_precomputation: bool): args = self.get_args()