Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 16 additions & 8 deletions src/diffusers/models/transformers/sana_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,20 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return hidden_states


class SanaModulatedNorm(nn.Module):
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
super().__init__()
self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)

def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor) -> torch.Tensor:
hidden_states = self.norm(hidden_states)
shift, scale = (
scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
return hidden_states


class SanaTransformerBlock(nn.Module):
r"""
Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
Expand Down Expand Up @@ -288,8 +302,7 @@ def __init__(

# 4. Output blocks
self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)

self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)

self.gradient_checkpointing = False
Expand Down Expand Up @@ -462,13 +475,8 @@ def custom_forward(*inputs):
)

# 3. Normalization
shift, scale = (
self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states)
hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)

# 4. Modulation
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)

# 5. Unpatchify
Expand Down
11 changes: 6 additions & 5 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import requests_mock
import torch
import torch.nn as nn
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size, compute_module_sizes
from huggingface_hub import ModelCard, delete_repo, snapshot_download
from huggingface_hub.utils import is_jinja_available
from parameterized import parameterized
Expand Down Expand Up @@ -1080,7 +1080,7 @@ def test_cpu_offload(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_persistent_sizes(model)[""]
model_size = compute_module_sizes(model)[""]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
Expand Down Expand Up @@ -1110,7 +1110,7 @@ def test_disk_offload_without_safetensors(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_persistent_sizes(model)[""]
model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir, safe_serialization=False)

Expand Down Expand Up @@ -1144,7 +1144,7 @@ def test_disk_offload_with_safetensors(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_persistent_sizes(model)[""]
model_size = compute_module_sizes(model)[""]
with tempfile.TemporaryDirectory() as tmp_dir:
model.cpu().save_pretrained(tmp_dir)

Expand Down Expand Up @@ -1172,7 +1172,7 @@ def test_model_parallelism(self):
torch.manual_seed(0)
base_output = model(**inputs_dict)

model_size = compute_module_persistent_sizes(model)[""]
model_size = compute_module_sizes(model)[""]
# We test several splits of sizes to make sure it works.
max_gpu_sizes = [int(p * model_size) for p in self.model_split_percents[1:]]
with tempfile.TemporaryDirectory() as tmp_dir:
Expand All @@ -1183,6 +1183,7 @@ def test_model_parallelism(self):
new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory)
# Making sure part of the model will actually end up offloaded
self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1})
print(f" new_model.hf_device_map:{new_model.hf_device_map}")

self.check_device_map_is_respected(new_model, new_model.hf_device_map)

Expand Down
25 changes: 1 addition & 24 deletions tests/models/transformers/test_models_transformer_sana.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ class SanaTransformerTests(ModelTesterMixin, unittest.TestCase):
model_class = SanaTransformer2DModel
main_input_name = "hidden_states"
uses_custom_attn_processor = True
model_split_percents = [0.7, 0.7, 0.9]

@property
def dummy_input(self):
Expand Down Expand Up @@ -81,27 +82,3 @@ def prepare_init_args_and_inputs_for_common(self):
def test_gradient_checkpointing_is_applied(self):
expected_set = {"SanaTransformer2DModel"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda",
reason="Test currently fails.",
strict=True,
)
def test_cpu_offload(self):
return super().test_cpu_offload()

@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda",
reason="Test currently fails.",
strict=True,
)
def test_disk_offload_with_safetensors(self):
return super().test_disk_offload_with_safetensors()

@pytest.mark.xfail(
condition=torch.device(torch_device).type == "cuda",
reason="Test currently fails.",
strict=True,
)
def test_disk_offload_without_safetensors(self):
return super().test_disk_offload_without_safetensors()
Loading