Skip to content

Commit c67753e

Browse files
committed
Merge branch 'refs/heads/DC-AE' into DC-AE-Sana
# Conflicts: # src/diffusers/models/autoencoders/dc_ae.py
2 parents 1f08631 + d3d9c84 commit c67753e

File tree

5 files changed

+907
-4
lines changed

5 files changed

+907
-4
lines changed

src/diffusers/models/attention.py

Lines changed: 159 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@
1919

2020
from ..utils import deprecate, logging
2121
from ..utils.torch_utils import maybe_allow_in_graph
22-
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU
22+
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU, get_activation
2323
from .attention_processor import Attention, JointAttnProcessor2_0
2424
from .embeddings import SinusoidalPositionalEmbedding
25-
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX
25+
from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX, RMSNorm2d
2626

2727

2828
logger = logging.get_logger(__name__)
@@ -1241,3 +1241,160 @@ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
12411241
for module in self.net:
12421242
hidden_states = module(hidden_states)
12431243
return hidden_states
1244+
1245+
1246+
class DCAELiteMLA(nn.Module):
1247+
r"""Lightweight multi-scale linear attention used in DC-AE"""
1248+
1249+
def __init__(
1250+
self,
1251+
in_channels: int,
1252+
out_channels: int,
1253+
heads: Optional[int] = None,
1254+
heads_ratio: float = 1.0,
1255+
dim=8,
1256+
use_bias=(False, False),
1257+
norm=(None, "bn2d"),
1258+
act_func=(None, None),
1259+
kernel_func="relu",
1260+
scales: Tuple[int, ...] = (5,),
1261+
eps=1.0e-15,
1262+
):
1263+
super().__init__()
1264+
self.eps = eps
1265+
heads = int(in_channels // dim * heads_ratio) if heads is None else heads
1266+
1267+
total_dim = heads * dim
1268+
1269+
self.dim = dim
1270+
1271+
qkv = [nn.Conv2d(in_channels=in_channels, out_channels=3 * total_dim, kernel_size=1, bias=use_bias[0])]
1272+
if norm[0] is None:
1273+
pass
1274+
elif norm[0] == "rms2d":
1275+
qkv.append(RMSNorm2d(num_features=3 * total_dim))
1276+
elif norm[0] == "bn2d":
1277+
qkv.append(nn.BatchNorm2d(num_features=3 * total_dim))
1278+
else:
1279+
raise ValueError(f"norm {norm[0]} is not supported")
1280+
if act_func[0] is not None:
1281+
qkv.append(get_activation(act_func[0]))
1282+
self.qkv = nn.Sequential(*qkv)
1283+
1284+
self.aggreg = nn.ModuleList(
1285+
[
1286+
nn.Sequential(
1287+
nn.Conv2d(
1288+
3 * total_dim,
1289+
3 * total_dim,
1290+
scale,
1291+
padding=scale // 2,
1292+
groups=3 * total_dim,
1293+
bias=use_bias[0],
1294+
),
1295+
nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]),
1296+
)
1297+
for scale in scales
1298+
]
1299+
)
1300+
self.kernel_func = get_activation(kernel_func)
1301+
1302+
proj = [nn.Conv2d(in_channels=total_dim * (1 + len(scales)), out_channels=out_channels, kernel_size=1, bias=use_bias[1])]
1303+
if norm[1] is None:
1304+
pass
1305+
elif norm[1] == "rms2d":
1306+
proj.append(RMSNorm2d(num_features=out_channels))
1307+
elif norm[1] == "bn2d":
1308+
proj.append(nn.BatchNorm2d(num_features=out_channels))
1309+
else:
1310+
raise ValueError(f"norm {norm[1]} is not supported")
1311+
if act_func[1] is not None:
1312+
proj.append(get_activation(act_func[1]))
1313+
self.proj = nn.Sequential(*proj)
1314+
1315+
def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor:
1316+
B, _, H, W = list(qkv.size())
1317+
1318+
if qkv.dtype == torch.float16:
1319+
qkv = qkv.float()
1320+
1321+
qkv = torch.reshape(
1322+
qkv,
1323+
(
1324+
B,
1325+
-1,
1326+
3 * self.dim,
1327+
H * W,
1328+
),
1329+
)
1330+
q, k, v = (
1331+
qkv[:, :, 0 : self.dim],
1332+
qkv[:, :, self.dim : 2 * self.dim],
1333+
qkv[:, :, 2 * self.dim :],
1334+
)
1335+
1336+
# lightweight linear attention
1337+
q = self.kernel_func(q)
1338+
k = self.kernel_func(k)
1339+
1340+
# linear matmul
1341+
trans_k = k.transpose(-1, -2)
1342+
1343+
v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1)
1344+
vk = torch.matmul(v, trans_k)
1345+
out = torch.matmul(vk, q)
1346+
if out.dtype == torch.bfloat16:
1347+
out = out.float()
1348+
out = out[:, :, :-1] / (out[:, :, -1:] + self.eps)
1349+
1350+
out = torch.reshape(out, (B, -1, H, W))
1351+
return out
1352+
1353+
def relu_quadratic_att(self, qkv: torch.Tensor) -> torch.Tensor:
1354+
B, _, H, W = list(qkv.size())
1355+
1356+
qkv = torch.reshape(
1357+
qkv,
1358+
(
1359+
B,
1360+
-1,
1361+
3 * self.dim,
1362+
H * W,
1363+
),
1364+
)
1365+
q, k, v = (
1366+
qkv[:, :, 0 : self.dim],
1367+
qkv[:, :, self.dim : 2 * self.dim],
1368+
qkv[:, :, 2 * self.dim :],
1369+
)
1370+
1371+
q = self.kernel_func(q)
1372+
k = self.kernel_func(k)
1373+
1374+
att_map = torch.matmul(k.transpose(-1, -2), q) # b h n n
1375+
original_dtype = att_map.dtype
1376+
if original_dtype in [torch.float16, torch.bfloat16]:
1377+
att_map = att_map.float()
1378+
att_map = att_map / (torch.sum(att_map, dim=2, keepdim=True) + self.eps) # b h n n
1379+
att_map = att_map.to(original_dtype)
1380+
out = torch.matmul(v, att_map) # b h d n
1381+
1382+
out = torch.reshape(out, (B, -1, H, W))
1383+
return out
1384+
1385+
def forward(self, x: torch.Tensor) -> torch.Tensor:
1386+
# generate multi-scale q, k, v
1387+
qkv = self.qkv(x)
1388+
multi_scale_qkv = [qkv]
1389+
for op in self.aggreg:
1390+
multi_scale_qkv.append(op(qkv))
1391+
qkv = torch.cat(multi_scale_qkv, dim=1)
1392+
1393+
H, W = list(qkv.size())[-2:]
1394+
if H * W > self.dim:
1395+
out = self.relu_linear_att(qkv).to(qkv.dtype)
1396+
else:
1397+
out = self.relu_quadratic_att(qkv)
1398+
out = self.proj(out)
1399+
1400+
return x + out

0 commit comments

Comments
 (0)