Skip to content

Commit 81714c1

Browse files
committed
fix a bug in drop_filter vs keep_filter; even more tests
1 parent 2ffb37c commit 81714c1

File tree

2 files changed

+176
-92
lines changed

2 files changed

+176
-92
lines changed

tests/layers/test_drop.py

Lines changed: 97 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313

1414
class Conv2dKernelMidpointMask2d(unittest.TestCase):
1515
def test_conv2d_kernel_midpoint_mask_odd_bool(self):
16-
mask = drop.conv2d_kernel_midpoint_mask(shape=(5, 7), kernel=(3, 3), device=torch_device)
16+
mask = drop.conv2d_kernel_midpoint_mask(
17+
shape=(5, 7),
18+
kernel=(3, 3),
19+
device=torch_device,
20+
dtype=torch.bool,
21+
)
1722
print(mask)
1823
assert mask.device == torch.device(torch_device)
1924
assert mask.tolist() == \
@@ -25,32 +30,6 @@ def test_conv2d_kernel_midpoint_mask_odd_bool(self):
2530
[False, False, False, False, False, False, False],
2631
]
2732

28-
def test_conv2d_kernel_midpoint_mask_odd_float_inplace(self):
29-
mask = torch.tensor(
30-
[
31-
[2.0, 1.0, 1.0, 1.0, 1.0, 7.0, 1.0],
32-
[1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 8.0],
33-
[9.0, 1.0, 4.0, 1.0, 1.0, 1.0, 1.0],
34-
[1.0, 1.0, 1.0, 5.0, 1.0, 1.0, 1.0],
35-
[1.0, 1.0, 1.0, 1.0, 6.0, 1.0, 1.0],
36-
],
37-
device=torch_device,
38-
)
39-
drop.conv2d_kernel_midpoint_mask(
40-
kernel=(3, 3),
41-
inplace_mask=mask,
42-
)
43-
print(mask)
44-
assert mask.device == torch.device(torch_device)
45-
assert mask.tolist() == \
46-
[
47-
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
48-
[0.0, 3.0, 1.0, 1.0, 1.0, 1.0, 0.0],
49-
[0.0, 1.0, 4.0, 1.0, 1.0, 1.0, 0.0],
50-
[0.0, 1.0, 1.0, 5.0, 1.0, 1.0, 0.0],
51-
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
52-
]
53-
5433
def test_conv2d_kernel_midpoint_mask_odd_float(self):
5534
mask = drop.conv2d_kernel_midpoint_mask(
5635
shape=(5, 7),
@@ -88,10 +67,14 @@ def test_conv2d_kernel_midpoint_mask_odd_int(self):
8867
]
8968

