Skip to content

Commit 03c57a4

Browse files
committed
Fix bug in timm.layers.drop.drop_block_2d when H != W.
There are two bugs in the `valid_block` code for `drop_block_2d`. - a (W, H) grid being reshaped as (H, W) The current code uses (W, H) to generate the meshgrid; but then uses a `.reshape((1, 1, H, W))` to unsqueeze the block map. The simplest fix to the first bug is a one-line change: ```python h_i, w_i = ndgrid(torch.arange(H), torch.arange(W)) ``` This is a longer patch, that attempts to make the code testable. Note: The current code behaves oddly when the block_size or clipped_block_size is even; I've added tests exposing the behavior; but have not changed it. When you trigger the reshape bug, you get wild results: ``` $ python scratch.py {'H': 4, 'W': 5, 'block_size': 3, 'fix_reshape': False} grid.shape=torch.Size([1, 1, 4, 5]) tensor([[[[False, False, False, False, False], [ True, True, False, False, True], [ True, False, False, True, True], [False, False, False, False, False]]]]) {'H': 4, 'W': 5, 'block_size': 3, 'fix_reshape': True} grid.shape=torch.Size([1, 1, 4, 5]) tensor([[[[False, False, False, False, False], [False, True, True, True, False], [False, True, True, True, False], [False, False, False, False, False]]]]) ``` Here's a tiny exceprt script, showing the problem; it generated the above output. ```python import torch from typing import Tuple def ndgrid(*tensors) -> Tuple[torch.Tensor, ...]: """generate N-D grid in dimension order. The ndgrid function is like meshgrid except that the order of the first two input arguments are switched. That is, the statement [X1,X2,X3] = ndgrid(x1,x2,x3) produces the same result as [X2,X1,X3] = meshgrid(x2,x1,x3) This naming is based on MATLAB, the purpose is to avoid confusion due to torch's change to make torch.meshgrid behaviour move from matching ndgrid ('ij') indexing to numpy meshgrid defaults of ('xy'). """ try: return torch.meshgrid(*tensors, indexing='ij') except TypeError: # old PyTorch < 1.10 will follow this path as it does not have indexing arg, # the old behaviour of meshgrid was 'ij' return torch.meshgrid(*tensors) def valid_block(H, W, block_size, fix_reshape=False): clipped_block_size = min(block_size, H, W) if fix_reshape: # This should match the .reshape() dimension order below. h_i, w_i = ndgrid(torch.arange(H), torch.arange(W)) else: # The original produces crazy stride patterns, due to .reshape() offset winding. # This is only visible when H != W. w_i, h_i = ndgrid(torch.arange(W), torch.arange(H)) valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \ ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2)) valid_block = torch.reshape(valid_block, (1, 1, H, W)) return valid_block def main(): common_args = dict(H=4, W=5, block_size=3) for fix in [False, True]: args = dict(H=4, W=5, block_size=3, fix_reshape=fix) grid = valid_block(**args) print(args) print(f"{grid.shape=}") print(grid) print() if __name__ == "__main__": main() ```
1 parent 954613a commit 03c57a4

File tree

3 files changed

+106
-6
lines changed

3 files changed

+106
-6
lines changed

tests/layers/__init__.py

Whitespace-only changes.

tests/layers/test_drop.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import importlib
2+
import os
3+
import unittest
4+
5+
import torch
6+
7+
from timm.layers import drop
8+
9+
torch_backend = os.environ.get('TORCH_BACKEND')
10+
if torch_backend is not None:
11+
importlib.import_module(torch_backend)
12+
torch_device = os.environ.get('TORCH_DEVICE', 'cpu')
13+
14+
class ClipMaskTests(unittest.TestCase):
15+
def test_clip_mask_2d_odd(self):
16+
mask = drop.clip_mask_2d(h=5, w=7, kernel=3, device=torch_device)
17+
assert mask.device == torch.device(torch_device)
18+
assert mask.tolist() == \
19+
[
20+
[False, False, False, False, False, False, False],
21+
[False, True, True, True, True, True, False],
22+
[False, True, True, True, True, True, False],
23+
[False, True, True, True, True, True, False],
24+
[False, False, False, False, False, False, False],
25+
]
26+
27+
def test_clip_mask_2d_even(self):
28+
mask = drop.clip_mask_2d(h=5, w=7, kernel=2, device=torch_device)
29+
assert mask.device == torch.device(torch_device)
30+
# TODO: This is a suprising result; should even kernels be forbidden?
31+
assert mask.tolist() == \
32+
[
33+
[False, False, False, False, False, False, False],
34+
[False, True, True, True, True, True, True],
35+
[False, True, True, True, True, True, True],
36+
[False, True, True, True, True, True, True],
37+
[False, True, True, True, True, True, True],
38+
]
39+
40+
def test_clip_mask_2d_kernel_too_big(self):
41+
try:
42+
drop.clip_mask_2d(h=4, w=7, kernel=5, device=torch_device)
43+
raise RuntimeError("Expected throw")
44+
45+
except AssertionError as e:
46+
assert "kernel=5 > min(h=4, w=7)" in e.args[0]
47+

