Skip to content

Commit dc887ad

Browse files
committed
format
1 parent d77d77a commit dc887ad

File tree

1 file changed

+19
-8
lines changed

1 file changed

+19
-8
lines changed

timm/layers/drop.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,10 @@ def conv2d_kernel_midpoint_mask(
5656

5757
mask = torch.zeros(shape, dtype=dtype, device=device)
5858

59-
mask[kh // 2 : h - ((kh - 1) // 2), kw // 2 : w - ((kw - 1) // 2)] = 1.0
59+
mask[
60+
kh // 2 : h - ((kh - 1) // 2),
61+
kw // 2 : w - ((kw - 1) // 2),
62+
] = 1.0
6063

6164
return mask
6265

@@ -145,9 +148,11 @@ def drop_block_2d(
145148
# batchwise => one mask for whole batch, quite a bit faster
146149
mask_shape = (1 if batchwise else B, C, H, W)
147150

148-
selection = torch.empty(mask_shape, dtype=x.dtype, device=x.device).bernoulli_(
149-
gamma
150-
)
151+
selection = torch.empty(
152+
mask_shape,
153+
dtype=x.dtype,
154+
device=x.device,
155+
).bernoulli_(gamma)
151156

152157
drop_filter = drop_block_2d_drop_filter_(
153158
selection=selection,
@@ -163,7 +168,11 @@ def drop_block_2d(
163168

164169
if with_noise:
165170
# x += (noise * (1 - block_mask))
166-
noise = torch.randn(mask_shape, dtype=x.dtype, device=x.device)
171+
noise = torch.randn(
172+
mask_shape,
173+
dtype=x.dtype,
174+
device=x.device,
175+
)
167176

168177
if inplace:
169178
noise.mul_(drop_filter)
@@ -281,10 +290,12 @@ def drop_path(
281290
if drop_prob == 0.0 or not training:
282291
return x
283292
keep_prob = 1 - drop_prob
284-
shape = (x.shape[0],) + (1,) * (
285-
x.ndim - 1
286-
) # work with diff dim tensors, not just 2D ConvNets
293+
294+
# work with diff dim tensors, not just 2D ConvNets
295+
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
296+
287297
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
298+
288299
if keep_prob > 0.0 and scale_by_keep:
289300
random_tensor.div_(keep_prob)
290301
return x * random_tensor

0 commit comments

Comments
 (0)