Skip to content

Commit b2fe962

Browse files
committed
inplace speedups
1 parent 8521edc commit b2fe962

File tree

1 file changed

+63
-46
lines changed

1 file changed

+63
-46
lines changed

timm/layers/drop.py

Lines changed: 63 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,11 @@
2323

2424
def conv2d_kernel_midpoint_mask(
2525
*,
26-
shape: Tuple[int, int],
2726
kernel: Tuple[int, int],
28-
device,
29-
dtype,
27+
inplace=None,
28+
shape: Optional[Tuple[int, int]] = None,
29+
device=None,
30+
dtype=None,
3031
):
3132
"""Build a mask of kernel midpoints.
3233
@@ -43,30 +44,46 @@ def conv2d_kernel_midpoint_mask(
4344
4445
Args:
4546
kernel: the (kh, hw) shape of the kernel.
47+
inplace: use the provided tensor as the mask; set masked-out values to 0.
4648
shape: the (h, w) shape of the tensor.
4749
device: the target device.
4850
dtype: the target dtype.
4951
5052
Returns:
5153
a (h, w) bool mask tensor.
5254
"""
55+
if inplace is None:
56+
assert shape is not None, f"shape is required when inplace is None."
57+
assert dtype is not None, f"dtype is required when inplace is None."
58+
assert device is not None, f"device is required when inplace is None."
59+
60+
mask = torch.ones(shape, dtype=dtype, device=device)
61+
else:
62+
assert shape is None, f"shape and inplace are incompatile"
63+
assert dtype is None, f"dtype and inplace are incompatile"
64+
assert device is None, f"device and inplace are incompatile"
65+
66+
mask = inplace
67+
shape = inplace.shape[-2:]
68+
device = inplace.device
69+
dtype = inplace.dtype
70+
5371
h, w = shape
5472
kh, kw = kernel
5573
assert kh <= h and kw <= w, f"{kernel=} ! <= {shape=}"
5674

57-
mask = torch.zeros(shape, dtype=dtype, device=device)
58-
59-
mask[
60-
kh // 2 : h - ((kh - 1) // 2),
61-
kw // 2 : w - ((kw - 1) // 2),
62-
] = 1.0
75+
mask[..., 0 : kh // 2, :] = 0
76+
mask[..., :, 0 : kw // 2 :] = 0
77+
mask[..., h - ((kh - 1) // 2) :, :] = 0
78+
mask[..., :, w - ((kw - 1) // 2) :] = 0
6379

6480
return mask
6581

6682

6783
def drop_block_2d_drop_filter_(
6884
*,
6985
selection,
86+
inplace: bool = False,
7087
kernel: Tuple[int, int],
7188
partial_edge_blocks: bool,
7289
):
@@ -78,19 +95,26 @@ def drop_block_2d_drop_filter_(
7895
selection: 4D (B, C, H, W) input selection noise;
7996
`1.0` at the midpoints of selected blocks to drop,
8097
`0.0` everywhere else. Expected to be gamma noise.
98+
inplace: permit in-place updates to `selection`.
8199
kernel: the shape of the 2d kernel.
82100
partial_edge_blocks: permit partial blocks at the edges, faster.
83101
84102
Returns:
85103
A drop filter, `1.0` at points to drop, `0.0` at points to keep.
86104
"""
87105
if not partial_edge_blocks:
88-
selection = selection * conv2d_kernel_midpoint_mask(
89-
shape=selection.shape[-2:],
90-
kernel=kernel,
91-
dtype=selection.dtype,
92-
device=selection.device,
93-
)
106+
if inplace:
107+
selection = conv2d_kernel_midpoint_mask(
108+
kernel=kernel,
109+
inplace=selection,
110+
)
111+
else:
112+
selection = selection * conv2d_kernel_midpoint_mask(
113+
shape=selection.shape[-2:],
114+
kernel=kernel,
115+
dtype=selection.dtype,
116+
device=selection.device,
117+
)
94118

95119
kh, kw = kernel
96120

@@ -136,62 +160,55 @@ def drop_block_2d(
136160
B, C, H, W = x.shape
137161

138162
# TODO: This behaves oddly when clipped_block_size < block_size.
139-
kh = kw = block_size
140-
141-
kernel = [min(kh, H), min(kw, W)]
163+
# We could expose non-square blocks above this layer.
164+
kernel = [min(block_size, H), min(block_size, W)]
142165
kh, kw = kernel
143166

167+
# batchwise => one mask for whole batch, quite a bit faster
168+
noise_shape = (1 if batchwise else B, C, H, W)
169+
144170
gamma = (
145171
float(gamma_scale * drop_prob * H * W)
146172
/ float(kh * kw)
147173
/ float((H - kh + 1) * (W - kw + 1))
148174
)
149175

150-
# batchwise => one mask for whole batch, quite a bit faster
151-
mask_shape = (1 if batchwise else B, C, H, W)
152-
153-
selection = torch.empty(
154-
mask_shape,
155-
dtype=x.dtype,
156-
device=x.device,
157-
).bernoulli_(gamma)
158-
159176
drop_filter = drop_block_2d_drop_filter_(
160-
selection=selection,
161177
kernel=kernel,
162178
partial_edge_blocks=partial_edge_blocks,
179+
inplace=True,
180+
selection=torch.empty(
181+
noise_shape,
182+
dtype=x.dtype,
183+
device=x.device,
184+
).bernoulli_(gamma),
163185
)
164186
keep_filter = 1.0 - drop_filter
165187

166-
if inplace:
167-
x.mul_(keep_filter)
168-
else:
169-
x = x * keep_filter
170-
171188
if with_noise:
172-
# x += (noise * (1 - block_mask))
173-
noise = torch.randn(
174-
mask_shape,
175-
dtype=x.dtype,
176-
device=x.device,
177-
)
189+
# x += (noise * drop_filter)
190+
drop_noise = torch.randn_like(drop_filter)
191+
drop_noise.mul_(drop_filter)
178192

179193
if inplace:
180-
noise.mul_(drop_filter)
181-
x.add_(noise)
194+
x.mul_(keep_filter)
195+
x.add_(drop_noise)
196+
182197
else:
183-
x = x + noise * drop_filter
198+
x = x * keep_filter + drop_noise
184199

185200
else:
186-
# x *= (size(block_mask) / sum(block_mask))
201+
# x *= (size(keep_filter) / (sum(keep_filter) + eps))
187202
count = keep_filter.numel()
188203
total = keep_filter.to(dtype=torch.float32).sum()
189-
normalize_scale = count / total.add(1e-7).to(x.dtype)
204+
keep_scale = count / total.add(1e-7).to(x.dtype)
205+
206+
keep_filter.mul_(keep_scale)
190207

191208
if inplace:
192-
x.mul_(normalize_scale)
209+
x.mul_(keep_filter)
193210
else:
194-
x = x * normalize_scale
211+
x = x * keep_filter
195212

196213
return x
197214

0 commit comments

Comments
 (0)