Skip to content

Commit 8ab11b9

Browse files
committed
add deimv2
1 parent 2a96ae2 commit 8ab11b9

File tree

10 files changed

+3043
-1
lines changed

10 files changed

+3043
-1
lines changed

library/src/otx/backend/native/models/common/backbones/dinov3.py

Lines changed: 542 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 222 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,222 @@
1+
2+
# Copyright (C) 2025 Intel Corporation
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Modified from DEIMv2: Real-Time Object Detection Meets DINOv3
7+
Copyright (c) 2025 The DEIMv2 Authors. All Rights Reserved.
8+
---------------------------------------------------------------------------------
9+
Modified from DINOv3 (https://github.com/facebookresearch/dinov3)
10+
11+
Copyright (c) Meta Platforms, Inc. and affiliates.
12+
13+
This software may be used and distributed in accordance with
14+
the terms of the DINOv3 License Agreement.
15+
"""
16+
17+
from __future__ import annotations
18+
19+
import os
20+
21+
import torch
22+
import torch.nn as nn
23+
import torch.nn.functional as F
24+
from typing import ClassVar
25+
26+
from otx.backend.native.models.common.backbones.dinov3 import DinoVisionTransformer
27+
from otx.backend.native.models.detection.backbones.vit_tiny import VisionTransformer
28+
29+
30+
class SpatialPriorModulev2(nn.Module):
31+
def __init__(self, inplanes=16):
32+
super().__init__()
33+
34+
# 1/4
35+
self.stem = nn.Sequential(
36+
*[
37+
nn.Conv2d(3, inplanes, kernel_size=3, stride=2, padding=1, bias=False),
38+
nn.SyncBatchNorm(inplanes),
39+
nn.GELU(),
40+
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
41+
]
42+
)
43+
# 1/8
44+
self.conv2 = nn.Sequential(
45+
*[
46+
nn.Conv2d(inplanes, 2 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
47+
nn.SyncBatchNorm(2 * inplanes),
48+
]
49+
)
50+
# 1/16
51+
self.conv3 = nn.Sequential(
52+
*[
53+
nn.GELU(),
54+
nn.Conv2d(2 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
55+
nn.SyncBatchNorm(4 * inplanes),
56+
]
57+
)
58+
# 1/32
59+
self.conv4 = nn.Sequential(
60+
*[
61+
nn.GELU(),
62+
nn.Conv2d(4 * inplanes, 4 * inplanes, kernel_size=3, stride=2, padding=1, bias=False),
63+
nn.SyncBatchNorm(4 * inplanes),
64+
]
65+
)
66+
67+
def forward(self, x):
68+
c1 = self.stem(x)
69+
c2 = self.conv2(c1) # 1/8
70+
c3 = self.conv3(c2) # 1/16
71+
c4 = self.conv4(c3) # 1/32
72+
73+
return c2, c3, c4
74+
75+
76+
class DINOv3STAsModule(nn.Module):
77+
def __init__(
78+
self,
79+
name,
80+
weights_path=None,
81+
interaction_indexes=[],
82+
finetune=True,
83+
embed_dim=192,
84+
num_heads=3,
85+
patch_size=16,
86+
use_sta=True,
87+
conv_inplane=16,
88+
hidden_dim=None,
89+
):
90+
super(DINOv3STAsModule, self).__init__()
91+
if 'dinov3' in name:
92+
self.dinov3 = DinoVisionTransformer(name=name)
93+
if weights_path is not None and os.path.exists(weights_path):
94+
print(f'Loading ckpt from {weights_path}...')
95+
self.dinov3.load_state_dict(torch.load(weights_path))
96+
else:
97+
print('Training DINOv3 from scratch...')
98+
else:
99+
self.dinov3 = VisionTransformer(embed_dim=embed_dim, num_heads=num_heads, return_layers=interaction_indexes)
100+
if weights_path is not None and os.path.exists(weights_path):
101+
print(f'Loading ckpt from {weights_path}...')
102+
self.dinov3._model.load_state_dict(torch.load(weights_path))
103+
else:
104+
print('Training ViT-Tiny from scratch...')
105+
106+
embed_dim = self.dinov3.embed_dim
107+
self.interaction_indexes = interaction_indexes
108+
self.patch_size = patch_size
109+
110+
if not finetune:
111+
self.dinov3.eval()
112+
self.dinov3.requires_grad_(False)
113+
114+
# init the feature pyramid
115+
self.use_sta = use_sta
116+
if use_sta:
117+
print(f"Using Lite Spatial Prior Module with inplanes={conv_inplane}")
118+
self.sta = SpatialPriorModulev2(inplanes=conv_inplane)
119+
else:
120+
conv_inplane = 0
121+
122+
# linear projection
123+
hidden_dim = hidden_dim if hidden_dim is not None else embed_dim
124+
self.convs = nn.ModuleList([
125+
nn.Conv2d(embed_dim + conv_inplane*2, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False),
126+
nn.Conv2d(embed_dim + conv_inplane*4, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False),
127+
nn.Conv2d(embed_dim + conv_inplane*4, hidden_dim, kernel_size=1, stride=1, padding=0, bias=False)
128+
])
129+
# norm
130+
self.norms = nn.ModuleList([
131+
nn.SyncBatchNorm(hidden_dim),
132+
nn.SyncBatchNorm(hidden_dim),
133+
nn.SyncBatchNorm(hidden_dim)
134+
])
135+
136+
def forward(self, x):
137+
# Code for matching with oss
138+
H_c, W_c = x.shape[2] // 16, x.shape[3] // 16
139+
H_toks, W_toks = x.shape[2] // self.patch_size, x.shape[3] // self.patch_size
140+
bs, C, h, w = x.shape
141+
142+
if len(self.interaction_indexes) > 0 and not isinstance(self.dinov3, VisionTransformer):
143+
all_layers = self.dinov3.get_intermediate_layers(
144+
x, n=self.interaction_indexes, return_class_token=True
145+
)
146+
else:
147+
all_layers = self.dinov3(x)
148+
149+
if len(all_layers) == 1: # repeat the same layer for all the three scales
150+
all_layers = [all_layers[0], all_layers[0], all_layers[0]]
151+
152+
sem_feats = []
153+
num_scales = len(all_layers) - 2
154+
for i, sem_feat in enumerate(all_layers):
155+
feat, _ = sem_feat
156+
sem_feat = feat.transpose(1, 2).view(bs, -1, H_c, W_c).contiguous() # [B, D, H, W]
157+
resize_H, resize_W = int(H_c * 2**(num_scales-i)), int(W_c * 2**(num_scales-i))
158+
sem_feat = F.interpolate(sem_feat, size=[resize_H, resize_W], mode="bilinear", align_corners=False)
159+
sem_feats.append(sem_feat)
160+
161+
# fusion
162+
fused_feats = []
163+
if self.use_sta:
164+
detail_feats = self.sta(x)
165+
for sem_feat, detail_feat in zip(sem_feats, detail_feats):
166+
fused_feats.append(torch.cat([sem_feat, detail_feat], dim=1))
167+
else:
168+
fused_feats = sem_feats
169+
170+
c2 = self.norms[0](self.convs[0](fused_feats[0]))
171+
c3 = self.norms[1](self.convs[1](fused_feats[1]))
172+
c4 = self.norms[2](self.convs[2](fused_feats[2]))
173+
174+
return c2, c3, c4
175+
176+
177+
class DINOv3STAs(nn.Module):
178+
"""DINOv3STAs backbone."""
179+
180+
backbone_cfg: ClassVar = {
181+
"deimv2_x" : {
182+
"name": "dinov3_vits16plus",
183+
"weights_path": None,
184+
"interaction_indexes": [5,8,11],
185+
"finetune": True,
186+
"conv_inplane": 64,
187+
"hidden_dim": 256
188+
},
189+
"deimv2_l" : {
190+
"name": "dinov3_vits16",
191+
"weights_path": None,
192+
"interaction_indexes": [5,8,11],
193+
"finetune": True,
194+
"conv_inplane": 32,
195+
"hidden_dim": 224,
196+
},
197+
"deimvv2_m": {
198+
"name": "vit_tinyplus",
199+
"embed_dim": 256,
200+
"weights_path": None,
201+
"interaction_indexes": [3, 7, 11],
202+
"num_heads": 4
203+
},
204+
"deimv2_s": {
205+
"name": "vit_tiny",
206+
"embed_dim": 192,
207+
"weights_path": None,
208+
"interaction_indexes": [3, 7, 11],
209+
"num_heads": 3,
210+
}
211+
}
212+
213+
def __new__(cls, model_name: str) -> DINOv3STAsModule:
214+
"""Create DINOv3STAs backbone.
215+
216+
Args:
217+
model_name (str): Model name.
218+
219+
Returns:
220+
DINOv3STAsModule: DINOv3STAs backbone.
221+
"""
222+
return DINOv3STAsModule(**cls.backbone_cfg[model_name])

0 commit comments

Comments
 (0)