Skip to content

Commit 2ffb37c

Browse files
committed
Unify drop_block_2d / drop_block_fast_2d; add some actual tests
1 parent fa047b8 commit 2ffb37c

File tree

4 files changed

+220
-84
lines changed

4 files changed

+220
-84
lines changed

tests/layers/test_drop.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,32 @@ def test_conv2d_kernel_midpoint_mask_odd_bool(self):
2525
[False, False, False, False, False, False, False],
2626
]
2727

28+
def test_conv2d_kernel_midpoint_mask_odd_float_inplace(self):
29+
mask = torch.tensor(
30+
[
31+
[2.0, 1.0, 1.0, 1.0, 1.0, 7.0, 1.0],
32+
[1.0, 3.0, 1.0, 1.0, 1.0, 1.0, 8.0],
33+
[9.0, 1.0, 4.0, 1.0, 1.0, 1.0, 1.0],
34+
[1.0, 1.0, 1.0, 5.0, 1.0, 1.0, 1.0],
35+
[1.0, 1.0, 1.0, 1.0, 6.0, 1.0, 1.0],
36+
],
37+
device=torch_device,
38+
)
39+
drop.conv2d_kernel_midpoint_mask(
40+
kernel=(3, 3),
41+
inplace_mask=mask,
42+
)
43+
print(mask)
44+
assert mask.device == torch.device(torch_device)
45+
assert mask.tolist() == \
46+
[
47+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
48+
[0.0, 3.0, 1.0, 1.0, 1.0, 1.0, 0.0],
49+
[0.0, 1.0, 4.0, 1.0, 1.0, 1.0, 0.0],
50+
[0.0, 1.0, 1.0, 5.0, 1.0, 1.0, 0.0],
51+
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
52+
]
53+
2854
def test_conv2d_kernel_midpoint_mask_odd_float(self):
2955
mask = drop.conv2d_kernel_midpoint_mask(
3056
shape=(5, 7),

tests/test_models.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def test_model_forward(model_name, batch_size):
204204
@pytest.mark.parametrize('model_name', list_models(exclude_filters=EXCLUDE_FILTERS, name_matches_cfg=True))
205205
@pytest.mark.parametrize('batch_size', [2])
206206
def test_model_backward(model_name, batch_size):
207-
"""Run a single forward pass with each model"""
207+
"""Run a single forward and backward pass with each model"""
208208
input_size = _get_input_size(model_name=model_name, target=TARGET_BWD_SIZE)
209209
if max(input_size) > MAX_BWD_SIZE:
210210
pytest.skip("Fixed input size model > limit.")
@@ -594,7 +594,7 @@ def _create_fx_model(model, train=False):
594594
return fx_model
595595

596596

597-
EXCLUDE_FX_FILTERS = ['vit_gi*', 'hiera*']
597+
EXCLUDE_FX_FILTERS = ['vit_gi*', 'hiera*', '*dropblock*']
598598
# not enough memory to run fx on more models than other tests
599599
if 'GITHUB_ACTIONS' in os.environ:
600600
EXCLUDE_FX_FILTERS += [

timm/layers/drop.py

Lines changed: 125 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,19 @@
1414
1515
Hacked together by / Copyright 2020 Ross Wightman
1616
"""
17+
from typing import Optional, Tuple
1718
import torch
1819
import torch.nn as nn
1920
import torch.nn.functional as F
2021

2122

2223
def conv2d_kernel_midpoint_mask(
23-
shape: (int, int),
24-
kernel: (int, int),
25-
device,
26-
dtype = torch.bool,
24+
kernel: Tuple[int, int],
25+
*,
26+
inplace_mask = None,
27+
shape: Optional[Tuple[int, int]] = None,
28+
device = None,
29+
dtype = None,
2730
):
2831
"""Build a mask of kernel midpoints.
2932
@@ -36,28 +39,53 @@ def conv2d_kernel_midpoint_mask(
3639
3740
Requires `kernel <= min(h, w)`.
3841
42+
When an `inplace_mask` is not provided, a new mask of `1`s is allocated,
43+
and then the `0` locations are cleared.
44+
45+
When an `inplace_mask` is provided, the `0` locations are cleared on the mask,
46+
and no other changes are made. `shape`, `dtype`, and `device` must match, if
47+
they are provided.
48+
3949
Args:
40-
shape: the (h, w) shape of the tensor.
4150
kernel: the (kh, hw) shape of the kernel.
51+
inplace_mask: if supplied, updates will apply to the inplace_mask,
52+
and device and dtype will be ignored. Only clears 'false' locations.
53+
shape: the (h, w) shape of the tensor.
4254
device: the target device.
43-
check_kernel: when true, assert that the kernel_size is odd.
55+
dtype: the target dtype.
4456
4557
Returns:
4658
a (h, w) bool mask tensor.
4759
"""
60+
if inplace_mask is not None:
61+
mask = inplace_mask
62+
63+
if shape:
64+
assert shape == mask.shape[-2], f"{shape=} !~= {mask.shape=}"
65+
66+
shape = mask.shape
67+
68+
if device:
69+
device = torch.device(device)
70+
assert device == mask.device, f"{device=} != {mask.device=}"
71+
72+
if dtype:
73+
dtype = torch.dtype(dtype)
74+
assert dtype == inplace_mask.dtype, f"{dtype=} != {mask.dtype=}"
75+
76+
else:
77+
mask = torch.ones(shape, dtype=dtype, device=device)
78+
4879
h, w = shape
4980
kh, kw = kernel
5081
assert kh <= h and kw <= w, f"{kernel=} ! <= {shape=}"
5182

52-
mask = torch.zeros((h, w), dtype=dtype, device=device)
83+
# Set to 0, rather than set to 1, so we can clear the inplace mask.
84+
mask[:kh // 2, :] = 0
85+
mask[h - (kh - 1) // 2:, :] = 0
86+
mask[:, :kw // 2] = 0
87+
mask[:, w - (kw - 1) // 2:] = 0
5388

54-
h_start = kh // 2
55-
h_end = (kh - 1) // 2
56-
57-
w_start = kw // 2
58-
w_end = (kw - 1) // 2
59-
60-
mask[h_start:h - h_end, w_start:w - w_end] = 1
6189
return mask
6290

6391

@@ -68,7 +96,8 @@ def drop_block_2d(
6896
gamma_scale: float = 1.0,
6997
with_noise: bool = False,
7098
inplace: bool = False,
71-
batchwise: bool = False
99+
batchwise: bool = False,
100+
messy: bool = False,
72101
):
73102
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
74103
@@ -82,6 +111,7 @@ def drop_block_2d(
82111
with_noise: should normal noise be added to the dropped region?
83112
inplace: if the drop should be applied in-place on the input tensor.
84113
batchwise: should the entire batch use the same drop mask?
114+
messy: partial-blocks at the edges, faster.
85115
86116
Returns:
87117
If inplace, the modified `x`; otherwise, the dropped copy of `x`, on the same device.
@@ -90,44 +120,55 @@ def drop_block_2d(
90120
total_size = W * H
91121

92122
# TODO: This behaves oddly when clipped_block_size < block_size.
93-
clipped_block_size = min(block_size, W, H)
123+
clipped_block_size = min(block_size, H, W)
124+
125+
gamma = (
126+
float(gamma_scale * drop_prob * total_size)
127+
/ float(clipped_block_size ** 2)
128+
/ float((H - block_size + 1) * (W - block_size + 1))
129+
)
94130

95-
# seed_drop_rate, the gamma parameter
96-
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
97-
(W - block_size + 1) * (H - block_size + 1))
131+
# batchwise => one mask for whole batch, quite a bit faster
132+
mask_shape = (1 if batchwise else B, C, H, W)
98133

99-
# Forces the block to be inside the feature map.
100-
valid_block = conv2d_kernel_midpoint_mask(
101-
shape=(H, W),
102-
kernel=(clipped_block_size, clipped_block_size),
103-
device=x.device,
134+
block_mask = torch.empty(
135+
mask_shape,
104136
dtype=x.dtype,
105-
).unsqueeze().unsqueeze()
137+
device=x.device
138+
).bernoulli_(gamma)
106139

107-
if batchwise:
108-
# one mask for whole batch, quite a bit faster
109-
uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
110-
else:
111-
uniform_noise = torch.rand_like(x)
112-
block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
113-
block_mask = -F.max_pool2d(
114-
-block_mask,
115-
kernel_size=clipped_block_size, # block_size,
140+
if not messy:
141+
conv2d_kernel_midpoint_mask(
142+
kernel=(clipped_block_size, clipped_block_size),
143+
inplace_mask=block_mask,
144+
)
145+
146+
block_mask = F.max_pool2d(
147+
block_mask,
148+
kernel_size=clipped_block_size,
116149
stride=1,
117150
padding=clipped_block_size // 2)
118151

152+
if inplace:
153+
x.mul_(block_mask)
154+
else:
155+
x = x * block_mask
156+
157+
# From this point on, we do inplace ops on X.
158+
119159
if with_noise:
120-
normal_noise = torch.randn((1, C, H, W), dtype=x.dtype, device=x.device) if batchwise else torch.randn_like(x)
121-
if inplace:
122-
x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
123-
else:
124-
x = x * block_mask + normal_noise * (1 - block_mask)
160+
noise = torch.randn(mask_shape, dtype=x.dtype, device=x.device)
161+
# x += (noise * (1 - block_mask))
162+
block_mask.neg_().add_(1)
163+
noise.mul_(block_mask)
164+
x.add_(noise)
165+
125166
else:
126-
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)).to(x.dtype)
127-
if inplace:
128-
x.mul_(block_mask * normalize_scale)
129-
else:
130-
x = x * block_mask * normalize_scale
167+
# x *= (size(block_mask) / sum(block_mask))
168+
total = block_mask.to(dtype=torch.float32).sum()
169+
normalize_scale = block_mask.numel() / total.add(1e-7).to(x.dtype)
170+
x.mul_(normalize_scale)
171+
131172
return x
132173

133174

@@ -144,35 +185,37 @@ def drop_block_fast_2d(
144185
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
145186
block mask at edges.
146187
"""
147-
B, C, H, W = x.shape
148-
total_size = W * H
149-
clipped_block_size = min(block_size, min(W, H))
150-
gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
151-
(W - block_size + 1) * (H - block_size + 1))
152-
153-
block_mask = torch.empty_like(x).bernoulli_(gamma)
154-
block_mask = F.max_pool2d(
155-
block_mask.to(x.dtype), kernel_size=clipped_block_size, stride=1, padding=clipped_block_size // 2)
156-
157-
if with_noise:
158-
normal_noise = torch.empty_like(x).normal_()
159-
if inplace:
160-
x.mul_(1. - block_mask).add_(normal_noise * block_mask)
161-
else:
162-
x = x * (1. - block_mask) + normal_noise * block_mask
163-
else:
164-
block_mask = 1 - block_mask
165-
normalize_scale = (block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)).to(dtype=x.dtype)
166-
if inplace:
167-
x.mul_(block_mask * normalize_scale)
168-
else:
169-
x = x * block_mask * normalize_scale
170-
return x
188+
drop_block_2d(
189+
x=x,
190+
drop_prob=drop_prob,
191+
block_size=block_size,
192+
gamma_scale=gamma_scale,
193+
with_noise=with_noise,
194+
inplace=inplace,
195+
batchwise=True,
196+
messy=True,
197+
)
171198

172199

173200
class DropBlock2d(nn.Module):
174-
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
201+
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
202+
203+
Args:
204+
drop_prob: the probability of dropping any given block.
205+
block_size: the size of the dropped blocks; should be odd.
206+
gamma_scale: adjustment scale for the drop_prob.
207+
with_noise: should normal noise be added to the dropped region?
208+
inplace: if the drop should be applied in-place on the input tensor.
209+
batchwise: should the entire batch use the same drop mask?
210+
messy: partial-blocks at the edges, faster.
175211
"""
212+
drop_prob: float
213+
block_size: int
214+
gamma_scale: float
215+
with_noise: bool
216+
inplace: bool
217+
batchwise: bool
218+
messy: bool
176219

177220
def __init__(
178221
self,
@@ -182,25 +225,30 @@ def __init__(
182225
with_noise: bool = False,
183226
inplace: bool = False,
184227
batchwise: bool = False,
185-
fast: bool = True):
228+
messy: bool = True,
229+
):
186230
super(DropBlock2d, self).__init__()
187231
self.drop_prob = drop_prob
188232
self.gamma_scale = gamma_scale
189233
self.block_size = block_size
190234
self.with_noise = with_noise
191235
self.inplace = inplace
192236
self.batchwise = batchwise
193-
self.fast = fast # FIXME finish comparisons of fast vs not
237+
self.messy = messy
194238

195239
def forward(self, x):
196240
if not self.training or not self.drop_prob:
197241
return x
198-
if self.fast:
199-
return drop_block_fast_2d(
200-
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace)
201-
else:
202-
return drop_block_2d(
203-
x, self.drop_prob, self.block_size, self.gamma_scale, self.with_noise, self.inplace, self.batchwise)
242+
243+
return drop_block_2d(
244+
x=x,
245+
drop_prob=self.drop_prob,
246+
block_size=self.block_size,
247+
gamma_scale=self.gamma_scale,
248+
with_noise=self.with_noise,
249+
inplace=self.inplace,
250+
batchwise=self.batchwise,
251+
messy=self.messy)
204252

205253

206254
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):

0 commit comments

Comments
 (0)