Skip to content

Commit 81ec5a1

Browse files
committed
Unfold not needed with stride=kernel_size, removed unecessary permutes for a speedup
1 parent dd57311 commit 81ec5a1

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

timm/models/csatv2.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -211,7 +211,6 @@ def __init__(
211211
dd = dict(device=device, dtype=dtype)
212212
super().__init__()
213213
self.k = kernel_size
214-
self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size))
215214
self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **dd)
216215
self.permutation = _zigzag_permutation(kernel_size, kernel_size)
217216

@@ -222,19 +221,21 @@ def __init__(
222221

223222
self.register_buffer('mean', torch.tensor(_DCT_MEAN, device=device), persistent=False)
224223
self.register_buffer('var', torch.tensor(_DCT_VAR, device=device), persistent=False)
225-
self.register_buffer('imagenet_mean', torch.tensor([0.485, 0.456, 0.406], device=device), persistent=False)
226-
self.register_buffer('imagenet_std', torch.tensor([0.229, 0.224, 0.225], device=device), persistent=False)
224+
# Shape (3, 1, 1) for BCHW broadcasting
225+
self.register_buffer('imagenet_mean', torch.tensor([0.485, 0.456, 0.406], device=device).view(3, 1, 1), persistent=False)
226+
self.register_buffer('imagenet_std', torch.tensor([0.229, 0.224, 0.225], device=device).view(3, 1, 1), persistent=False)
227227

228228
def _denormalize(self, x: torch.Tensor) -> torch.Tensor:
229229
"""Convert from ImageNet normalized to [0, 255] range."""
230230
return x.mul(self.imagenet_std).add_(self.imagenet_mean) * 255
231231

232232
def _rgb_to_ycbcr(self, x: torch.Tensor) -> torch.Tensor:
233-
"""Convert RGB to YCbCr color space."""
234-
y = (x[:, :, :, 0] * 0.299) + (x[:, :, :, 1] * 0.587) + (x[:, :, :, 2] * 0.114)
235-
cb = 0.564 * (x[:, :, :, 2] - y) + 128
236-
cr = 0.713 * (x[:, :, :, 0] - y) + 128
237-
return torch.stack([y, cb, cr], dim=-1)
233+
"""Convert RGB to YCbCr color space (BCHW input/output)."""
234+
r, g, b = x[:, 0], x[:, 1], x[:, 2]
235+
y = r * 0.299 + g * 0.587 + b * 0.114
236+
cb = 0.564 * (b - y) + 128
237+
cr = 0.713 * (r - y) + 128
238+
return torch.stack([y, cb, cr], dim=1)
238239

239240
def _frequency_normalize(self, x: torch.Tensor) -> torch.Tensor:
240241
"""Normalize DCT coefficients using precomputed statistics."""
@@ -243,12 +244,11 @@ def _frequency_normalize(self, x: torch.Tensor) -> torch.Tensor:
243244

244245
def forward(self, x: torch.Tensor) -> torch.Tensor:
245246
b, c, h, w = x.shape
246-
x = x.permute(0, 2, 3, 1)
247247
x = self._denormalize(x)
248248
x = self._rgb_to_ycbcr(x)
249-
x = x.permute(0, 3, 1, 2)
250-
x = self.unfold(x).transpose(-1, -2)
251-
x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k)
249+
# Extract non-overlapping k x k patches
250+
x = x.reshape(b, c, h // self.k, self.k, w // self.k, self.k) # (B, C, H//k, k, W//k, k)
251+
x = x.permute(0, 2, 4, 1, 3, 5) # (B, H//k, W//k, C, k, k)
252252
x = self.transform(x)
253253
x = x.reshape(-1, c, self.k * self.k)
254254
x = x[:, :, self.permutation]
@@ -275,14 +275,14 @@ def __init__(
275275
dd = dict(device=device, dtype=dtype)
276276
super().__init__()
277277
self.k = kernel_size
278-
self.unfold = nn.Unfold(kernel_size=(kernel_size, kernel_size), stride=(kernel_size, kernel_size))
279278
self.transform = Dct2d(kernel_size, kernel_type, orthonormal, **dd)
280279
self.permutation = _zigzag_permutation(kernel_size, kernel_size)
281280

282281
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
283282
b, c, h, w = x.shape
284-
x = self.unfold(x).transpose(-1, -2)
285-
x = x.reshape(b, h // self.k, w // self.k, c, self.k, self.k)
283+
# Extract non-overlapping k x k patches
284+
x = x.reshape(b, c, h // self.k, self.k, w // self.k, self.k) # (B, C, H//k, k, W//k, k)
285+
x = x.permute(0, 2, 4, 1, 3, 5) # (B, H//k, W//k, C, k, k)
286286
x = self.transform(x)
287287
x = x.reshape(-1, c, self.k * self.k)
288288
x = x[:, :, self.permutation]

0 commit comments

Comments
 (0)