Skip to content

Commit d3d9c84

Browse files
committed
move LiteMLA to attention.py
1 parent e007057 commit d3d9c84

File tree

3 files changed

+185
-173
lines changed

3 files changed

+185
-173
lines changed

src/diffusers/models/attention.py

Lines changed: 159 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919

2020
from ..utils import deprecate, logging
2121
from ..utils.torch_utils import maybe_allow_in_graph
22-
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
22+
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU, get_activation
2323
from .attention_processor import Attention, JointAttnProcessor2_0
2424
from .embeddings import SinusoidalPositionalEmbedding
25-
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
25+
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX, RMSNorm2d
2626

2727

2828
logger = logging.get_logger(__name__)
@@ -1241,3 +1241,160 @@ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
12411241
for module in self.net:
12421242
hidden_states = module(hidden_states)
12431243
return hidden_states
1244+
1245+
1246+
class DCAELiteMLA(nn.Module):
1247+
r"""Lightweight multi-scale linear attention used in DC-AE"""
1248+
1249+
def __init__(
1250+
self,
1251+
in_channels: int,
1252+
out_channels: int,
1253+
heads: Optional[int] = None,
1254+
heads_ratio: float = 1.0,
1255+
dim=8,
1256+
use_bias=(False, False),
1257+
norm=(None, "bn2d"),
1258+
act_func=(None, None),
1259+
kernel_func="relu",
1260+
scales: Tuple[int, ...] = (5,),
1261+
eps=1.0e-15,
1262+
):
1263+
super().__init__()
1264+
self.eps = eps
1265+
heads = int(in_channels // dim * heads_ratio) if heads is None else heads
1266+
1267+
total_dim = heads * dim
1268+
1269+
self.dim = dim
1270+
1271+
qkv = [nn.Conv2d(in_channels=in_channels, out_channels=3 * total_dim, kernel_size=1, bias=use_bias[0])]
1272+
if norm[0] is None:
1273+
pass
1274+
elif norm[0] == "rms2d":
1275+
qkv.append(RMSNorm2d(num_features=3 * total_dim))
1276+
elif norm[0] == "bn2d":
1277+
qkv.append(nn.BatchNorm2d(num_features=3 * total_dim))
1278+
else:
1279+
raise ValueError(f"norm {norm[0]} is not supported")
1280+
if act_func[0] is not None:
1281+
qkv.append(get_activation(act_func[0]))
1282+
self.qkv = nn.Sequential(*qkv)
1283+
1284+
self.aggreg = nn.ModuleList(
1285+
[
1286+
nn.Sequential(
1287+
nn.Conv2d(
1288+
3 * total_dim,
1289+
3 * total_dim,
1290+
scale,
1291+
padding=scale // 2,
1292+
groups=3 * total_dim,
1293+
bias=use_bias[0],
1294+
),
1295+
nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
1296+
)
1297+
for scale in scales
1298+
]
1299+
)
1300+
self.kernel_func = get_activation(kernel_func)
1301+
1302+
proj = [nn.Conv2d(in_channels=total_dim * (1 + len(scales)), out_channels=out_channels, kernel_size=1, bias=use_bias[1])]
1303+
if norm[1] is None:
1304+
pass
1305+
elif norm[1] == "rms2d":
1306+
proj.append(RMSNorm2d(num_features=out_channels))
1307+
elif norm[1] == "bn2d":
1308+
proj.append(nn.BatchNorm2d(num_features=out_channels))
1309+
else:
1310+
raise ValueError(f"norm {norm[1]} is not supported")
1311+
if act_func[1] is not None:
1312+
proj.append(get_activation(act_func[1]))
1313+
self.proj = nn.Sequential(*proj)
1314+
1315+
def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
1316+
B, _, H, W = list(qkv.size())
1317+
1318+
if qkv.dtype == torch.float16:
1319+
qkv = qkv.float()
1320+
1321+
qkv = torch.reshape(
1322+
qkv,
1323+
(
1324+
B,
1325+
-1,
1326+
3 * self.dim,
1327+
H * W,
1328+
),
1329+
)
1330+
q, k, v = (
1331+
qkv[:, :, 0 : self.dim],
1332+
qkv[:, :, self.dim : 2 * self.dim],
1333+
qkv[:, :, 2 * self.dim :],
1334+
)
1335+
1336+
# lightweight linear attention
1337+
q = self.kernel_func(q)
1338+
k = self.kernel_func(k)
1339+
1340+
# linear matmul
1341+
trans_k = k.transpose(-1, -2)
1342+
1343+
v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1)
1344+
vk = torch.matmul(v, trans_k)
1345+
out = torch.matmul(vk, q)
1346+
if out.dtype == torch.bfloat16:
1347+
out = out.float()
1348+
out = out[:, :, :-1] / (out[:, :, -1:] + self.eps)
1349+
1350+
out = torch.reshape(out, (B, -1, H, W))
1351+
return out
1352+
1353+
def relu_quadratic_att(self, qkv: torch.Tensor) -> torch.Tensor:
1354+
B, _, H, W = list(qkv.size())
1355+
1356+
qkv = torch.reshape(
1357+
qkv,
1358+
(
1359+
B,
1360+
-1,
1361+
3 * self.dim,
1362+
H * W,
1363+
),
1364+
)
1365+
q, k, v = (
1366+
qkv[:, :, 0 : self.dim],
1367+
qkv[:, :, self.dim : 2 * self.dim],
1368+
qkv[:, :, 2 * self.dim :],
1369+
)
1370+
1371+
q = self.kernel_func(q)
1372+
k = self.kernel_func(k)
1373+
1374+
att_map = torch.matmul(k.transpose(-1, -2), q) # b h n n
1375+
original_dtype = att_map.dtype
1376+
if original_dtype in [torch.float16, torch.bfloat16]:
1377+
att_map = att_map.float()
1378+
att_map = att_map / (torch.sum(att_map, dim=2, keepdim=True) + self.eps) # b h n n
1379+
att_map = att_map.to(original_dtype)
1380+
out = torch.matmul(v, att_map) # b h d n
1381+
1382+
out = torch.reshape(out, (B, -1, H, W))
1383+
return out
1384+
1385+
def forward(self, x: torch.Tensor) -> torch.Tensor:
1386+
# generate multi-scale q, k, v
1387+
qkv = self.qkv(x)
1388+
multi_scale_qkv = [qkv]
1389+
for op in self.aggreg:
1390+
multi_scale_qkv.append(op(qkv))
1391+
qkv = torch.cat(multi_scale_qkv, dim=1)
1392+
1393+
H, W = list(qkv.size())[-2:]
1394+
if H * W > self.dim:
1395+
out = self.relu_linear_att(qkv).to(qkv.dtype)
1396+
else:
1397+
out = self.relu_quadratic_att(qkv)
1398+
out = self.proj(out)
1399+
1400+
return x + out

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 3 additions & 171 deletions
Original file line numberDiff line numberDiff line change
@@ -27,35 +27,14 @@
2727
from ..modeling_utils import ModelMixin
2828

