Skip to content

Commit 30c3238

Browse files
committed
move mla to attention processor file; split qkv conv to linears
1 parent 4a224ce commit 30c3238

File tree

4 files changed

+180
-166
lines changed

4 files changed

+180
-166
lines changed

scripts/convert_dcae_to_diffusers.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ def remove_keys_(key: str, state_dict: Dict[str, Any]):
1313

1414

1515
def remap_qkv_(key: str, state_dict: Dict[str, Any]):
16-
# qkv = state_dict.pop(key)
17-
# q, k, v = torch.chunk(qkv, 3, dim=0)
18-
# parent_module, _, _ = key.rpartition(".qkv.conv.weight")
19-
# state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
20-
# state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
21-
# state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
22-
state_dict[key.replace("qkv.conv", "to_qkv")] = state_dict.pop(key)
16+
qkv = state_dict.pop(key)
17+
q, k, v = torch.chunk(qkv, 3, dim=0)
18+
parent_module, _, _ = key.rpartition(".qkv.conv.weight")
19+
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
20+
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
21+
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
22+
# state_dict[key.replace("qkv.conv", "to_qkv")] = state_dict.pop(key)
2323

2424

2525
AE_KEYS_RENAME_DICT = {

src/diffusers/models/attention_processor.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,161 @@ def fuse_projections(self, fuse=True):
752752
self.fused_projections = fuse
753753

754754

755+
class MultiscaleAttentionProjection(nn.Module):
756+
def __init__(
757+
self,
758+
in_channels: int,
759+
num_attention_heads: int,
760+
kernel_size: int,
761+
) -> None:
762+
super().__init__()
763+
764+
self.proj_in = nn.Conv2d(
765+
3 * in_channels,
766+
3 * in_channels,
767+
kernel_size,
768+
padding=kernel_size // 2,
769+
groups=3 * in_channels,
770+
bias=False,
771+
)
772+
self.proj_out = nn.Conv2d(
773+
3 * in_channels, 3 * in_channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False
774+
)
775+
776+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
777+
hidden_states = self.proj_in(hidden_states)
778+
hidden_states = self.proj_out(hidden_states)
779+
return hidden_states
780+
781+
782+
class MultiscaleLinearAttention(nn.Module):
783+
r"""Lightweight multi-scale linear attention"""
784+
785+
def __init__(
786+
self,
787+
in_channels: int,
788+
out_channels: int,
789+
num_attention_heads: Optional[int] = None,
790+
heads_ratio: float = 1.0,
791+
attention_head_dim: int = 8,
792+
norm_type: str = "batch_norm",
793+
kernel_sizes: Tuple[int, ...] = (5,),
794+
eps: float = 1e-15,
795+
):
796+
super().__init__()
797+
798+
# To prevent circular import
799+
from .normalization import get_normalization
800+
801+
self.eps = eps
802+
self.attention_head_dim = attention_head_dim
803+
self.norm_type = norm_type
804+
805+
num_attention_heads = (
806+
int(in_channels // attention_head_dim * heads_ratio)
807+
if num_attention_heads is None
808+
else num_attention_heads
809+
)
810+
inner_dim = num_attention_heads * attention_head_dim
811+
812+
# self.to_qkv = nn.Conv2d(in_channels, 3 * inner_dim, 1, 1, 0, bias=False)
813+
self.to_q = nn.Linear(in_channels, inner_dim, bias=False)
814+
self.to_k = nn.Linear(in_channels, inner_dim, bias=False)
815+
self.to_v = nn.Linear(in_channels, inner_dim, bias=False)
816+
817+
self.to_qkv_multiscale = nn.ModuleList()
818+
for kernel_size in kernel_sizes:
819+
self.to_qkv_multiscale.append(MultiscaleAttentionProjection(inner_dim, num_attention_heads, kernel_size))
820+
821+
self.kernel_nonlinearity = nn.ReLU()
822+
self.proj_out = nn.Conv2d(inner_dim * (1 + len(kernel_sizes)), out_channels, 1, 1, 0, bias=False)
823+
self.norm_out = get_normalization(norm_type, num_features=out_channels)
824+
825+
def linear_attention(self, qkv: torch.Tensor) -> torch.Tensor:
826+
batch_size, _, height, width = qkv.shape
827+
828+
qkv = qkv.float()
829+
qkv = torch.reshape(qkv, (batch_size, -1, 3 * self.attention_head_dim, height * width))
830+
831+
query, key, value = (
832+
qkv[:, :, 0 : self.attention_head_dim],
833+
qkv[:, :, self.attention_head_dim : 2 * self.attention_head_dim],
834+
qkv[:, :, 2 * self.attention_head_dim :],
835+
)
836+
837+
# lightweight linear attention
838+
query = self.kernel_nonlinearity(query)
839+
key = self.kernel_nonlinearity(key)
840+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1)
841+
842+
key_T = key.transpose(-1, -2)
843+
scores = torch.matmul(value, key_T)
844+
output = torch.matmul(scores, query)
845+
846+
output = output.float()
847+
output = output[:, :, :-1] / (output[:, :, -1:] + self.eps)
848+
output = torch.reshape(output, (batch_size, -1, height, width))
849+
850+
return output
851+
852+
def quadratic_attention(self, qkv: torch.Tensor) -> torch.Tensor:
853+
batch_size, _, height, width = list(qkv.size())
854+
855+
qkv = torch.reshape(qkv, (batch_size, -1, 3 * self.attention_head_dim, height * width))
856+
query, key, value = (
857+
qkv[:, :, 0 : self.attention_head_dim],
858+
qkv[:, :, self.attention_head_dim : 2 * self.attention_head_dim],
859+
qkv[:, :, 2 * self.attention_head_dim :],
860+
)
861+
862+
query = self.kernel_nonlinearity(query)
863+
key = self.kernel_nonlinearity(key)
864+
865+
scores = torch.matmul(key.transpose(-1, -2), query)
866+
867+
original_dtype = scores.dtype
868+
scores = scores.float()
869+
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
870+
scores = scores.to(original_dtype)
871+
872+
output = torch.matmul(value, scores)
873+
output = torch.reshape(output, (batch_size, -1, height, width))
874+
875+
return output
876+
877+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
878+
residual = hidden_states
879+
880+
# qkv = self.to_qkv(hidden_states)
881+
hidden_states = hidden_states.movedim(1, 3)
882+
query = self.to_q(hidden_states)
883+
key = self.to_k(hidden_states)
884+
value = self.to_v(hidden_states)
885+
qkv = torch.cat([query, key, value], dim=3)
886+
qkv = qkv.movedim(3, 1)
887+
888+
multi_scale_qkv = [qkv]
889+
for block in self.to_qkv_multiscale:
890+
multi_scale_qkv.append(block(qkv))
891+
892+
qkv = torch.cat(multi_scale_qkv, dim=1)
893+
894+
height, width = qkv.shape[-2:]
895+
if height * width > self.attention_head_dim:
896+
hidden_states = self.linear_attention(qkv).to(qkv.dtype)
897+
else:
898+
hidden_states = self.quadratic_attention(qkv)
899+
900+
hidden_states = self.proj_out(hidden_states)
901+
902+
if self.norm_type == "rms_norm":
903+
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
904+
else:
905+
hidden_states = self.norm_out(hidden_states)
906+
907+
return hidden_states + residual
908+
909+
755910
class AttnProcessor:
756911
r"""
757912
Default processor for performing attention-related computations.

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 6 additions & 159 deletions
Original file line numberDiff line numberDiff line change
@@ -21,20 +21,9 @@
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ..activations import get_activation
24+
from ..attention_processor import MultiscaleLinearAttention
2425
from ..modeling_utils import ModelMixin
25-
from ..normalization import RMSNorm
26-
27-
28-
def get_norm_layer(name: Optional[str] = "batch_norm", num_features: Optional[int] = None) -> Optional[nn.Module]:
29-
if name is None:
30-
norm = None
31-
elif name == "rms_norm":
32-
norm = RMSNorm(num_features, eps=1e-5, elementwise_affine=True, bias=True)
33-
elif name == "batch_norm":
34-
norm = nn.BatchNorm2d(num_features=num_features)
35-
else:
36-
raise ValueError(f"norm {name} is not supported")
37-
return norm
26+
from ..normalization import RMSNorm, get_normalization
3827

3928

4029
class GLUMBConv(nn.Module):
@@ -81,7 +70,7 @@ def __init__(
8170
self.nonlinearity = get_activation(act_fn) if act_fn is not None else nn.Identity()
8271
self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
8372
self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False)
84-
self.norm = get_norm_layer(norm_type, out_channels)
73+
self.norm = get_normalization(norm_type, out_channels)
8574

8675
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
8776
residual = hidden_states
@@ -93,149 +82,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
9382
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
9483
else:
9584
hidden_states = self.norm(hidden_states)
96-
return hidden_states + residual
97-
98-
99-
class MLAProjection(nn.Module):
100-
def __init__(
101-
self,
102-
in_channels: int,
103-
num_attention_heads: int,
104-
kernel_size: int,
105-
) -> None:
106-
super().__init__()
107-
108-
self.proj_in = nn.Conv2d(
109-
3 * in_channels,
110-
3 * in_channels,
111-
kernel_size,
112-
padding=kernel_size // 2,
113-
groups=3 * in_channels,
114-
bias=False,
115-
)
116-
self.proj_out = nn.Conv2d(
117-
3 * in_channels, 3 * in_channels, 1, 1, 0, groups=3 * num_attention_heads, bias=False
118-
)
119-
120-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
121-
hidden_states = self.proj_in(hidden_states)
122-
hidden_states = self.proj_out(hidden_states)
123-
return hidden_states
124-
125-
126-
class LiteMLA(nn.Module):
127-
r"""Lightweight multi-scale linear attention"""
128-
129-
def __init__(
130-
self,
131-
in_channels: int,
132-
out_channels: int,
133-
num_attention_heads: Optional[int] = None,
134-
heads_ratio: float = 1.0,
135-
attention_head_dim: int = 8,
136-
norm_type: str = "batch_norm",
137-
kernel_sizes: Tuple[int, ...] = (5,),
138-
eps: float = 1e-15,
139-
):
140-
super().__init__()
141-
142-
self.eps = eps
143-
self.attention_head_dim = attention_head_dim
144-
self.norm_type = norm_type
145-
146-
num_attention_heads = (
147-
int(in_channels // attention_head_dim * heads_ratio)
148-
if num_attention_heads is None
149-
else num_attention_heads
150-
)
151-
inner_dim = num_attention_heads * attention_head_dim
152-
153-
self.to_qkv = nn.Conv2d(in_channels, 3 * inner_dim, 1, 1, 0, bias=False)
154-
155-
self.to_qkv_multiscale = nn.ModuleList()
156-
for kernel_size in kernel_sizes:
157-
self.to_qkv_multiscale.append(MLAProjection(inner_dim, num_attention_heads, kernel_size))
158-
159-
self.kernel_nonlinearity = nn.ReLU()
160-
self.proj_out = nn.Conv2d(inner_dim * (1 + len(kernel_sizes)), out_channels, 1, 1, 0, bias=False)
161-
self.norm_out = get_norm_layer(norm_type, num_features=out_channels)
162-
163-
def linear_attention(self, qkv: torch.Tensor) -> torch.Tensor:
164-
batch_size, _, height, width = qkv.shape
165-
166-
qkv = qkv.float()
167-
qkv = torch.reshape(qkv, (batch_size, -1, 3 * self.attention_head_dim, height * width))
168-
169-
query, key, value = (
170-
qkv[:, :, 0 : self.attention_head_dim],
171-
qkv[:, :, self.attention_head_dim : 2 * self.attention_head_dim],
172-
qkv[:, :, 2 * self.attention_head_dim :],
173-
)
174-
175-
# lightweight linear attention
176-
query = self.kernel_nonlinearity(query)
177-
key = self.kernel_nonlinearity(key)
178-
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1)
179-
180-
key_T = key.transpose(-1, -2)
181-
scores = torch.matmul(value, key_T)
182-
output = torch.matmul(scores, query)
183-
184-
output = output.float()
185-
output = output[:, :, :-1] / (output[:, :, -1:] + self.eps)
186-
output = torch.reshape(output, (batch_size, -1, height, width))
187-
188-
return output
189-
190-
def quadratic_attention(self, qkv: torch.Tensor) -> torch.Tensor:
191-
batch_size, _, height, width = list(qkv.size())
192-
193-
qkv = torch.reshape(qkv, (batch_size, -1, 3 * self.attention_head_dim, height * width))
194-
query, key, value = (
195-
qkv[:, :, 0 : self.attention_head_dim],
196-
qkv[:, :, self.attention_head_dim : 2 * self.attention_head_dim],
197-
qkv[:, :, 2 * self.attention_head_dim :],
198-
)
199-
200-
query = self.kernel_nonlinearity(query)
201-
key = self.kernel_nonlinearity(key)
202-
203-
scores = torch.matmul(key.transpose(-1, -2), query)
204-
205-
original_dtype = scores.dtype
206-
scores = scores.float()
207-
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
208-
scores = scores.to(original_dtype)
209-
210-
output = torch.matmul(value, scores)
211-
output = torch.reshape(output, (batch_size, -1, height, width))
212-
213-
return output
214-
215-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
216-
residual = hidden_states
217-
218-
qkv = self.to_qkv(hidden_states)
219-
220-
multi_scale_qkv = [qkv]
221-
for block in self.to_qkv_multiscale:
222-
multi_scale_qkv.append(block(qkv))
223-
224-
qkv = torch.cat(multi_scale_qkv, dim=1)
225-
226-
height, width = qkv.shape[-2:]
227-
if height * width > self.attention_head_dim:
228-
hidden_states = self.linear_attention(qkv).to(qkv.dtype)
229-
else:
230-
hidden_states = self.quadratic_attention(qkv)
231-
232-
hidden_states = self.proj_out(hidden_states)
233-
234-
if self.norm_type == "rms_norm":
235-
hidden_states = self.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
236-
else:
237-
hidden_states = self.norm_out(hidden_states)
238-
85+
23986
return hidden_states + residual
24087

24188

@@ -247,10 +94,10 @@ def __init__(
24794
dim: int = 32,
24895
qkv_multiscales: Tuple[int, ...] = (5,),
24996
norm_type: str = "batch_norm",
250-
):
97+
) -> None:
25198
super().__init__()
25299

253-
self.attn = LiteMLA(
100+
self.attn = MultiscaleLinearAttention(
254101
in_channels=in_channels,
255102
out_channels=in_channels,
256103
heads_ratio=heads_ratio,

src/diffusers/models/normalization.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,3 +572,15 @@ def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12):
572572

573573
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
574574
return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps)
575+
576+
577+
def get_normalization(norm_type: str = "batch_norm", num_features: Optional[int] = None, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True) -> nn.Module:
578+
if norm_type == "rms_norm":
579+
norm = RMSNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
580+
elif norm_type == "layer_norm":
581+
norm = nn.LayerNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
582+
elif norm_type == "batch_norm":
583+
norm = nn.BatchNorm2d(num_features, eps=eps, affine=elementwise_affine)
584+
else:
585+
raise ValueError(f"{norm_type=} is not supported.")
586+
return norm

0 commit comments

Comments
 (0)