Skip to content

Commit becbcd6

Browse files
committed
style
1 parent 42d3a6a commit becbcd6

File tree

3 files changed

+7
-8
lines changed

3 files changed

+7
-8
lines changed

src/diffusers/models/transformers/sana_transformer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,12 @@ class SanaModulatedNorm(nn.Module):
8686
def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
8787
super().__init__()
8888
self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)
89-
90-
def forward(self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor) -> torch.Tensor:
89+
90+
def forward(
91+
self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
92+
) -> torch.Tensor:
9193
hidden_states = self.norm(hidden_states)
92-
shift, scale = (
93-
scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)
94-
).chunk(2, dim=1)
94+
shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
9595
hidden_states = hidden_states * (1 + scale) + shift
9696
return hidden_states
9797

@@ -235,7 +235,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
235235
"""
236236

237237
_supports_gradient_checkpointing = True
238-
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
238+
_no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"]
239239

240240
@register_to_config
241241
def __init__(

tests/models/test_modeling_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
import requests_mock
3030
import torch
3131
import torch.nn as nn
32-
from accelerate.utils.modeling import _get_proper_dtype, dtype_byte_size, compute_module_sizes
32+
from accelerate.utils.modeling import _get_proper_dtype, compute_module_sizes, dtype_byte_size
3333
from huggingface_hub import ModelCard, delete_repo, snapshot_download
3434
from huggingface_hub.utils import is_jinja_available
3535
from parameterized import parameterized

tests/models/transformers/test_models_transformer_sana.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
import unittest
1616

17-
import pytest
1817
import torch
1918

2019
from diffusers import SanaTransformer2DModel

0 commit comments

Comments
 (0)