Skip to content

Commit 5d6f8d3

Browse files
authored
Decouple DinoV2 for semantic segmentation (#4136)
* dinov2 decoupled. Perf tests * added dino * remove dinov2 backbone * fix linter * remove unit test * fix integration tests * revert perf test back
1 parent 5707bc5 commit 5d6f8d3

File tree

10 files changed

+203
-234
lines changed

10 files changed

+203
-234
lines changed

src/otx/algo/classification/backbones/vision_transformer.py

Lines changed: 158 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""Copy from mmpretrain/models/backbones/vision_transformer.py."""
66
from __future__ import annotations
77

8+
import math
89
from functools import partial
910
from typing import TYPE_CHECKING, Any, Callable, Literal
1011

@@ -46,6 +47,7 @@
4647
"vit-huge",
4748
"dinov2-s",
4849
"dinov2-small",
50+
"dinov2-small-seg",
4951
"dinov2-b",
5052
"dinov2-base",
5153
"dinov2-l",
@@ -87,6 +89,7 @@ class VisionTransformer(BaseModule):
8789
norm_layer: Normalization layer.
8890
act_layer: MLP activation layer.
8991
block_fn: Transformer block layer.
92+
interpolate_offset: work-around offset to apply when interpolating positional embeddings
9093
lora: Enable LoRA training.
9194
"""
9295

@@ -147,6 +150,17 @@ class VisionTransformer(BaseModule):
147150
"num_heads": 6,
148151
"reg_tokens": 4,
149152
"no_embed_class": True,
153+
},
154+
),
155+
**dict.fromkeys(
156+
["dinov2-small-seg"], # segmentation
157+
{
158+
"patch_size": 14,
159+
"embed_dim": 384,
160+
"depth": 12,
161+
"num_heads": 6,
162+
"reg_tokens": 0,
163+
"no_embed_class": False,
150164
"init_values": 1e-5,
151165
},
152166
),
@@ -193,9 +207,9 @@ class VisionTransformer(BaseModule):
193207

