Skip to content

Commit 2a96ae2

Browse files
committed
continue exps
1 parent fadcdde commit 2a96ae2

File tree

17 files changed

+1981
-40
lines changed

17 files changed

+1981
-40
lines changed

library/src/otx/backend/native/callbacks/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@
44
"""Module for OTX custom callbacks."""
55

66
from .batchsize_finder import BatchSizeFinder
7+
from .ema import EMAWeightAveraging
78

8-
__all__ = ["BatchSizeFinder"]
9+
__all__ = ["BatchSizeFinder", "EMAWeightAveraging"]

library/src/otx/backend/native/callbacks/ema.py

Lines changed: 416 additions & 0 deletions
Large diffs are not rendered by default.

library/src/otx/backend/native/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
TVModel,
1111
VisionTransformer,
1212
)
13-
from .detection import ATSS, RTDETR, SSD, YOLOX, DEIMDFine, DFine, RTMDet
13+
from .detection import ATSS, RTDETR, SSD, YOLOX, DEIMDFine, DEIMV2, DFine, RTMDet
1414
from .instance_segmentation import MaskRCNN, MaskRCNNTV, RTMDetInst
1515
from .keypoint_detection import RTMPose
1616
from .segmentation import DinoV2Seg, LiteHRNet, SegNext
@@ -22,6 +22,7 @@
2222
"YOLOX",
2323
"DEIMDFine",
2424
"DFine",
25+
"DEIMV2",
2526
"DinoV2Seg",
2627
"EfficientNet",
2728
"LiteHRNet",

library/src/otx/backend/native/models/classification/utils/swiglu_ffn.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from otx.backend.native.models.modules.drop import build_dropout
1616
from otx.backend.native.models.modules.norm import build_norm_layer
17+
from otx.backend.native.models.common.layers.transformer_layers import ListForwardMixin
1718

1819

1920
class SwiGLUFFN(nn.Module):
@@ -100,3 +101,31 @@ def __init__(
100101
out_dims=out_dims,
101102
bias=bias,
102103
)
104+
105+
106+
class SwiGLUFFNV2(nn.Module, ListForwardMixin):
107+
def __init__(
108+
self,
109+
in_features: int,
110+
hidden_features: int | None = None,
111+
out_features: int | None = None,
112+
act_layer: Callable[..., nn.Module] | None = None,
113+
drop: float = 0.0,
114+
bias: bool = True,
115+
align_to: int = 8,
116+
device=None,
117+
) -> None:
118+
super().__init__()
119+
out_features = out_features or in_features
120+
hidden_features = hidden_features or in_features
121+
d = int(hidden_features * 2 / 3)
122+
swiglu_hidden_features = d + (-d % align_to)
123+
self.w1 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device)
124+
self.w2 = nn.Linear(in_features, swiglu_hidden_features, bias=bias, device=device)
125+
self.w3 = nn.Linear(swiglu_hidden_features, out_features, bias=bias, device=device)
126+
127+
def forward(self, x: torch.Tensor) -> torch.Tensor:
128+
x1 = self.w1(x)
129+
x2 = self.w2(x)
130+
hidden = nn.functional.silu(x1) * x2
131+
return self.w3(hidden)

library/src/otx/backend/native/models/common/layers/position_embed.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66
from __future__ import annotations
77

88
import math
9+
from typing import Literal
910

1011
import torch
1112
from torch import nn
13+
import numpy as np
14+
from torch import Tensor
1215

1316

