|
6 | 6 | from __future__ import annotations |
7 | 7 |
|
8 | 8 | import math |
| 9 | +from typing import Literal |
9 | 10 |
|
10 | 11 | import torch |
11 | 12 | from torch import nn |
| 13 | +import numpy as np |
| 14 | +from torch import Tensor |
12 | 15 |
|
13 | 16 |
|
14 | 17 | class PositionEmbeddingSine(nn.Module): |
@@ -105,3 +108,111 @@ def gen_sineembed_for_position(pos_tensor: torch.Tensor) -> torch.Tensor: |
105 | 108 | msg = f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}" |
106 | 109 | raise ValueError(msg) |
107 | 110 | return pos |
| 111 | + |
| 112 | + |
| 113 | +class RopePositionEmbedding(nn.Module): |
| 114 | + def __init__( |
| 115 | + self, |
| 116 | + embed_dim: int, |
| 117 | + *, |
| 118 | + num_heads: int, |
| 119 | + base: float | None = 100.0, |
| 120 | + min_period: float | None = None, |
| 121 | + max_period: float | None = None, |
| 122 | + normalize_coords: Literal["min", "max", "separate"] = "separate", |
| 123 | + shift_coords: float | None = None, |
| 124 | + jitter_coords: float | None = None, |
| 125 | + rescale_coords: float | None = None, |
| 126 | + dtype: torch.dtype | None = None, |
| 127 | + device: torch.device | None = None, |
| 128 | + ): |
| 129 | + super().__init__() |
| 130 | + assert embed_dim % (4 * num_heads) == 0 |
| 131 | + both_periods = min_period is not None and max_period is not None |
| 132 | + if (base is None and not both_periods) or (base is not None and both_periods): |
| 133 | + raise ValueError("Either `base` or `min_period`+`max_period` must be provided.") |
| 134 | + |
| 135 | + D_head = embed_dim // num_heads |
| 136 | + self.base = base |
| 137 | + self.min_period = min_period |
| 138 | + self.max_period = max_period |
| 139 | + self.D_head = D_head |
| 140 | + self.normalize_coords = normalize_coords |
| 141 | + self.shift_coords = shift_coords |
| 142 | + self.jitter_coords = jitter_coords |
| 143 | + self.rescale_coords = rescale_coords |
| 144 | + |
| 145 | + # Needs persistent=True because we do teacher.load_state_dict(student.state_dict()) to initialize the teacher |
| 146 | + self.dtype = dtype # Don't rely on self.periods.dtype |
| 147 | + self.register_buffer( |
| 148 | + "periods", |
| 149 | + torch.empty(D_head // 4, device=device, dtype=dtype), |
| 150 | + persistent=True, |
| 151 | + ) |
| 152 | + self._init_weights() |
| 153 | + |
| 154 | + def forward(self, *, H: int, W: int) -> tuple[Tensor, Tensor]: |
| 155 | + device = self.periods.device |
| 156 | + dtype = self.dtype |
| 157 | + dd = {"device": device, "dtype": dtype} |
| 158 | + |
| 159 | + # Prepare coords in range [-1, +1] |
| 160 | + if self.normalize_coords == "max": |
| 161 | + max_HW = max(H, W) |
| 162 | + coords_h = torch.arange(0.5, H, **dd) / max_HW # [H] |
| 163 | + coords_w = torch.arange(0.5, W, **dd) / max_HW # [W] |
| 164 | + elif self.normalize_coords == "min": |
| 165 | + min_HW = min(H, W) |
| 166 | + coords_h = torch.arange(0.5, H, **dd) / min_HW # [H] |
| 167 | + coords_w = torch.arange(0.5, W, **dd) / min_HW # [W] |
| 168 | + elif self.normalize_coords == "separate": |
| 169 | + coords_h = torch.arange(0.5, H, **dd) / H # [H] |
| 170 | + coords_w = torch.arange(0.5, W, **dd) / W # [W] |
| 171 | + else: |
| 172 | + raise ValueError(f"Unknown normalize_coords: {self.normalize_coords}") |
| 173 | + coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1) # [H, W, 2] |
| 174 | + coords = coords.flatten(0, 1) # [HW, 2] |
| 175 | + coords = 2.0 * coords - 1.0 # Shift range [0, 1] to [-1, +1] |
| 176 | + |
| 177 | + # Shift coords by adding a uniform value in [-shift, shift] |
| 178 | + if self.training and self.shift_coords is not None: |
| 179 | + shift_hw = torch.empty(2, **dd).uniform_(-self.shift_coords, self.shift_coords) |
| 180 | + coords += shift_hw[None, :] |
| 181 | + |
| 182 | + # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter] |
| 183 | + if self.training and self.jitter_coords is not None: |
| 184 | + jitter_max = np.log(self.jitter_coords) |
| 185 | + jitter_min = -jitter_max |
| 186 | + jitter_hw = torch.empty(2, **dd).uniform_(jitter_min, jitter_max).exp() |
| 187 | + coords *= jitter_hw[None, :] |
| 188 | + |
| 189 | + # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale] |
| 190 | + if self.training and self.rescale_coords is not None: |
| 191 | + rescale_max = np.log(self.rescale_coords) |
| 192 | + rescale_min = -rescale_max |
| 193 | + rescale_hw = torch.empty(1, **dd).uniform_(rescale_min, rescale_max).exp() |
| 194 | + coords *= rescale_hw |
| 195 | + |
| 196 | + # Prepare angles and sin/cos |
| 197 | + angles = 2 * math.pi * coords[:, :, None] / self.periods[None, None, :] # [HW, 2, D//4] |
| 198 | + angles = angles.flatten(1, 2) # [HW, D//2] |
| 199 | + angles = angles.tile(2) # [HW, D] |
| 200 | + cos = torch.cos(angles) # [HW, D] |
| 201 | + sin = torch.sin(angles) # [HW, D] |
| 202 | + |
| 203 | + return (sin, cos) # 2 * [HW, D] |
| 204 | + |
| 205 | + def _init_weights(self): |
| 206 | + device = self.periods.device |
| 207 | + dtype = self.dtype |
| 208 | + if self.base is not None: |
| 209 | + periods = self.base ** ( |
| 210 | + 2 * torch.arange(self.D_head // 4, device=device, dtype=dtype) / (self.D_head // 2) |
| 211 | + ) # [D//4] |
| 212 | + else: |
| 213 | + base = self.max_period / self.min_period |
| 214 | + exponents = torch.linspace(0, 1, self.D_head // 4, device=device, dtype=dtype) # [D//4] range [0, 1] |
| 215 | + periods = base**exponents # range [1, max_period / min_period] |
| 216 | + periods = periods / base # range [min_period / max_period, 1] |
| 217 | + periods = periods * self.max_period # range [min_period, max_period] |
| 218 | + self.periods.data = periods |
0 commit comments