Skip to content

Commit 4450b1c

Browse files
committed
Merge branch 'main' into layerwise-upcasting-hook
2 parents b713511 + a1f9a71 commit 4450b1c

File tree

3 files changed

+24
-39
lines changed

3 files changed

+24
-39
lines changed

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
8282
return hidden_states
8383

8484

85+
class SanaModulatedNorm(nn.Module):
86+
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
87+
super().__init__()
88+
self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)
89+
90+
def forward(
91+
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
92+
) -> torch.Tensor:
93+
hidden_states = self.norm(hidden_states)
94+
shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
95+
hidden_states = hidden_states * (1 + scale) + shift
96+
return hidden_states
97+
98+
8599
class SanaTransformerBlock(nn.Module):
86100
r"""
87101
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
@@ -221,7 +235,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
221235
"""
222236

223237
_supports_gradient_checkpointing = True
224-
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
238+
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"]
225239
_precision_sensitive_module_patterns = ["patch_embed", "norm"]
226240

227241
@register_to_config
@@ -289,8 +303,7 @@ def __init__(
289303

290304
# 4. Output blocks
291305
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
292-
293-
self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
306+
self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
294307
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
295308

296309
self.gradient_checkpointing = False
@@ -463,13 +476,8 @@ def custom_forward(*inputs):
463476
)
464477

465478
# 3. Normalization
466-
shift, scale = (
467-
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
468-
).chunk(2, dim=1)
469-
hidden_states = self.norm_out(hidden_states)
479+
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
470480

471-
# 4. Modulation
472-
hidden_states = hidden_states * (1 + scale) + shift
473481
hidden_states = self.proj_out(hidden_states)
474482

475483
# 5. Unpatchify

tests/models/test_modeling_common.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
import requests_mock
3232
import torch
3333
import torch.nn as nn
34-
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size
34+
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
3535
from huggingface_hub import ModelCard, delete_repo, snapshot_download
3636
from huggingface_hub.utils import is_jinja_available
3737
from parameterized import parameterized
@@ -1094,7 +1094,7 @@ def test_cpu_offload(self):
10941094
torch.manual_seed(0)
10951095
base_output = model(**inputs_dict)
10961096

1097-
model_size = compute_module_persistent_sizes(model)[""]
1097+
model_size = compute_module_sizes(model)[""]
10981098
# We test several splits of sizes to make sure it works.
10991099
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
11001100
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1124,7 +1124,7 @@ def test_disk_offload_without_safetensors(self):
11241124
torch.manual_seed(0)
11251125
base_output = model(**inputs_dict)
11261126

1127-
model_size = compute_module_persistent_sizes(model)[""]
1127+
model_size = compute_module_sizes(model)[""]
11281128
with tempfile.TemporaryDirectory() as tmp_dir:
11291129
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)
11301130

@@ -1158,7 +1158,7 @@ def test_disk_offload_with_safetensors(self):
11581158
torch.manual_seed(0)
11591159
base_output = model(**inputs_dict)
11601160

1161-
model_size = compute_module_persistent_sizes(model)[""]
1161+
model_size = compute_module_sizes(model)[""]
11621162
with tempfile.TemporaryDirectory() as tmp_dir:
11631163
model.cpu().save_pretrained(tmp_dir)
11641164

@@ -1186,7 +1186,7 @@ def test_model_parallelism(self):
11861186
torch.manual_seed(0)
11871187
base_output = model(**inputs_dict)
11881188

1189-
model_size = compute_module_persistent_sizes(model)[""]
1189+
model_size = compute_module_sizes(model)[""]
11901190
# We test several splits of sizes to make sure it works.
11911191
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
11921192
with tempfile.TemporaryDirectory() as tmp_dir:
@@ -1197,6 +1197,7 @@ def test_model_parallelism(self):
11971197
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
11981198
# Making sure part of the model will actually end up offloaded
11991199
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
1200+
print(f" new_model.hf_device_map:{new_model.hf_device_map}")
12001201

12011202
self.check_device_map_is_respected(new_model, new_model.hf_device_map)
12021203

tests/models/transformers/test_models_transformer_sana.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import unittest
1616

17-
import pytest
1817
import torch
1918

2019
from diffusers import SanaTransformer2DModel
@@ -33,6 +32,7 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
3332
model_class = SanaTransformer2DModel
3433
main_input_name = "hidden_states"
3534
uses_custom_attn_processor = True
35+
model_split_percents = [0.7, 0.7, 0.9]
3636

3737
@property
3838
def dummy_input(self):
@@ -81,27 +81,3 @@ def prepare_init_args_and_inputs_for_common(self):
8181
def test_gradient_checkpointing_is_applied(self):
8282
expected_set = {"SanaTransformer2DModel"}
8383
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)
84-
85-
@pytest.mark.xfail(
86-
condition=torch.device(torch_device).type == "cuda",
87-
reason="Test currently fails.",
88-
strict=True,
89-
)
90-
def test_cpu_offload(self):
91-
return super().test_cpu_offload()
92-
93-
@pytest.mark.xfail(
94-
condition=torch.device(torch_device).type == "cuda",
95-
reason="Test currently fails.",
96-
strict=True,
97-
)
98-
def test_disk_offload_with_safetensors(self):
99-
return super().test_disk_offload_with_safetensors()
100-
101-
@pytest.mark.xfail(
102-
condition=torch.device(torch_device).type == "cuda",
103-
reason="Test currently fails.",
104-
strict=True,
105-
)
106-
def test_disk_offload_without_safetensors(self):
107-
return super().test_disk_offload_without_safetensors()

0 commit comments

Comments
 (0)