Skip to content

Commit 226c98d

Browse files
authored
Fix Layerwise Casting (#316)
* update * patch
1 parent afc1179 commit 226c98d

File tree

11 files changed

+122
-33
lines changed

11 files changed

+122
-33
lines changed

finetrainers/patches/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,12 @@ def perform_patches_for_training(args: "BaseArgs", parallel_backend: "ParallelBa
1717
if parallel_backend.tensor_parallel_enabled:
1818
patch.patch_apply_rotary_emb_for_tp_compatibility()
1919

20+
if args.model_name == ModelType.WAN:
21+
from .models.wan import patch
22+
23+
patch.patch_time_text_image_embedding_forward()
24+
2025
if args.training_type == TrainingType.LORA and len(args.layerwise_upcasting_modules) > 0:
21-
from dependencies.peft import patch
26+
from .dependencies.peft import patch
2227

2328
patch.patch_peft_move_adapter_to_device_of_base_layer()

finetrainers/patches/models/ltx_video/patch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ def patch_apply_rotary_emb_for_tp_compatibility() -> None:
1616

1717

1818
def _perform_ltx_transformer_forward_patch() -> None:
19-
LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3Dforward
19+
LTXVideoTransformer3DModel.forward = _patched_LTXVideoTransformer3D_forward
2020

2121

2222
def _perform_ltx_apply_rotary_emb_tensor_parallel_compatibility_patch() -> None:
@@ -35,7 +35,7 @@ def apply_rotary_emb(x, freqs):
3535
diffusers.models.transformers.transformer_ltx.apply_rotary_emb = apply_rotary_emb
3636

3737

38-
def _patched_LTXVideoTransformer3Dforward(
38+
def _patched_LTXVideoTransformer3D_forward(
3939
self,
4040
hidden_states: torch.Tensor,
4141
encoder_hidden_states: torch.Tensor,
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import Optional
2+
3+
import diffusers
4+
import torch
5+
6+
7+
def patch_time_text_image_embedding_forward() -> None:
8+
_patch_time_text_image_embedding_forward()
9+
10+
11+
def _patch_time_text_image_embedding_forward() -> None:
12+
diffusers.models.transformers.transformer_wan.WanTimeTextImageEmbedding.forward = (
13+
_patched_WanTimeTextImageEmbedding_forward
14+
)
15+
16+
17+
def _patched_WanTimeTextImageEmbedding_forward(
18+
self,
19+
timestep: torch.Tensor,
20+
encoder_hidden_states: torch.Tensor,
21+
encoder_hidden_states_image: Optional[torch.Tensor] = None,
22+
):
23+
# Some code has been removed compared to original implementation in Diffusers
24+
# Also, timestep is typed as that of encoder_hidden_states
25+
timestep = self.timesteps_proj(timestep).type_as(encoder_hidden_states)
26+
temb = self.time_embedder(timestep).type_as(encoder_hidden_states)
27+
timestep_proj = self.time_proj(self.act_fn(temb))
28+
29+
encoder_hidden_states = self.text_embedder(encoder_hidden_states)
30+
if encoder_hidden_states_image is not None:
31+
encoder_hidden_states_image = self.image_embedder(encoder_hidden_states_image)
32+
33+
return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image

finetrainers/trainer/sft_trainer/trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,7 @@ def _train(self) -> None:
334334
parallel_backend = self.state.parallel_backend
335335
train_state = self.state.train_state
336336
device = parallel_backend.device
337+
dtype = self.args.transformer_dtype
337338

338339
memory_statistics = utils.get_memory_statistics()
339340
logger.info(f"Memory before training start: {json.dumps(memory_statistics, indent=4)}")
@@ -447,8 +448,8 @@ def _train(self) -> None:
447448

448449
logger.debug(f"Starting training step ({train_state.step}/{self.args.train_steps})")
449450

450-
utils.align_device_and_dtype(latent_model_conditions, device, self.args.transformer_dtype)
451-
utils.align_device_and_dtype(condition_model_conditions, device, self.args.transformer_dtype)
451+
latent_model_conditions = utils.align_device_and_dtype(latent_model_conditions, device, dtype)
452+
condition_model_conditions = utils.align_device_and_dtype(condition_model_conditions, device, dtype)
452453
latent_model_conditions = utils.make_contiguous(latent_model_conditions)
453454
condition_model_conditions = utils.make_contiguous(condition_model_conditions)
454455

tests/README.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,12 @@ TODO(aryan): everything here needs to be improved.
77
```
88
# world_size=1 tests
99
torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_1___batch_size_1
10+
torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k test___layerwise_upcasting___dp_degree_1___batch_size_1
1011
torchrun --nnodes=1 --nproc_per_node 1 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_1___batch_size_2
1112
1213
# world_size=2 tests
1314
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_2___batch_size_1
15+
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___layerwise_upcasting___dp_degree_2___batch_size_1
1416
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_degree_2___batch_size_2
1517
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_shards_2___batch_size_1
1618
torchrun --nnodes=1 --nproc_per_node 2 -m pytest -s tests/trainer/test_sft_trainer.py -k test___dp_shards_2___batch_size_2

tests/models/cogvideox/base_specification.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ def __init__(self, **kwargs):
1717
super().__init__(**kwargs)
1818

1919
def load_condition_models(self):
20-
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
20+
text_encoder = T5EncoderModel.from_pretrained(
21+
"hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype
22+
)
2123
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
2224
return {"text_encoder": text_encoder, "tokenizer": tokenizer}
2325

@@ -44,6 +46,10 @@ def load_latent_models(self):
4446
norm_num_groups=2,
4547
temporal_compression_ratio=4,
4648
)
49+
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
50+
# Doing so overrides things like _keep_in_fp32_modules
51+
vae.to(self.vae_dtype)
52+
self.vae_config = vae.config
4753
return {"vae": vae}
4854

4955
def load_diffusion_models(self):
@@ -64,6 +70,9 @@ def load_diffusion_models(self):
6470
max_text_seq_length=16,
6571
use_rotary_positional_embeddings=True,
6672
)
73+
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
74+
# Doing so overrides things like _keep_in_fp32_modules
75+
transformer.to(self.transformer_dtype)
6776
self.transformer_config = transformer.config
6877
scheduler = CogVideoXDDIMScheduler()
6978
return {"transformer": transformer, "scheduler": scheduler}

tests/models/cogview4/base_specification.py

Lines changed: 11 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
import torch
55
from diffusers import AutoencoderKL, CogView4Transformer2DModel, FlowMatchEulerDiscreteScheduler
6-
from transformers import AutoTokenizer, GlmConfig, GlmModel
6+
from transformers import AutoTokenizer, GlmModel
77

88

99
project_root = pathlib.Path(__file__).resolve().parents[2]
@@ -17,39 +17,26 @@ def __init__(self, **kwargs):
1717
super().__init__(**kwargs)
1818

1919
def load_condition_models(self):
20-
text_encoder_config = GlmConfig(
21-
hidden_size=32, intermediate_size=8, num_hidden_layers=2, num_attention_heads=4, head_dim=8
20+
text_encoder = GlmModel.from_pretrained(
21+
"hf-internal-testing/tiny-random-cogview4", subfolder="text_encoder", torch_dtype=self.text_encoder_dtype
22+
)
23+
tokenizer = AutoTokenizer.from_pretrained(
24+
"hf-internal-testing/tiny-random-cogview4", subfolder="tokenizer", trust_remote_code=True
2225
)
23-
text_encoder = GlmModel(text_encoder_config)
24-
# TODO(aryan): try to not rely on trust_remote_code by creating dummy tokenizer
25-
tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True)
2626
return {"text_encoder": text_encoder, "tokenizer": tokenizer}
2727

2828
def load_latent_models(self):
2929
torch.manual_seed(0)
30-
vae = AutoencoderKL(
31-
block_out_channels=[32, 64],
32-
in_channels=3,
33-
out_channels=3,
34-
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
35-
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
36-
latent_channels=4,
37-
sample_size=128,
30+
vae = AutoencoderKL.from_pretrained(
31+
"hf-internal-testing/tiny-random-cogview4", subfolder="vae", torch_dtype=self.vae_dtype
3832
)
33+
self.vae_config = vae.config
3934
return {"vae": vae}
4035

4136
def load_diffusion_models(self):
4237
torch.manual_seed(0)
43-
transformer = CogView4Transformer2DModel(
44-
patch_size=2,
45-
in_channels=4,
46-
num_layers=2,
47-
attention_head_dim=4,
48-
num_attention_heads=4,
49-
out_channels=4,
50-
text_embed_dim=32,
51-
time_embed_dim=8,
52-
condition_dim=4,
38+
transformer = CogView4Transformer2DModel.from_pretrained(
39+
"hf-internal-testing/tiny-random-cogview4", subfolder="transformer", torch_dtype=self.transformer_dtype
5340
)
5441
scheduler = FlowMatchEulerDiscreteScheduler()
5542
return {"transformer": transformer, "scheduler": scheduler}

tests/models/hunyuan_video/base_specification.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def load_condition_models(self):
5959
text_encoder_2 = CLIPTextModel(clip_text_encoder_config)
6060
tokenizer_2 = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
6161

62+
text_encoder.to(self.text_encoder_dtype)
63+
text_encoder_2.to(self.text_encoder_2_dtype)
64+
6265
return {
6366
"tokenizer": tokenizer,
6467
"tokenizer_2": tokenizer_2,
@@ -93,6 +96,10 @@ def load_latent_models(self):
9396
temporal_compression_ratio=4,
9497
mid_block_add_attention=True,
9598
)
99+
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
100+
# Doing so overrides things like _keep_in_fp32_modules
101+
vae.to(self.vae_dtype)
102+
self.vae_config = vae.config
96103
return {"vae": vae}
97104

98105
def load_diffusion_models(self):
@@ -112,5 +119,8 @@ def load_diffusion_models(self):
112119
pooled_projection_dim=8,
113120
rope_axes_dim=(2, 4, 4),
114121
)
122+
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
123+
# Doing so overrides things like _keep_in_fp32_modules
124+
transformer.to(self.transformer_dtype)
115125
scheduler = FlowMatchEulerDiscreteScheduler()
116126
return {"transformer": transformer, "scheduler": scheduler}

tests/models/ltx_video/base_specification.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ def __init__(self, **kwargs):
1717
super().__init__(**kwargs)
1818

1919
def load_condition_models(self):
20-
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
20+
text_encoder = T5EncoderModel.from_pretrained(
21+
"hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype
22+
)
2123
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
2224
return {"text_encoder": text_encoder, "tokenizer": tokenizer}
2325

@@ -42,6 +44,10 @@ def load_latent_models(self):
4244
encoder_causal=True,
4345
decoder_causal=False,
4446
)
47+
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
48+
# Doing so overrides things like _keep_in_fp32_modules
49+
vae.to(self.vae_dtype)
50+
self.vae_config = vae.config
4551
return {"vae": vae}
4652

4753
def load_diffusion_models(self):
@@ -57,5 +63,8 @@ def load_diffusion_models(self):
5763
num_layers=1,
5864
caption_channels=32,
5965
)
66+
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
67+
# Doing so overrides things like _keep_in_fp32_modules
68+
transformer.to(self.transformer_dtype)
6069
scheduler = FlowMatchEulerDiscreteScheduler()
6170
return {"transformer": transformer, "scheduler": scheduler}

tests/models/wan/base_specification.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,9 @@ def __init__(self, **kwargs):
1717
super().__init__(**kwargs)
1818

1919
def load_condition_models(self):
20-
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
20+
text_encoder = T5EncoderModel.from_pretrained(
21+
"hf-internal-testing/tiny-random-t5", torch_dtype=self.text_encoder_dtype
22+
)
2123
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
2224
return {"text_encoder": text_encoder, "tokenizer": tokenizer}
2325

@@ -30,6 +32,10 @@ def load_latent_models(self):
3032
num_res_blocks=1,
3133
temperal_downsample=[False, True, True],
3234
)
35+
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
36+
# Doing so overrides things like _keep_in_fp32_modules
37+
vae.to(self.vae_dtype)
38+
self.vae_config = vae.config
3339
return {"vae": vae}
3440

3541
def load_diffusion_models(self):
@@ -48,5 +54,8 @@ def load_diffusion_models(self):
4854
qk_norm="rms_norm_across_heads",
4955
rope_max_seq_len=32,
5056
)
57+
# TODO(aryan): Upload dummy checkpoints to the Hub so that we don't have to do this.
58+
# Doing so overrides things like _keep_in_fp32_modules
59+
transformer.to(self.transformer_dtype)
5160
scheduler = FlowMatchEulerDiscreteScheduler()
5261
return {"transformer": transformer, "scheduler": scheduler}

0 commit comments

Comments
 (0)