|
1 | | -# Copyright (C) 2024 Intel Corporation |
| 1 | +# Copyright (C) 2024-2025 Intel Corporation |
2 | 2 | # SPDX-License-Identifier: Apache-2.0 |
3 | 3 | # |
4 | 4 | """Implementation of common transformer layers.""" |
|
10 | 10 | from typing import Callable |
11 | 11 |
|
12 | 12 | import torch |
| 13 | +import torch.nn.functional as f |
13 | 14 | from otx.algo.common.utils.utils import get_clones |
14 | 15 | from otx.algo.modules.transformer import deformable_attention_core_func |
15 | 16 | from torch import Tensor, nn |
@@ -306,6 +307,151 @@ def forward( |
306 | 307 | return self.output_proj(output) |
307 | 308 |
|
308 | 309 |
|
| 310 | +class MSDeformableAttentionV2(nn.Module): |
| 311 | + """Multi-Scale Deformable Attention Module V2. |
| 312 | +
|
| 313 | + Note: |
| 314 | + This is different from vanilla MSDeformableAttention where it uses |
| 315 | + distinct number of sampling points for features at different scales. |
| 316 | + Refer to RTDETRv2. |
| 317 | +
|
| 318 | + Args: |
| 319 | + embed_dim (int): The number of expected features in the input. |
| 320 | + num_heads (int): The number of heads in the multiheadattention models. |
| 321 | + num_levels (int): The number of levels in MSDeformableAttention. |
| 322 | + num_points_list (list[int]): Number of distinct points for each layer. Defaults to [3, 6, 3]. |
| 323 | + """ |
| 324 | + |
| 325 | + def __init__( |
| 326 | + self, |
| 327 | + embed_dim: int = 256, |
| 328 | + num_heads: int = 8, |
| 329 | + num_levels: int = 4, |
| 330 | + num_points_list: list[int] = [3, 6, 3], # noqa: B006 |
| 331 | + ) -> None: |
| 332 | + super().__init__() |
| 333 | + self.embed_dim = embed_dim |
| 334 | + self.num_heads = num_heads |
| 335 | + self.num_levels = num_levels |
| 336 | + self.num_points_list = num_points_list |
| 337 | + |
| 338 | + num_points_scale = [1 / n for n in num_points_list for _ in range(n)] |
| 339 | + self.register_buffer( |
| 340 | + "num_points_scale", |
| 341 | + torch.tensor(num_points_scale, dtype=torch.float32), |
| 342 | + ) |
| 343 | + |
| 344 | + self.total_points = num_heads * sum(num_points_list) |
| 345 | + self.head_dim = embed_dim // num_heads |
| 346 | + |
| 347 | + self.sampling_offsets = nn.Linear(embed_dim, self.total_points * 2) |
| 348 | + self.attention_weights = nn.Linear(embed_dim, self.total_points) |
| 349 | + |
| 350 | + self._reset_parameters() |
| 351 | + |
| 352 | + def _reset_parameters(self) -> None: |
| 353 | + """Reset parameters of the model.""" |
| 354 | + init.constant_(self.sampling_offsets.weight, 0) |
| 355 | + thetas = torch.arange(self.num_heads, dtype=torch.float32) * (2.0 * math.pi / self.num_heads) |
| 356 | + grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) |
| 357 | + grid_init = grid_init / grid_init.abs().max(-1, keepdim=True).values # noqa: PD011 |
| 358 | + grid_init = grid_init.reshape(self.num_heads, 1, 2).tile([1, sum(self.num_points_list), 1]) |
| 359 | + scaling = torch.concat([torch.arange(1, n + 1) for n in self.num_points_list]).reshape(1, -1, 1) |
| 360 | + grid_init *= scaling |
| 361 | + self.sampling_offsets.bias.data[...] = grid_init.flatten() |
| 362 | + |
| 363 | + # attention_weights |
| 364 | + init.constant_(self.attention_weights.weight, 0) |
| 365 | + init.constant_(self.attention_weights.bias, 0) |
| 366 | + |
| 367 | + def forward( |
| 368 | + self, |
| 369 | + query: Tensor, |
| 370 | + reference_points: Tensor, |
| 371 | + value: Tensor, |
| 372 | + value_spatial_shapes: list[list[int]], |
| 373 | + ) -> Tensor: |
| 374 | + """Forward function of MSDeformableAttention. |
| 375 | +
|
| 376 | + Args: |
| 377 | + query (Tensor): [bs, query_length, C] |
| 378 | + reference_points (Tensor): [bs, query_length, n_levels, 2], range in [0, 1], top-left (0,0), |
| 379 | + bottom-right (1, 1), including padding area |
| 380 | + value (Tensor): [bs, value_length, C] |
| 381 | + value_spatial_shapes (List): [n_levels, 2], [(H_0, W_0), (H_1, W_1), ..., (H_{L-1}, W_{L-1})] |
| 382 | +
|
| 383 | + Returns: |
| 384 | + output (Tensor): [bs, Length_{query}, C] |
| 385 | + """ |
| 386 | + bs, len_q = query.shape[:2] |
| 387 | + _, n_head, c, _ = value[0].shape |
| 388 | + num_points_list = self.num_points_list |
| 389 | + |
| 390 | + sampling_offsets = self.sampling_offsets(query).reshape( |
| 391 | + bs, |
| 392 | + len_q, |
| 393 | + self.num_heads, |
| 394 | + sum(self.num_points_list), |
| 395 | + 2, |
| 396 | + ) |
| 397 | + |
| 398 | + attention_weights = self.attention_weights(query).reshape( |
| 399 | + bs, |
| 400 | + len_q, |
| 401 | + self.num_heads, |
| 402 | + sum(self.num_points_list), |
| 403 | + ) |
| 404 | + attention_weights = f.softmax(attention_weights, dim=-1) |
| 405 | + |
| 406 | + if reference_points.shape[-1] == 2: |
| 407 | + offset_normalizer = torch.tensor(value_spatial_shapes) |
| 408 | + offset_normalizer = offset_normalizer.flip([1]).reshape(1, 1, 1, self.num_levels, 1, 2) |
| 409 | + sampling_locations = ( |
| 410 | + reference_points.reshape( |
| 411 | + bs, |
| 412 | + len_q, |
| 413 | + 1, |
| 414 | + self.num_levels, |
| 415 | + 1, |
| 416 | + 2, |
| 417 | + ) |
| 418 | + + sampling_offsets / offset_normalizer |
| 419 | + ) |
| 420 | + elif reference_points.shape[-1] == 4: |
| 421 | + num_points_scale = self.num_points_scale.to(query).unsqueeze(-1) |
| 422 | + offset = sampling_offsets * num_points_scale * reference_points[:, :, None, :, 2:] * 0.5 |
| 423 | + sampling_locations = reference_points[:, :, None, :, :2] + offset |
| 424 | + else: |
| 425 | + msg = (f"Last dim of reference_points must be 2 or 4, but get {reference_points.shape[-1]} instead.",) |
| 426 | + raise ValueError(msg) |
| 427 | + |
| 428 | + # sampling_offsets [8, 480, 8, 12, 2] |
| 429 | + sampling_grids = 2 * sampling_locations - 1 |
| 430 | + |
| 431 | + sampling_grids = sampling_grids.permute(0, 2, 1, 3, 4).flatten(0, 1) |
| 432 | + sampling_locations_list = sampling_grids.split(num_points_list, dim=-2) |
| 433 | + |
| 434 | + sampling_value_list = [] |
| 435 | + for level, (h, w) in enumerate(value_spatial_shapes): |
| 436 | + value_l = value[level].reshape(bs * n_head, c, h, w) |
| 437 | + sampling_grid_l = sampling_locations_list[level] |
| 438 | + sampling_value_l = f.grid_sample( |
| 439 | + value_l, |
| 440 | + sampling_grid_l, |
| 441 | + mode="bilinear", |
| 442 | + padding_mode="zeros", |
| 443 | + align_corners=False, |
| 444 | + ) |
| 445 | + |
| 446 | + sampling_value_list.append(sampling_value_l) |
| 447 | + |
| 448 | + attn_weights = attention_weights.permute(0, 2, 1, 3).reshape(bs * n_head, 1, len_q, sum(num_points_list)) |
| 449 | + weighted_sample_locs = torch.concat(sampling_value_list, dim=-1) * attn_weights |
| 450 | + output = weighted_sample_locs.sum(-1).reshape(bs, n_head * c, len_q) |
| 451 | + |
| 452 | + return output.permute(0, 2, 1) |
| 453 | + |
| 454 | + |
309 | 455 | class VisualEncoderLayer(nn.Module): |
310 | 456 | """VisualEncoderLayer module consisting of MSDeformableAttention and feed-forward network. |
311 | 457 |
|
|
0 commit comments