Skip to content

Commit 39a947c

Browse files
committed
refactor
1 parent 30c3238 commit 39a947c

File tree

4 files changed

+140
-95
lines changed

4 files changed

+140
-95
lines changed

scripts/convert_dcae_to_diffusers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,11 @@ def remap_qkv_(key: str, state_dict: Dict[str, Any]):
1919
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
2020
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
2121
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
22-
# state_dict[key.replace("qkv.conv", "to_qkv")] = state_dict.pop(key)
22+
23+
24+
def remap_proj_conv_(key: str, state_dict: Dict[str, Any]):
25+
parent_module, _, _ = key.rpartition(".proj.conv.weight")
26+
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
2327

2428

2529
AE_KEYS_RENAME_DICT = {
@@ -40,7 +44,6 @@ def remap_qkv_(key: str, state_dict: Dict[str, Any]):
4044
"conv1.conv": "conv1",
4145
"conv2.conv": "conv2",
4246
"conv2.norm": "norm",
43-
"proj.conv": "proj_out",
4447
"proj.norm": "norm_out",
4548
# encoder
4649
"encoder.project_in.conv": "encoder.conv_in",
@@ -76,6 +79,7 @@ def remap_qkv_(key: str, state_dict: Dict[str, Any]):
7679

7780
AE_SPECIAL_KEYS_REMAP = {
7881
"qkv.conv.weight": remap_qkv_,
82+
"proj.conv.weight": remap_proj_conv_,
7983
}
8084

8185

src/diffusers/models/attention_processor.py

Lines changed: 122 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -752,7 +752,7 @@ def fuse_projections(self, fuse=True):
752752
self.fused_projections = fuse
753753

754754

755-
class MultiscaleAttentionProjection(nn.Module):
755+
class SanaMultiscaleAttentionProjection(nn.Module):
756756
def __init__(
757757
self,
758758
in_channels: int,
@@ -761,25 +761,24 @@ def __init__(
761761
) -> None:
762762
super().__init__()
763763

764+
channels = 3 * in_channels
764765
self.proj_in = nn.Conv2d(
765-
3 * in_channels,
766-
3 * in_channels,
766+
channels,
767+
channels,
767768
kernel_size,
768769
padding=kernel_size // 2,
769770
groups=3 * in_channels,
770771
bias=False,
771772
)
772-
self.proj_out = nn.Conv2d(
773-
3 * in_channels, 3 * in_channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False
774-
)
773+
self.proj_out = nn.Conv2d(channels, channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False)
775774

776775
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
777776
hidden_states = self.proj_in(hidden_states)
778777
hidden_states = self.proj_out(hidden_states)
779778
return hidden_states
780779

781780

782-
class MultiscaleLinearAttention(nn.Module):
781+
class SanaMultiscaleLinearAttention(nn.Module):
783782
r"""Lightweight multi-scale linear attention"""
784783

785784
def __init__(
@@ -792,6 +791,7 @@ def __init__(
792791
norm_type: str = "batch_norm",
793792
kernel_sizes: Tuple[int, ...] = (5,),
794793
eps: float = 1e-15,
794+
residual_connection: bool = False,
795795
):
796796
super().__init__()
797797

@@ -801,6 +801,7 @@ def __init__(
801801
self.eps = eps
802802
self.attention_head_dim = attention_head_dim
803803
self.norm_type = norm_type
804+
self.residual_connection = residual_connection
804805

805806
num_attention_heads = (
806807
int(in_channels // attention_head_dim * heads_ratio)
@@ -809,102 +810,32 @@ def __init__(
809810
)
810811
inner_dim = num_attention_heads * attention_head_dim
811812

812-
# self.to_qkv = nn.Conv2d(in_channels, 3 * inner_dim, 1, 1, 0, bias=False)
813813
self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
814814
self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
815815
self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
816816

817817
self.to_qkv_multiscale = nn.ModuleList()
818818
for kernel_size in kernel_sizes:
819-
self.to_qkv_multiscale.append(MultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size))
819+
self.to_qkv_multiscale.append(
820+
SanaMultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size)
821+
)
820822

821-
self.kernel_nonlinearity = nn.ReLU()
822-
self.proj_out = nn.Conv2d(inner_dim * (1 + len(kernel_sizes)), out_channels, 1, 1, 0, bias=False)
823+
self.nonlinearity = nn.ReLU()
824+
self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
823825
self.norm_out = get_normalization(norm_type, num_features=out_channels)
824826

825-
def linear_attention(self, qkv: torch.Tensor) -> torch.Tensor:
826-
batch_size, _, height, width = qkv.shape
827-
828-
qkv = qkv.float()
829-
qkv = torch.reshape(qkv, (batch_size, -1, 3 * self.attention_head_dim, height * width))
830-
831-
query, key, value = (
832-
qkv[:, :, 0 : self.attention_head_dim],
833-
qkv[:, :, self.attention_head_dim : 2 * self.attention_head_dim],
834-
qkv[:, :, 2 * self.attention_head_dim :],
835-
)
836-
837-
# lightweight linear attention
838-
query = self.kernel_nonlinearity(query)
839-
key = self.kernel_nonlinearity(key)
840-
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1)
841-
842-
key_T = key.transpose(-1, -2)
843-
scores = torch.matmul(value, key_T)
844-
output = torch.matmul(scores, query)
845-
846-
output = output.float()
847-
output = output[:, :, :-1] / (output[:, :, -1:] + self.eps)
848-
output = torch.reshape(output, (batch_size, -1, height, width))
849-
850-
return output
851-
852-
def quadratic_attention(self, qkv: torch.Tensor) -> torch.Tensor:
853-
batch_size, _, height, width = list(qkv.size())
854-
855-
qkv = torch.reshape(qkv, (batch_size, -1, 3 * self.attention_head_dim, height * width))
856-
query, key, value = (
857-
qkv[:, :, 0 : self.attention_head_dim],
858-
qkv[:, :, self.attention_head_dim : 2 * self.attention_head_dim],
859-
qkv[:, :, 2 * self.attention_head_dim :],
860-
)
861-
862-
query = self.kernel_nonlinearity(query)
863-
key = self.kernel_nonlinearity(key)
864-
865-
scores = torch.matmul(key.transpose(-1, -2), query)
866-
867-
original_dtype = scores.dtype
868-
scores = scores.float()
869-
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
870-
scores = scores.to(original_dtype)
871-
872-
output = torch.matmul(value, scores)
873-
output = torch.reshape(output, (batch_size, -1, height, width))
874-
875-
return output
827+
self.processor = SanaMultiscaleLinearAttnProcessor2_0()
828+
self.processor_quadratic = SanaMultiscaleQuadraticAttnProcessor2_0()
876829

877830
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
878-
residual = hidden_states
879-
880-
# qkv = self.to_qkv(hidden_states)
881-
hidden_states = hidden_states.movedim(1, 3)
882-
query = self.to_q(hidden_states)
883-
key = self.to_k(hidden_states)
884-
value = self.to_v(hidden_states)
885-
qkv = torch.cat([query, key, value], dim=3)
886-
qkv = qkv.movedim(3, 1)
887-
888-
multi_scale_qkv = [qkv]
889-
for block in self.to_qkv_multiscale:
890-
multi_scale_qkv.append(block(qkv))
831+
height, width = hidden_states.shape[-2:]
891832

892-
qkv = torch.cat(multi_scale_qkv, dim=1)
893-
894-
height, width = qkv.shape[-2:]
895833
if height * width > self.attention_head_dim:
896-
hidden_states = self.linear_attention(qkv).to(qkv.dtype)
897-
else:
898-
hidden_states = self.quadratic_attention(qkv)
899-
900-
hidden_states = self.proj_out(hidden_states)
901-
902-
if self.norm_type == "rms_norm":
903-
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
834+
hidden_states = self.processor(self, hidden_states)
904835
else:
905-
hidden_states = self.norm_out(hidden_states)
836+
hidden_states = self.processor_quadratic(self, hidden_states)
906837

907-
return hidden_states + residual
838+
return hidden_states
908839

909840

910841
class AttnProcessor:
@@ -5160,6 +5091,109 @@ def __call__(
51605091
return hidden_states
51615092

51625093

5094+
class SanaMultiscaleLinearAttnProcessor2_0:
5095+
r"""
5096+
Processor for implementing multiscale linear attention.
5097+
"""
5098+
5099+
def __call__(self, attn: SanaMultiscaleLinearAttention, hidden_states: torch.Tensor) -> torch.Tensor:
5100+
residual = hidden_states
5101+
5102+
batch_size, _, height, width = hidden_states.shape
5103+
original_dtype = hidden_states.dtype
5104+
5105+
hidden_states = hidden_states.movedim(1, -1)
5106+
query = attn.to_q(hidden_states)
5107+
key = attn.to_k(hidden_states)
5108+
value = attn.to_v(hidden_states)
5109+
hidden_states = torch.cat([query, key, value], dim=3)
5110+
hidden_states = hidden_states.movedim(-1, 1)
5111+
5112+
multiscale_hidden_states = [hidden_states]
5113+
for block in attn.to_qkv_multiscale:
5114+
multiscale_hidden_states.append(block(hidden_states))
5115+
5116+
hidden_states = torch.cat(multiscale_hidden_states, dim=1)
5117+
5118+
hidden_states = hidden_states.to(dtype=torch.float32)
5119+
hidden_states = hidden_states.reshape(batch_size, -1, 3 * attn.attention_head_dim, height * width)
5120+
5121+
query, key, value = hidden_states.chunk(3, dim=2)
5122+
query = attn.nonlinearity(query)
5123+
key = attn.nonlinearity(key)
5124+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1)
5125+
5126+
scores = torch.matmul(value, key.transpose(-1, -2))
5127+
hidden_states = torch.matmul(scores, query)
5128+
5129+
hidden_states = hidden_states.to(dtype=torch.float32)
5130+
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + attn.eps)
5131+
hidden_states = hidden_states.to(dtype=original_dtype)
5132+
5133+
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
5134+
hidden_states = attn.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
5135+
5136+
if attn.norm_type == "rms_norm":
5137+
hidden_states = attn.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
5138+
else:
5139+
hidden_states = attn.norm_out(hidden_states)
5140+
5141+
if attn.residual_connection:
5142+
hidden_states = hidden_states + residual
5143+
5144+
return hidden_states
5145+
5146+
5147+
class SanaMultiscaleQuadraticAttnProcessor2_0:
5148+
r"""
5149+
Processor for implementing multiscale quadratic attention.
5150+
"""
5151+
5152+
def __call__(self, attn: SanaMultiscaleLinearAttention, hidden_states: torch.Tensor) -> torch.Tensor:
5153+
residual = hidden_states
5154+
5155+
batch_size, _, height, width = list(hidden_states.size())
5156+
original_dtype = hidden_states.dtype
5157+
5158+
hidden_states = hidden_states.movedim(1, -1)
5159+
query = attn.to_q(hidden_states)
5160+
key = attn.to_k(hidden_states)
5161+
value = attn.to_v(hidden_states)
5162+
hidden_states = torch.cat([query, key, value], dim=3)
5163+
hidden_states = hidden_states.movedim(-1, 1)
5164+
5165+
multi_scale_qkv = [hidden_states]
5166+
for block in attn.to_qkv_multiscale:
5167+
multi_scale_qkv.append(block(hidden_states))
5168+
5169+
hidden_states = torch.cat(multi_scale_qkv, dim=1)
5170+
5171+
hidden_states = hidden_states.reshape(batch_size, -1, 3 * attn.attention_head_dim, height * width)
5172+
5173+
query, key, value = hidden_states.chunk(3, dim=2)
5174+
query = attn.nonlinearity(query)
5175+
key = attn.nonlinearity(key)
5176+
5177+
scores = torch.matmul(key.transpose(-1, -2), query)
5178+
scores = scores.to(dtype=torch.float32)
5179+
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + attn.eps)
5180+
scores = scores.to(dtype=original_dtype)
5181+
hidden_states = torch.matmul(value, scores)
5182+
5183+
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
5184+
hidden_states = attn.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
5185+
5186+
if attn.norm_type == "rms_norm":
5187+
hidden_states = attn.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
5188+
else:
5189+
hidden_states = attn.norm_out(hidden_states)
5190+
5191+
if attn.residual_connection:
5192+
hidden_states = hidden_states + residual
5193+
5194+
return hidden_states
5195+
5196+
51635197
class LoRAAttnProcessor:
51645198
def __init__(self):
51655199
pass

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,15 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16-
from typing import Optional, Tuple, Union
16+
from typing import Tuple, Union
1717

1818
import torch
1919
import torch.nn as nn
2020
import torch.nn.functional as F
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ..activations import get_activation
24-
from ..attention_processor import MultiscaleLinearAttention
24+
from ..attention_processor import SanaMultiscaleLinearAttention
2525
from ..modeling_utils import ModelMixin
2626
from ..normalization import RMSNorm, get_normalization
2727

@@ -82,7 +82,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
8282
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
8383
else:
8484
hidden_states = self.norm(hidden_states)
85-
85+
8686
return hidden_states + residual
8787

8888

@@ -97,13 +97,14 @@ def __init__(
9797
) -> None:
9898
super().__init__()
9999

100-
self.attn = MultiscaleLinearAttention(
100+
self.attn = SanaMultiscaleLinearAttention(
101101
in_channels=in_channels,
102102
out_channels=in_channels,
103103
heads_ratio=heads_ratio,
104104
attention_head_dim=dim,
105105
norm_type=norm_type,
106106
kernel_sizes=qkv_multiscales,
107+
residual_connection=True,
107108
)
108109

109110
self.conv_out = GLUMBConv(

src/diffusers/models/normalization.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -574,7 +574,13 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
574574
return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps)
575575

576576

577-
def get_normalization(norm_type: str = "batch_norm", num_features: Optional[int] = None, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True) -> nn.Module:
577+
def get_normalization(
578+
norm_type: str = "batch_norm",
579+
num_features: Optional[int] = None,
580+
eps: float = 1e-5,
581+
elementwise_affine: bool = True,
582+
bias: bool = True,
583+
) -> nn.Module:
578584
if norm_type == "rms_norm":
579585
norm = RMSNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
580586
elif norm_type == "layer_norm":

0 commit comments

Comments
 (0)