Skip to content

Commit fa047b8

Browse files
committed
Update based on review discussion
1 parent 0c80c3f commit fa047b8

File tree

2 files changed

+74
-31
lines changed

2 files changed

+74
-31
lines changed

tests/layers/test_drop.py

Lines changed: 43 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@
1111
importlib.import_module(torch_backend)
1212
torch_device = os.environ.get('TORCH_DEVICE', 'cpu')
1313

14-
class ClipMaskTests(unittest.TestCase):
15-
def test_clip_mask_2d_odd(self):
16-
mask = drop.clip_mask_2d(h=5, w=7, kernel=3, device=torch_device)
14+
class Conv2dKernelMidpointMask2d(unittest.TestCase):
15+
def test_conv2d_kernel_midpoint_mask_odd_bool(self):
16+
mask = drop.conv2d_kernel_midpoint_mask(shape=(5, 7), kernel=(3, 3), device=torch_device)
1717
print(mask)
1818
assert mask.device == torch.device(torch_device)
1919
assert mask.tolist() == \
@@ -25,8 +25,44 @@ def test_clip_mask_2d_odd(self):
2525
[False, False, False, False, False, False, False],
2626
]
2727

28-
def test_clip_mask_2d_even(self):
29-
mask = drop.clip_mask_2d(h=5, w=7, kernel=2, device=torch_device)
28+
def test_conv2d_kernel_midpoint_mask_odd_float(self):
29+
mask = drop.conv2d_kernel_midpoint_mask(
30+
shape=(5, 7),
31+
kernel=(3, 3),
32+
device=torch_device,
33+
dtype=torch.float32,
34+
)
35+
print(mask)
36+
assert mask.device == torch.device(torch_device)
37+
assert mask.tolist() == \
38+
[
39+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
40+
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
41+
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
42+
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0],
43+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
44+
]
45+
46+
def test_conv2d_kernel_midpoint_mask_odd_int(self):
47+
mask = drop.conv2d_kernel_midpoint_mask(
48+
shape=(5, 7),
49+
kernel=(3, 3),
50+
device=torch_device,
51+
dtype=torch.int32,
52+
)
53+
print(mask)
54+
assert mask.device == torch.device(torch_device)
55+
assert mask.tolist() == \
56+
[
57+
[0, 0, 0, 0, 0, 0, 0],
58+
[0, 1, 1, 1, 1, 1, 0],
59+
[0, 1, 1, 1, 1, 1, 0],
60+
[0, 1, 1, 1, 1, 1, 0],
61+
[0, 0, 0, 0, 0, 0, 0],
62+
]
63+
64+
def test_conv2d_kernel_midpoint_mask_even(self):
65+
mask = drop.conv2d_kernel_midpoint_mask(shape=(5, 7), kernel=(2, 2), device=torch_device)
3066
print(mask)
3167
assert mask.device == torch.device(torch_device)
3268
# TODO: This is a suprising result; should even kernels be forbidden?
@@ -41,9 +77,9 @@ def test_clip_mask_2d_even(self):
4177

4278
def test_clip_mask_2d_kernel_too_big(self):
4379
try:
44-
drop.clip_mask_2d(h=4, w=7, kernel=5, device=torch_device)
80+
drop.conv2d_kernel_midpoint_mask(shape=(4, 7), kernel=(5, 5), device=torch_device)
4581
raise RuntimeError("Expected throw")
4682

4783
except AssertionError as e:
48-
assert "kernel=5 > min(h=4, w=7)" in e.args[0]
84+
assert "kernel=(5, 5) ! <= shape=(4, 7)" in e.args[0]
4985

timm/layers/drop.py

Lines changed: 31 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,38 +19,45 @@
1919
import torch.nn.functional as F
2020

2121

22-
def clip_mask_2d(
23-
h: int,
24-
w: int,
25-
kernel: int,
22+
def conv2d_kernel_midpoint_mask(
23+
shape: (int, int),
24+
kernel: (int, int),
2625
device,
26+
dtype = torch.bool,
2727
):
28-
"""Build a clip mask.
28+
"""Build a mask of kernel midpoints.
2929
30-
Returns a mask of all points which permit a (kernel, kernel) sized
31-
block to sit entirely within the (h, w) index space.
30+
This predicts the kernel midpoints that conv2d (and related kernel functions)
31+
would place a kernel.
3232
33-
Requires `kernel <= min(h, w)`.
33+
The *midpoint* of a kernel is computed as ``size / 2``:
34+
* the midpoint of odd kernels is the middle: `mid(3) == 1`
35+
* the midpoint of even kernels is the first point in the second half: `mid(4) == 2`
3436
35-
TODO: Should even kernels be forbidden?
36-
Even kernels behave oddly, but are not forbidden for historical reasons.
37+
Requires `kernel <= min(h, w)`.
3738
3839
Args:
39-
h: the height.
40-
w: the width.
41-
kernel_size: the size of the kernel.
40+
shape: the (h, w) shape of the tensor.
41+
kernel: the (kh, hw) shape of the kernel.
4242
device: the target device.
4343
check_kernel: when true, assert that the kernel_size is odd.
4444
4545
Returns:
4646
a (h, w) bool mask tensor.
4747
"""
48-
assert kernel <= min(h, w), f"{kernel=} > min({h=}, {w=})"
48+
h, w = shape
49+
kh, kw = kernel
50+
assert kh <= h and kw <= w, f"{kernel=} ! <= {shape=}"
51+
52+
mask = torch.zeros((h, w), dtype=dtype, device=device)
53+
54+
h_start = kh // 2
55+
h_end = (kh - 1) // 2
56+
57+
w_start = kw // 2
58+
w_end = (kw - 1) // 2
4959

50-
mask = torch.zeros((h, w), dtype=torch.bool, device=device)
51-
start = kernel // 2
52-
end = ((kernel - 1) // 2)
53-
mask[start:h-end, start:w-end] = True
60+
mask[h_start:h - h_end, w_start:w - w_end] = 1
5461
return mask
5562

5663

@@ -82,20 +89,20 @@ def drop_block_2d(
8289
B, C, H, W = x.shape
8390
total_size = W * H
8491

85-
# TODO: This behaves oddly when clipped_block_size < block_size, or block_size % 2 == 0.
92+
# TODO: This behaves oddly when clipped_block_size < block_size.
8693
clipped_block_size = min(block_size, W, H)
8794

8895
# seed_drop_rate, the gamma parameter
8996
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
9097
(W - block_size + 1) * (H - block_size + 1))
9198

9299
# Forces the block to be inside the feature map.
93-
valid_block = clip_mask_2d(
94-
h=H,
95-
w=W,
96-
kernel=clipped_block_size,
100+
valid_block = conv2d_kernel_midpoint_mask(
101+
shape=(H, W),
102+
kernel=(clipped_block_size, clipped_block_size),
97103
device=x.device,
98-
).reshape((1, 1, H, W)).to(dtype=x.dtype)
104+
dtype=x.dtype,
105+
).unsqueeze().unsqueeze()
99106

100107
if batchwise:
101108
# one mask for whole batch, quite a bit faster

0 commit comments

Comments
 (0)