2929
from ..activations import get_activation
30+
from ..normalization import RMSNorm2d
3031
from ..downsampling import ConvPixelUnshuffleDownsample2D, PixelUnshuffleChannelAveragingDownsample2D
3132
from ..upsampling import ConvPixelShuffleUpsample2D, ChannelDuplicatingPixelUnshuffleUpsample2D, Upsample2D
33+
from ..attention import DCAELiteMLA
3234

3335
from .vae import DecoderOutput
3436

3537

36-
class RMSNorm2d(nn.Module):
37-
def __init__(self, num_features: int, eps: float = 1e-5, elementwise_affine: bool = True, bias: bool = True) -> None:
38-
super().__init__()
39-
self.num_features = num_features
40-
self.eps = eps
41-
self.elementwise_affine = elementwise_affine
42-
if self.elementwise_affine:
43-
self.weight = torch.nn.parameter.Parameter(torch.empty(self.num_features))
44-
if bias:
45-
self.bias = torch.nn.parameter.Parameter(torch.empty(self.num_features))
46-
else:
47-
self.register_parameter('bias', None)
48-
else:
49-
self.register_parameter('weight', None)
50-
self.register_parameter('bias', None)
51-
52-
def forward(self, x: torch.Tensor) -> torch.Tensor:
53-
x = (x / torch.sqrt(torch.square(x.float()).mean(dim=1, keepdim=True) + self.eps)).to(x.dtype)
54-
if self.elementwise_affine:
55-
x = x * self.weight.view(1, -1, 1, 1) + self.bias.view(1, -1, 1, 1)
56-
return x
57-
58-
5938
class ConvLayer(nn.Module):
6039
def __init__(
6140
self,
@@ -205,153 +184,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
205184
return x
206185

207186

208-
class LiteMLA(nn.Module):
209-
r"""Lightweight multi-scale linear attention"""
210-
211-
def __init__(
212-
self,
213-
in_channels: int,
214-
out_channels: int,
215-
heads: Optional[int] = None,
216-
heads_ratio: float = 1.0,
217-
dim=8,
218-
use_bias=(False, False),
219-
norm=(None, "bn2d"),
220-
act_func=(None, None),
221-
kernel_func="relu",
222-
scales: tuple[int, ...] = (5,),
223-
eps=1.0e-15,
224-
):
225-
super().__init__()
226-
self.eps = eps
227-
heads = int(in_channels // dim * heads_ratio) if heads is None else heads
228-
229-
total_dim = heads * dim
230-
231-
self.dim = dim
232-
self.qkv = ConvLayer(
233-
in_channels,
234-
3 * total_dim,
235-
1,
236-
use_bias=use_bias[0],
237-
norm=norm[0],
238-
act_func=act_func[0],
239-
)
240-
self.aggreg = nn.ModuleList(
241-
[
242-
nn.Sequential(
243-
nn.Conv2d(
244-
3 * total_dim,
245-
3 * total_dim,
246-
scale,
247-
padding=scale // 2,
248-
groups=3 * total_dim,
249-
bias=use_bias[0],
250-
),
251-
nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
252-
)
253-
for scale in scales
254-
]
255-
)
256-
self.kernel_func = get_activation(kernel_func)
257-
258-
self.proj = ConvLayer(
259-
total_dim * (1 + len(scales)),
260-
out_channels,
261-
1,
262-
use_bias=use_bias[1],
263-
norm=norm[1],
264-
act_func=act_func[1],
265-
)
266-
267-
def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
268-
B, _, H, W = list(qkv.size())
269-
270-
if qkv.dtype == torch.float16:
271-
qkv = qkv.float()
272-
273-
qkv = torch.reshape(
274-
qkv,
275-
(
276-
B,
277-
-1,
278-
3 * self.dim,
279-
H * W,
280-
),
281-
)
282-
q, k, v = (
283-
qkv[:, :, 0 : self.dim],
284-
qkv[:, :, self.dim : 2 * self.dim],
285-
qkv[:, :, 2 * self.dim :],
286-
)
287-
288-
# lightweight linear attention
289-
q = self.kernel_func(q)
290-
k = self.kernel_func(k)
291-
292-
# linear matmul
293-
trans_k = k.transpose(-1, -2)
294-
295-
v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1)
296-
vk = torch.matmul(v, trans_k)
297-
out = torch.matmul(vk, q)
298-
if out.dtype == torch.bfloat16:
299-
out = out.float()
300-
out = out[:, :, :-1] / (out[:, :, -1:] + self.eps)
301-
302-
out = torch.reshape(out, (B, -1, H, W))
303-
return out
304-
305-
def relu_quadratic_att(self, qkv: torch.Tensor) -> torch.Tensor:
306-
B, _, H, W = list(qkv.size())
307-
308-
qkv = torch.reshape(
309-
qkv,
310-
(
311-
B,
312-
-1,
313-
3 * self.dim,
314-
H * W,
315-
),
316-
)
317-
q, k, v = (
318-
qkv[:, :, 0 : self.dim],
319-
qkv[:, :, self.dim : 2 * self.dim],
320-
qkv[:, :, 2 * self.dim :],
321-
)
322-
323-
q = self.kernel_func(q)
324-
k = self.kernel_func(k)
325-
326-
att_map = torch.matmul(k.transpose(-1, -2), q) # b h n n
327-
original_dtype = att_map.dtype
328-
if original_dtype in [torch.float16, torch.bfloat16]:
329-
att_map = att_map.float()
330-
att_map = att_map / (torch.sum(att_map, dim=2, keepdim=True) + self.eps) # b h n n
331-
att_map = att_map.to(original_dtype)
332-
out = torch.matmul(v, att_map) # b h d n
333-
334-
out = torch.reshape(out, (B, -1, H, W))
335-
return out
336-
337-
def forward(self, x: torch.Tensor) -> torch.Tensor:
338-
# generate multi-scale q, k, v
339-
qkv = self.qkv(x)
340-
multi_scale_qkv = [qkv]
341-
for op in self.aggreg:
342-
multi_scale_qkv.append(op(qkv))
343-
qkv = torch.cat(multi_scale_qkv, dim=1)
344-
345-
H, W = list(qkv.size())[-2:]
346-
if H * W > self.dim:
347-
out = self.relu_linear_att(qkv).to(qkv.dtype)
348-
else:
349-
out = self.relu_quadratic_att(qkv)
350-
out = self.proj(out)
351-
352-
return x + out
353-
354-
355187
class EfficientViTBlock(nn.Module):
356188
def __init__(
357189
self,
@@ -367,7 +199,7 @@ def __init__(
367199
):
368200
super().__init__()
369201
if context_module == "LiteMLA":
370-
self.context_module = LiteMLA(
202+
self.context_module = DCAELiteMLA(
371203
in_channels=in_channels,
372204
out_channels=in_channels,
373205
heads_ratio=heads_ratio,

0 commit comments

Comments
 (0)