194208
def __init__( # noqa: PLR0913
195209
self,
196-
arch: VIT_ARCH_TYPE = "vit-base",
210+
arch: VIT_ARCH_TYPE | str = "vit-base",
197211
img_size: int | tuple[int, int] = 224,
198-
patch_size: int | tuple[int, int] | None = None,
212+
patch_size: int | None = None,
199213
in_chans: int = 3,
200214
num_classes: int = 1000,
201215
embed_dim: int | None = None,
@@ -221,6 +235,7 @@ def __init__( # noqa: PLR0913
221235
mlp_layer: nn.Module | None = None,
222236
act_layer: LayerType | None = None,
223237
norm_layer: LayerType | None = None,
238+
interpolate_offset: float = 0.1,
224239
lora: bool = False,
225240
) -> None:
226241
super().__init__()
@@ -231,7 +246,7 @@ def __init__( # noqa: PLR0913
231246
arch_settings: dict[str, Any] = self.arch_zoo[arch]
232247

233248
self.img_size: int | tuple[int, int] = img_size
234-
self.patch_size: int | tuple[int, int] = patch_size or arch_settings.get("patch_size", 16)
249+
self.patch_size: int = patch_size or arch_settings.get("patch_size", 16)
235250
self.embed_dim = embed_dim or arch_settings.get("embed_dim", 768)
236251
depth = depth or arch_settings.get("depth", 12)
237252
num_heads = num_heads or arch_settings.get("num_heads", 12)
@@ -251,6 +266,7 @@ def __init__( # noqa: PLR0913
251266
self.no_embed_class = no_embed_class # don't embed prefix positions (includes reg)
252267
self.dynamic_img_size = dynamic_img_size
253268
self.grad_checkpointing = False
269+
self.interpolate_offset = interpolate_offset
254270

255271
embed_args = {}
256272
if dynamic_img_size:
@@ -353,15 +369,17 @@ def resize_positional_embeddings(pos_embed: torch.Tensor, new_shape: tuple[int,
353369
# convert dinov2 pretrained weights
354370
state_dict = torch.load(checkpoint_path)
355371
state_dict.pop("mask_token", None)
356-
state_dict["reg_token"] = state_dict.pop("register_tokens")
372+
if "reg_token" in state_dict:
373+
state_dict["reg_token"] = state_dict.pop("register_tokens")
357374
state_dict["cls_token"] = state_dict.pop("cls_token") + state_dict["pos_embed"][:, 0]
358375

359376
img_size = (self.img_size, self.img_size) if isinstance(self.img_size, int) else self.img_size
360-
patch_size = (self.patch_size, self.patch_size) if isinstance(self.patch_size, int) else self.patch_size
361-
state_dict["pos_embed"] = resize_positional_embeddings(
362-
state_dict.pop("pos_embed")[:, 1:],
363-
(img_size[0] // patch_size[0], img_size[1] // patch_size[1]),
364-
)
377+
patch_size = (self.patch_size, self.patch_size)
378+
if state_dict["pos_embed"].shape != self.pos_embed.shape:
379+
state_dict["pos_embed"] = resize_positional_embeddings(
380+
state_dict.pop("pos_embed")[:, 1:],
381+
(img_size[0] // patch_size[0], img_size[1] // patch_size[1]),
382+
)
365383
self.load_state_dict(state_dict, strict=False)
366384
else:
367385
msg = f"Unsupported `checkpoint_extension` {checkpoint_ext}, please choose from 'npz' or 'pth'."
@@ -401,6 +419,137 @@ def _pos_embed(self, x: torch.Tensor) -> torch.Tensor:
401419

402420
return self.pos_drop(x)
403421

422+
def interpolate_pos_encoding(self, x: torch.Tensor, w: int, h: int) -> torch.Tensor:
423+
"""Interpolates the positional encoding to match the input dimensions.
424+
425+
Args:
426+
x (torch.Tensor): Input tensor.
427+
w (int): Width of the input image.
428+
h (int): Height of the input image.
429+
430+
Returns:
431+
torch.Tensor: Tensor with interpolated positional encoding.
432+
"""
433+
previous_dtype = x.dtype
434+
npatch = x.shape[1]
435+
n = self.pos_embed.shape[1]
436+
if npatch == n and w == h:
437+
return self.pos_embed
438+
pos_embed = self.pos_embed.float()
439+
class_pos_embed = pos_embed[:, 0]
440+
patch_pos_embed = pos_embed[:, 1:]
441+
dim = x.shape[-1]
442+
w0 = w // self.patch_size
443+
h0 = h // self.patch_size
444+
m = int(math.sqrt(n)) # Recover the number of patches in each dimension
445+
if m * m != n:
446+
msg = f"Expected m * m to equal n, but got m={m}, n={n}"
447+
raise ValueError(msg)
448+
kwargs = {}
449+
if self.interpolate_offset:
450+
# fix float error by introducing small offset
451+
sx = float(w0 + self.interpolate_offset) / m
452+
sy = float(h0 + self.interpolate_offset) / m
453+
kwargs["scale_factor"] = (sx, sy)
454+
else:
455+
# Simply specify an output size instead of a scale factor
456+
kwargs["size"] = (w0, h0)
457+
patch_pos_embed = nn.functional.interpolate(
458+
patch_pos_embed.reshape(1, m, m, dim).permute(0, 3, 1, 2),
459+
mode="bicubic",
460+
**kwargs,
461+
)
462+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
463+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
464+
465+
def prepare_tokens_with_masks(self, x: torch.Tensor, masks: torch.Tensor | None = None) -> torch.Tensor:
466+
"""Prepare tokens with optional masks.
467+
468+
Args:
469+
x (torch.Tensor): Input tensor.
470+
masks (torch.Tensor | None): Optional masks tensor.
471+
472+
Returns:
473+
torch.Tensor: Tensor with prepared tokens.
474+
"""
475+
_, _, w, h = x.shape
476+
x = self.patch_embed(x)
477+
if masks is not None:
478+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
479+
480+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
481+
x = x + self.interpolate_pos_encoding(x, w, h)
482+
483+
if self.reg_token is not None:
484+
x = torch.cat(
485+
(
486+
x[:, :1],
487+
self.reg_token.expand(x.shape[0], -1, -1),
488+
x[:, 1:],
489+
),
490+
dim=1,
491+
)
492+
493+
return x
494+
495+
def _get_intermediate_layers_not_chunked(self, x: torch.Tensor, n: int = 1) -> list[torch.Tensor]:
496+
"""Get intermediate layers without chunking.
497+
498+
Args:
499+
x (torch.Tensor): Input tensor.
500+
n (int): Number of last blocks to take. If it's a list, take the specified blocks.
501+
502+
Returns:
503+
list[torch.Tensor]: List of intermediate layer outputs.
504+
"""
505+
x = self.prepare_tokens_with_masks(x)
506+
# If n is an int, take the n last blocks. If it's a list, take them
507+
output, total_block_len = [], len(self.blocks)
508+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
509+
for i, blk in enumerate(self.blocks):
510+
x = blk(x)
511+
if i in blocks_to_take:
512+
output.append(x)
513+
if len(output) != len(blocks_to_take):
514+
msg = f"only {len(output)} / {len(blocks_to_take)} blocks found"
515+
raise RuntimeError(msg)
516+
return output
517+
518+
def get_intermediate_layers(
519+
self,
520+
x: torch.Tensor,
521+
n: int = 1, # Layers or n last layers to take
522+
reshape: bool = False,
523+
return_class_token: bool = False,
524+
norm: bool = True,
525+
) -> tuple:
526+
"""Get intermediate layers of the VisionTransformer.
527+
528+
Args:
529+
x (torch.Tensor): Input tensor.
530+
n (int): Number of last blocks to take. If it's a list, take the specified blocks.
531+
reshape (bool): Whether to reshape the output feature maps.
532+
return_class_token (bool): Whether to return the class token.
533+
norm (bool): Whether to apply normalization to the outputs.
534+
535+
Returns:
536+
tuple: A tuple containing the intermediate layer outputs.
537+
"""
538+
outputs = self._get_intermediate_layers_not_chunked(x, n)
539+
if norm:
540+
outputs = [self.norm(out) for out in outputs]
541+
class_tokens = [out[:, 0] for out in outputs]
542+
outputs = [out[:, 1 + self.num_reg_tokens :] for out in outputs]
543+
if reshape:
544+
b, _, w, h = x.shape
545+
outputs = [
546+
out.reshape(b, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
547+
for out in outputs
548+
]
549+
if return_class_token:
550+
return tuple(zip(outputs, class_tokens))
551+
return tuple(outputs)
552+
404553
def forward(
405554
self,
406555
x: torch.Tensor,

src/otx/algo/segmentation/backbones/__init__.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
#
44
"""Backbone modules for OTX segmentation model."""
55

6-
from .dinov2 import DinoVisionTransformer
76
from .litehrnet import LiteHRNetBackbone
87
from .mscan import MSCAN
98

10-
__all__ = ["LiteHRNetBackbone", "DinoVisionTransformer", "MSCAN"]
9+
__all__ = ["LiteHRNetBackbone", "MSCAN"]

src/otx/algo/segmentation/backbones/dinov2.py

Lines changed: 0 additions & 98 deletions
This file was deleted.

0 commit comments

Comments
 (0)