9069
def test_conv2d_kernel_midpoint_mask_even(self):
91-
mask = drop.conv2d_kernel_midpoint_mask(shape=(5, 7), kernel=(2, 2), device=torch_device)
70+
mask = drop.conv2d_kernel_midpoint_mask(
71+
shape=(5, 7),
72+
kernel=(2, 2),
73+
device=torch_device,
74+
dtype=torch.bool,
75+
)
9276
print(mask)
9377
assert mask.device == torch.device(torch_device)
94-
# TODO: This is a suprising result; should even kernels be forbidden?
9578
assert mask.tolist() == \
9679
[
9780
[False, False, False, False, False, False, False],
@@ -103,9 +86,93 @@ def test_conv2d_kernel_midpoint_mask_even(self):
10386

10487
def test_clip_mask_2d_kernel_too_big(self):
10588
try:
106-
drop.conv2d_kernel_midpoint_mask(shape=(4, 7), kernel=(5, 5), device=torch_device)
89+
drop.conv2d_kernel_midpoint_mask(
90+
shape=(4, 7),
91+
kernel=(5, 5),
92+
device=torch_device,
93+
dtype=torch.bool,
94+
)
10795
raise RuntimeError("Expected throw")
10896

10997
except AssertionError as e:
11098
assert "kernel=(5, 5) ! <= shape=(4, 7)" in e.args[0]
11199

100+
101+
class DropBlock2dDropFilterTest(unittest.TestCase):
102+
def test_drop_filter(self):
103+
selection = torch.tensor(
104+
[
105+
[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0],
106+
[0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0],
107+
[0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0],
108+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
109+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0],
110+
],
111+
device=torch_device,
112+
).unsqueeze(0).unsqueeze(0)
113+
114+
result = drop.drop_block_2d_drop_filter_(
115+
selection=selection,
116+
kernel=(2, 3),
117+
messy=False
118+
).squeeze()
119+
print(result)
120+
assert result.device == torch.device(torch_device)
121+
assert result.tolist() == \
122+
[
123+
[1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
124+
[1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0],
125+
[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0],
126+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
127+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
128+
]
129+
130+
def test_drop_filter_messy(self):
131+
selection = torch.tensor(
132+
[
133+
[0, 0, 0, 1, 0, 0, 0],
134+
[0, 1, 0, 0, 0, 0, 0],
135+
[0, 0, 0, 0, 0, 1, 0],
136+
[0, 0, 0, 0, 0, 0, 0],
137+
[0, 0, 0, 0, 0, 0, 1],
138+
],
139+
device=torch_device,
140+
dtype=torch.int32,
141+
).unsqueeze(0).unsqueeze(0)
142+
143+
result = drop.drop_block_2d_drop_filter_(
144+
selection=selection,
145+
kernel=(2, 3),
146+
messy=True
147+
).squeeze()
148+
print(result)
149+
assert result.device == torch.device(torch_device)
150+
assert result.tolist() == \
151+
[
152+
[1, 1, 1, 1, 1, 0, 0],
153+
[1, 1, 1, 0, 1, 1, 1],
154+
[0, 0, 0, 0, 1, 1, 1],
155+
[0, 0, 0, 0, 0, 1, 1],
156+
[0, 0, 0, 0, 0, 1, 1],
157+
]
158+
159+
class DropBlock2dTest(unittest.TestCase):
160+
def test_drop_block_2d(self):
161+
tensor = torch.ones((1, 1, 200, 300), device=torch_device)
162+
163+
drop_prob=0.1
164+
keep_prob = 1.0 - drop_prob
165+
166+
result = drop.drop_block_2d(
167+
tensor,
168+
drop_prob=drop_prob,
169+
with_noise=True,
170+
).squeeze()
171+
172+
numel = float(result.numel())
173+
unchanged = float(len(result[result == 1.0]))
174+
keep_ratio = unchanged / numel
175+
176+
assert abs(keep_ratio - keep_prob) < 0.05, \
177+
f"abs({keep_ratio=} - {keep_prob=}) ! < 0.05"
178+

timm/layers/drop.py

Lines changed: 79 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@
2121

2222

2323
def conv2d_kernel_midpoint_mask(
24-
kernel: Tuple[int, int],
2524
*,
26-
inplace_mask = None,
27-
shape: Optional[Tuple[int, int]] = None,
28-
device = None,
29-
dtype = None,
25+
shape: Tuple[int, int],
26+
kernel: Tuple[int, int],
27+
device,
28+
dtype,
3029
):
3130
"""Build a mask of kernel midpoints.
3231
@@ -39,54 +38,69 @@ def conv2d_kernel_midpoint_mask(
3938
4039
Requires `kernel <= min(h, w)`.
4140
42-
When an `inplace_mask` is not provided, a new mask of `1`s is allocated,
43-
and then the `0` locations are cleared.
44-
45-
When an `inplace_mask` is provided, the `0` locations are cleared on the mask,
46-
and no other changes are made. `shape`, `dtype`, and `device` must match, if
47-
they are provided.
41+
A new mask of `1`s is allocated, and then the `0` locations are cleared.
4842
4943
Args:
5044
kernel: the (kh, hw) shape of the kernel.
51-
inplace_mask: if supplied, updates will apply to the inplace_mask,
52-
and device and dtype will be ignored. Only clears 'false' locations.
5345
shape: the (h, w) shape of the tensor.
5446
device: the target device.
5547
dtype: the target dtype.
5648
5749
Returns:
5850
a (h, w) bool mask tensor.
5951
"""
60-
if inplace_mask is not None:
61-
mask = inplace_mask
52+
h, w = shape
53+
kh, kw = kernel
54+
assert kh <= h and kw <= w, f"{kernel=} ! <= {shape=}"
6255

63-
if shape:
64-
assert shape == mask.shape[-2], f"{shape=} !~= {mask.shape=}"
56+
mask = torch.zeros(shape, dtype=dtype, device=device)
6557

66-
shape = mask.shape
58+
mask[kh//2: h - ((kh - 1) // 2), kw//2: w - ((kw - 1) // 2)] = 1.0
6759

68-
if device:
69-
device = torch.device(device)
70-
assert device == mask.device, f"{device=} != {mask.device=}"
60+
return mask
7161

72-
if dtype:
73-
dtype = torch.dtype(dtype)
74-
assert dtype == inplace_mask.dtype, f"{dtype=} != {mask.dtype=}"
7562

76-
else:
77-
mask = torch.ones(shape, dtype=dtype, device=device)
63+
def drop_block_2d_drop_filter_(
64+
*,
65+
selection,
66+
kernel: Tuple[int, int],
67+
messy: bool
68+
):
69+
"""Convert drop block gamma noise to a drop filter.
70+
71+
This is a deterministic internal component of drop_block_2d.
72+
73+
Args:
74+
selection: 4D (B, C, H, W) input selection noise;
75+
`1.0` at the midpoints of selected blocks to drop,
76+
`0.0` everywhere else. Expected to be gamma noise.
77+
kernel: the shape of the 2d kernel.
78+
messy: permit partial blocks at the edges, faster.
79+
80+
Returns:
81+
A drop filter, `1.0` at points to drop, `0.0` at points to keep.
82+
"""
83+
84+
if not messy:
85+
selection = selection * conv2d_kernel_midpoint_mask(
86+
shape=selection.shape[-2:],
87+
kernel=kernel,
88+
dtype=selection.dtype,
89+
device=selection.device,
90+
)
7891

79-
h, w = shape
8092
kh, kw = kernel
81-
assert kh <= h and kw <= w, f"{kernel=} ! <= {shape=}"
8293

83-
# Set to 0, rather than set to 1, so we can clear the inplace mask.
84-
mask[:kh // 2, :] = 0
85-
mask[h - (kh - 1) // 2:, :] = 0
86-
mask[:, :kw // 2] = 0
87-
mask[:, w - (kw - 1) // 2:] = 0
94+
drop_filter = F.max_pool2d(
95+
selection,
96+
kernel_size=kernel,
97+
stride=1,
98+
padding=[kh // 2, kw // 2],
99+
)
100+
if (kh % 2 == 0) or (kw % 2 == 0):
101+
drop_filter = drop_filter[..., (kh%2==0):, (kw%2==0):]
88102

89-
return mask
103+
return drop_filter
90104

91105

92106
def drop_block_2d(
@@ -117,57 +131,60 @@ def drop_block_2d(
117131
If inplace, the modified `x`; otherwise, the dropped copy of `x`, on the same device.
118132
"""
119133
B, C, H, W = x.shape
120-
total_size = W * H
121134

122135
# TODO: This behaves oddly when clipped_block_size < block_size.
123-
clipped_block_size = min(block_size, H, W)
136+
kh = kw = block_size
137+
138+
kernel = [min(kh, H), min(kw, W)]
139+
kh, kw = kernel
124140

125141
gamma = (
126-
float(gamma_scale * drop_prob * total_size)
127-
/ float(clipped_block_size ** 2)
128-
/ float((H - block_size + 1) * (W - block_size + 1))
142+
float(gamma_scale * drop_prob * H * W)
143+
/ float(kh * kw)
144+
/ float((H - kh + 1) * (W - kw + 1))
129145
)
130146

131147
# batchwise => one mask for whole batch, quite a bit faster
132148
mask_shape = (1 if batchwise else B, C, H, W)
133149

134-
block_mask = torch.empty(
150+
selection = torch.empty(
135151
mask_shape,
136152
dtype=x.dtype,
137153
device=x.device
138154
).bernoulli_(gamma)
139155

140-
if not messy:
141-
conv2d_kernel_midpoint_mask(
142-
kernel=(clipped_block_size, clipped_block_size),
143-
inplace_mask=block_mask,
144-
)
145-
146-
block_mask = F.max_pool2d(
147-
block_mask,
148-
kernel_size=clipped_block_size,
149-
stride=1,
150-
padding=clipped_block_size // 2)
156+
drop_filter = drop_block_2d_drop_filter_(
157+
selection=selection,
158+
kernel=kernel,
159+
messy=messy,
160+
)
161+
keep_filter = 1.0 - drop_filter
151162

152163
if inplace:
153-
x.mul_(block_mask)
164+
x.mul_(keep_filter)
154165
else:
155-
x = x * block_mask
156-
157-
# From this point on, we do inplace ops on X.
166+
x = x * keep_filter
158167

159168
if with_noise:
160-
noise = torch.randn(mask_shape, dtype=x.dtype, device=x.device)
161169
# x += (noise * (1 - block_mask))
162-
block_mask.neg_().add_(1)
163-
noise.mul_(block_mask)
164-
x.add_(noise)
170+
noise = torch.randn(mask_shape, dtype=x.dtype, device=x.device)
171+
172+
if inplace:
173+
noise.mul_(drop_filter)
174+
x.add_(noise)
175+
else:
176+
x = x + noise * drop_filter
165177

166178
else:
167179
# x *= (size(block_mask) / sum(block_mask))
168-
total = block_mask.to(dtype=torch.float32).sum()
169-
normalize_scale = block_mask.numel() / total.add(1e-7).to(x.dtype)
170-
x.mul_(normalize_scale)
180+
count = keep_filter.numel()
181+
total = keep_filter.to(dtype=torch.float32).sum()
182+
normalize_scale = count / total.add(1e-7).to(x.dtype)
183+
184+
if inplace:
185+
x.mul_(normalize_scale)
186+
else:
187+
x = x * normalize_scale
171188

172189
return x
173190

0 commit comments

Comments
 (0)