1417
class PositionEmbeddingSine(nn.Module):
@@ -105,3 +108,111 @@ def gen_sineembed_for_position(pos_tensor: torch.Tensor) -> torch.Tensor:
105108
msg = f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}"
106109
raise ValueError(msg)
107110
return pos
111+
112+
113+
class RopePositionEmbedding(nn.Module):
114+
def __init__(
115+
self,
116+
embed_dim: int,
117+
*,
118+
num_heads: int,
119+
base: float | None = 100.0,
120+
min_period: float | None = None,
121+
max_period: float | None = None,
122+
normalize_coords: Literal["min", "max", "separate"] = "separate",
123+
shift_coords: float | None = None,
124+
jitter_coords: float | None = None,
125+
rescale_coords: float | None = None,
126+
dtype: torch.dtype | None = None,
127+
device: torch.device | None = None,
128+
):
129+
super().__init__()
130+
assert embed_dim % (4 * num_heads) == 0
131+
both_periods = min_period is not None and max_period is not None
132+
if (base is None and not both_periods) or (base is not None and both_periods):
133+
raise ValueError("Either `base` or `min_period`+`max_period` must be provided.")
134+
135+
D_head = embed_dim // num_heads
136+
self.base = base
137+
self.min_period = min_period
138+
self.max_period = max_period
139+
self.D_head = D_head
140+
self.normalize_coords = normalize_coords
141+
self.shift_coords = shift_coords
142+
self.jitter_coords = jitter_coords
143+
self.rescale_coords = rescale_coords
144+
145+
# Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher
146+
self.dtype = dtype # Don't rely on self.periods.dtype
147+
self.register_buffer(
148+
"periods",
149+
torch.empty(D_head // 4, device=device, dtype=dtype),
150+
persistent=True,
151+
)
152+
self._init_weights()
153+
154+
def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]:
155+
device = self.periods.device
156+
dtype = self.dtype
157+
dd = {"device": device, "dtype": dtype}
158+
159+
# Prepare coords in range [-1, +1]
160+
if self.normalize_coords == "max":
161+
max_HW = max(H, W)
162+
coords_h = torch.arange(0.5, H, **dd) / max_HW # [H]
163+
coords_w = torch.arange(0.5, W, **dd) / max_HW # [W]
164+
elif self.normalize_coords == "min":
165+
min_HW = min(H, W)
166+
coords_h = torch.arange(0.5, H, **dd) / min_HW # [H]
167+
coords_w = torch.arange(0.5, W, **dd) / min_HW # [W]
168+
elif self.normalize_coords == "separate":
169+
coords_h = torch.arange(0.5, H, **dd) / H # [H]
170+
coords_w = torch.arange(0.5, W, **dd) / W # [W]
171+
else:
172+
raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}")
173+
coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # [H, W, 2]
174+
coords = coords.flatten(0, 1) # [HW, 2]
175+
coords = 2.0 * coords - 1.0 # Shift range [0, 1] to [-1, +1]
176+
177+
# Shift coords by adding a uniform value in [-shift, shift]
178+
if self.training and self.shift_coords is not None:
179+
shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords)
180+
coords += shift_hw[None, :]
181+
182+
# Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter]
183+
if self.training and self.jitter_coords is not None:
184+
jitter_max = np.log(self.jitter_coords)
185+
jitter_min = -jitter_max
186+
jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp()
187+
coords *= jitter_hw[None, :]
188+
189+
# Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale]
190+
if self.training and self.rescale_coords is not None:
191+
rescale_max = np.log(self.rescale_coords)
192+
rescale_min = -rescale_max
193+
rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp()
194+
coords *= rescale_hw
195+
196+
# Prepare angles and sin/cos
197+
angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] # [HW, 2, D//4]
198+
angles = angles.flatten(1, 2) # [HW, D//2]
199+
angles = angles.tile(2) # [HW, D]
200+
cos = torch.cos(angles) # [HW, D]
201+
sin = torch.sin(angles) # [HW, D]
202+
203+
return (sin, cos) # 2 * [HW, D]
204+
205+
def _init_weights(self):
206+
device = self.periods.device
207+
dtype = self.dtype
208+
if self.base is not None:
209+
periods = self.base ** (
210+
2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2)
211+
) # [D//4]
212+
else:
213+
base = self.max_period / self.min_period
214+
exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype) # [D//4] range [0, 1]
215+
periods = base**exponents # range [1, max_period / min_period]
216+
periods = periods / base # range [min_period / max_period, 1]
217+
periods = periods * self.max_period # range [min_period, max_period]
218+
self.periods.data = periods

0 commit comments

Comments
 (0)