timm/layers/drop.py

Lines changed: 59 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,43 @@
2121
from .grid import ndgrid
2222

2323

24+
def clip_mask_2d(
25+
h: int,
26+
w: int,
27+
kernel: int,
28+
device,
29+
):
30+
"""Build a clip mask.
31+
32+
Returns a mask of all points which permit a (kernel, kernel) sized
33+
block to sit entirely within the (h, w) index space.
34+
35+
Requires `kernel <= min(h, w)`.
36+
37+
TODO: Should even kernels be forbidden?
38+
Even kernels behave oddly, but are not forbidden for historical reasons.
39+
40+
Args:
41+
h: the height.
42+
w: the width.
43+
kernel_size: the size of the kernel.
44+
device: the target device.
45+
check_kernel: when true, assert that the kernel_size is odd.
46+
47+
Returns:
48+
a (h, w) bool mask tensor.
49+
"""
50+
assert kernel <= min(h, w), f"{kernel=} > min({h=}, {w=})"
51+
52+
h_i, w_i = ndgrid(torch.arange(h, device=device), torch.arange(w, device=device))
53+
return (
54+
(h_i >= kernel // 2) &
55+
(h_i < h - (kernel - 1) // 2) &
56+
(w_i >= kernel // 2) &
57+
(w_i < w - (kernel - 1) // 2)
58+
).reshape(h, w)
59+
60+
2461
def drop_block_2d(
2562
x,
2663
drop_prob: float = 0.1,
@@ -30,23 +67,39 @@ def drop_block_2d(
3067
inplace: bool = False,
3168
batchwise: bool = False
3269
):
33-
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
70+
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
3471
3572
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
3673
runs with success, but needs further validation and possibly optimization for lower runtime impact.
74+
75+
Args:
76+
drop_prob: the probability of dropping any given block.
77+
block_size: the size of the dropped blocks; should be odd.
78+
gamma_scale: adjustment scale for the drop_prob.
79+
with_noise: should normal noise be added to the dropped region?
80+
inplace: if the drop should be applied in-place on the input tensor.
81+
batchwise: should the entire batch use the same drop mask?
82+
83+
Returns:
84+
If inplace, the modified `x`; otherwise, the dropped copy of `x`, on the same device.
3785
"""
3886
B, C, H, W = x.shape
3987
total_size = W * H
40-
clipped_block_size = min(block_size, min(W, H))
88+
89+
# TODO: This behaves oddly when clipped_block_size < block_size, or block_size % 2 == 0.
90+
clipped_block_size = min(block_size, W, H)
91+
4192
# seed_drop_rate, the gamma parameter
4293
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
4394
(W - block_size + 1) * (H - block_size + 1))
4495

4596
# Forces the block to be inside the feature map.
46-
w_i, h_i = ndgrid(torch.arange(W, device=x.device), torch.arange(H, device=x.device))
47-
valid_block = ((w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)) & \
48-
((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
49-
valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
97+
valid_block = clip_mask_2d(
98+
h=H,
99+
w=W,
100+
kernel=clipped_block_size,
101+
device=x.device,
102+
).reshape((1, 1, H, W)).to(dtype=x.dtype)
50103

51104
if batchwise:
52105
# one mask for whole batch, quite a bit faster

0 commit comments

Comments
 (0)