|  | 
| 19 | 19 | 
 | 
| 20 | 20 | from ..utils import deprecate, logging | 
| 21 | 21 | from ..utils.torch_utils import maybe_allow_in_graph | 
| 22 |  | -from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU | 
|  | 22 | +from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, SwiGLU, get_activation | 
| 23 | 23 | from .attention_processor import Attention, JointAttnProcessor2_0 | 
| 24 | 24 | from .embeddings import SinusoidalPositionalEmbedding | 
| 25 |  | -from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX | 
|  | 25 | +from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm, SD35AdaLayerNormZeroX, RMSNorm2d | 
| 26 | 26 | 
 | 
| 27 | 27 | 
 | 
| 28 | 28 | logger = logging.get_logger(__name__) | 
| @@ -1241,3 +1241,160 @@ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: | 
| 1241 | 1241 |         for module in self.net: | 
| 1242 | 1242 |             hidden_states = module(hidden_states) | 
| 1243 | 1243 |         return hidden_states | 
|  | 1244 | + | 
|  | 1245 | + | 
|  | 1246 | +class DCAELiteMLA(nn.Module): | 
|  | 1247 | +    r"""Lightweight multi-scale linear attention used in DC-AE""" | 
|  | 1248 | + | 
|  | 1249 | +    def __init__( | 
|  | 1250 | +        self, | 
|  | 1251 | +        in_channels: int, | 
|  | 1252 | +        out_channels: int, | 
|  | 1253 | +        heads: Optional[int] = None, | 
|  | 1254 | +        heads_ratio: float = 1.0, | 
|  | 1255 | +        dim=8, | 
|  | 1256 | +        use_bias=(False, False), | 
|  | 1257 | +        norm=(None, "bn2d"), | 
|  | 1258 | +        act_func=(None, None), | 
|  | 1259 | +        kernel_func="relu", | 
|  | 1260 | +        scales: Tuple[int, ...] = (5,), | 
|  | 1261 | +        eps=1.0e-15, | 
|  | 1262 | +    ): | 
|  | 1263 | +        super().__init__() | 
|  | 1264 | +        self.eps = eps | 
|  | 1265 | +        heads = int(in_channels // dim * heads_ratio) if heads is None else heads | 
|  | 1266 | + | 
|  | 1267 | +        total_dim = heads * dim | 
|  | 1268 | + | 
|  | 1269 | +        self.dim = dim | 
|  | 1270 | +         | 
|  | 1271 | +        qkv = [nn.Conv2d(in_channels=in_channels, out_channels=3 * total_dim, kernel_size=1, bias=use_bias[0])] | 
|  | 1272 | +        if norm[0] is None: | 
|  | 1273 | +            pass | 
|  | 1274 | +        elif norm[0] == "rms2d": | 
|  | 1275 | +            qkv.append(RMSNorm2d(num_features=3 * total_dim)) | 
|  | 1276 | +        elif norm[0] == "bn2d": | 
|  | 1277 | +            qkv.append(nn.BatchNorm2d(num_features=3 * total_dim)) | 
|  | 1278 | +        else: | 
|  | 1279 | +            raise ValueError(f"norm {norm[0]} is not supported") | 
|  | 1280 | +        if act_func[0] is not None: | 
|  | 1281 | +            qkv.append(get_activation(act_func[0])) | 
|  | 1282 | +        self.qkv = nn.Sequential(*qkv) | 
|  | 1283 | + | 
|  | 1284 | +        self.aggreg = nn.ModuleList( | 
|  | 1285 | +            [ | 
|  | 1286 | +                nn.Sequential( | 
|  | 1287 | +                    nn.Conv2d( | 
|  | 1288 | +                        3 * total_dim, | 
|  | 1289 | +                        3 * total_dim, | 
|  | 1290 | +                        scale, | 
|  | 1291 | +                        padding=scale // 2, | 
|  | 1292 | +                        groups=3 * total_dim, | 
|  | 1293 | +                        bias=use_bias[0], | 
|  | 1294 | +                    ), | 
|  | 1295 | +                    nn.Conv2d(3 * total_dim, 3 * total_dim, 1, groups=3 * heads, bias=use_bias[0]), | 
|  | 1296 | +                ) | 
|  | 1297 | +                for scale in scales | 
|  | 1298 | +            ] | 
|  | 1299 | +        ) | 
|  | 1300 | +        self.kernel_func = get_activation(kernel_func) | 
|  | 1301 | + | 
|  | 1302 | +        proj = [nn.Conv2d(in_channels=total_dim * (1 + len(scales)), out_channels=out_channels, kernel_size=1, bias=use_bias[1])] | 
|  | 1303 | +        if norm[1] is None: | 
|  | 1304 | +            pass | 
|  | 1305 | +        elif norm[1] == "rms2d": | 
|  | 1306 | +            proj.append(RMSNorm2d(num_features=out_channels)) | 
|  | 1307 | +        elif norm[1] == "bn2d": | 
|  | 1308 | +            proj.append(nn.BatchNorm2d(num_features=out_channels)) | 
|  | 1309 | +        else: | 
|  | 1310 | +            raise ValueError(f"norm {norm[1]} is not supported") | 
|  | 1311 | +        if act_func[1] is not None: | 
|  | 1312 | +            proj.append(get_activation(act_func[1])) | 
|  | 1313 | +        self.proj = nn.Sequential(*proj) | 
|  | 1314 | + | 
|  | 1315 | +    def relu_linear_att(self, qkv: torch.Tensor) -> torch.Tensor: | 
|  | 1316 | +        B, _, H, W = list(qkv.size()) | 
|  | 1317 | + | 
|  | 1318 | +        if qkv.dtype == torch.float16: | 
|  | 1319 | +            qkv = qkv.float() | 
|  | 1320 | + | 
|  | 1321 | +        qkv = torch.reshape( | 
|  | 1322 | +            qkv, | 
|  | 1323 | +            ( | 
|  | 1324 | +                B, | 
|  | 1325 | +                -1, | 
|  | 1326 | +                3 * self.dim, | 
|  | 1327 | +                H * W, | 
|  | 1328 | +            ), | 
|  | 1329 | +        ) | 
|  | 1330 | +        q, k, v = ( | 
|  | 1331 | +            qkv[:, :, 0 : self.dim], | 
|  | 1332 | +            qkv[:, :, self.dim : 2 * self.dim], | 
|  | 1333 | +            qkv[:, :, 2 * self.dim :], | 
|  | 1334 | +        ) | 
|  | 1335 | + | 
|  | 1336 | +        # lightweight linear attention | 
|  | 1337 | +        q = self.kernel_func(q) | 
|  | 1338 | +        k = self.kernel_func(k) | 
|  | 1339 | + | 
|  | 1340 | +        # linear matmul | 
|  | 1341 | +        trans_k = k.transpose(-1, -2) | 
|  | 1342 | + | 
|  | 1343 | +        v = F.pad(v, (0, 0, 0, 1), mode="constant", value=1) | 
|  | 1344 | +        vk = torch.matmul(v, trans_k) | 
|  | 1345 | +        out = torch.matmul(vk, q) | 
|  | 1346 | +        if out.dtype == torch.bfloat16: | 
|  | 1347 | +            out = out.float() | 
|  | 1348 | +        out = out[:, :, :-1] / (out[:, :, -1:] + self.eps) | 
|  | 1349 | + | 
|  | 1350 | +        out = torch.reshape(out, (B, -1, H, W)) | 
|  | 1351 | +        return out | 
|  | 1352 | + | 
|  | 1353 | +    def relu_quadratic_att(self, qkv: torch.Tensor) -> torch.Tensor: | 
|  | 1354 | +        B, _, H, W = list(qkv.size()) | 
|  | 1355 | + | 
|  | 1356 | +        qkv = torch.reshape( | 
|  | 1357 | +            qkv, | 
|  | 1358 | +            ( | 
|  | 1359 | +                B, | 
|  | 1360 | +                -1, | 
|  | 1361 | +                3 * self.dim, | 
|  | 1362 | +                H * W, | 
|  | 1363 | +            ), | 
|  | 1364 | +        ) | 
|  | 1365 | +        q, k, v = ( | 
|  | 1366 | +            qkv[:, :, 0 : self.dim], | 
|  | 1367 | +            qkv[:, :, self.dim : 2 * self.dim], | 
|  | 1368 | +            qkv[:, :, 2 * self.dim :], | 
|  | 1369 | +        ) | 
|  | 1370 | + | 
|  | 1371 | +        q = self.kernel_func(q) | 
|  | 1372 | +        k = self.kernel_func(k) | 
|  | 1373 | + | 
|  | 1374 | +        att_map = torch.matmul(k.transpose(-1, -2), q)  # b h n n | 
|  | 1375 | +        original_dtype = att_map.dtype | 
|  | 1376 | +        if original_dtype in [torch.float16, torch.bfloat16]: | 
|  | 1377 | +            att_map = att_map.float() | 
|  | 1378 | +        att_map = att_map / (torch.sum(att_map, dim=2, keepdim=True) + self.eps)  # b h n n | 
|  | 1379 | +        att_map = att_map.to(original_dtype) | 
|  | 1380 | +        out = torch.matmul(v, att_map)  # b h d n | 
|  | 1381 | + | 
|  | 1382 | +        out = torch.reshape(out, (B, -1, H, W)) | 
|  | 1383 | +        return out | 
|  | 1384 | + | 
|  | 1385 | +    def forward(self, x: torch.Tensor) -> torch.Tensor: | 
|  | 1386 | +        # generate multi-scale q, k, v | 
|  | 1387 | +        qkv = self.qkv(x) | 
|  | 1388 | +        multi_scale_qkv = [qkv] | 
|  | 1389 | +        for op in self.aggreg: | 
|  | 1390 | +            multi_scale_qkv.append(op(qkv)) | 
|  | 1391 | +        qkv = torch.cat(multi_scale_qkv, dim=1) | 
|  | 1392 | + | 
|  | 1393 | +        H, W = list(qkv.size())[-2:] | 
|  | 1394 | +        if H * W > self.dim: | 
|  | 1395 | +            out = self.relu_linear_att(qkv).to(qkv.dtype) | 
|  | 1396 | +        else: | 
|  | 1397 | +            out = self.relu_quadratic_att(qkv) | 
|  | 1398 | +        out = self.proj(out) | 
|  | 1399 | + | 
|  | 1400 | +        return x + out | 
0 commit comments