Skip to content

Commit d77d77a

Browse files
committed
format
1 parent a7a3186 commit d77d77a

File tree

1 file changed

+50
-48
lines changed

1 file changed

+50
-48
lines changed

timm/layers/drop.py

Lines changed: 50 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
""" DropBlock, DropPath
1+
"""DropBlock, DropPath
22
33
PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
44
@@ -14,18 +14,19 @@
1414
1515
Hacked together by / Copyright 2020 Ross Wightman
1616
"""
17+
1718
from typing import Optional, Tuple
1819
import torch
1920
import torch.nn as nn
2021
import torch.nn.functional as F
2122

2223

2324
def conv2d_kernel_midpoint_mask(
24-
*,
25-
shape: Tuple[int, int],
26-
kernel: Tuple[int, int],
27-
device,
28-
dtype,
25+
*,
26+
shape: Tuple[int, int],
27+
kernel: Tuple[int, int],
28+
device,
29+
dtype,
2930
):
3031
"""Build a mask of kernel midpoints.
3132
@@ -55,16 +56,13 @@ def conv2d_kernel_midpoint_mask(
5556

5657
mask = torch.zeros(shape, dtype=dtype, device=device)
5758

58-
mask[kh//2: h - ((kh - 1) // 2), kw//2: w - ((kw - 1) // 2)] = 1.0
59+
mask[kh // 2 : h - ((kh - 1) // 2), kw // 2 : w - ((kw - 1) // 2)] = 1.0
5960

6061
return mask
6162

6263

6364
def drop_block_2d_drop_filter_(
64-
*,
65-
selection,
66-
kernel: Tuple[int, int],
67-
partial_edge_blocks: bool
65+
*, selection, kernel: Tuple[int, int], partial_edge_blocks: bool
6866
):
6967
"""Convert drop block gamma noise to a drop filter.
7068
@@ -98,20 +96,20 @@ def drop_block_2d_drop_filter_(
9896
padding=[kh // 2, kw // 2],
9997
)
10098
if (kh % 2 == 0) or (kw % 2 == 0):
101-
drop_filter = drop_filter[..., (kh%2==0):, (kw%2==0):]
99+
drop_filter = drop_filter[..., (kh % 2 == 0) :, (kw % 2 == 0) :]
102100

103101
return drop_filter
104102

105103

106104
def drop_block_2d(
107-
x,
108-
drop_prob: float = 0.1,
109-
block_size: int = 7,
110-
gamma_scale: float = 1.0,
111-
with_noise: bool = False,
112-
inplace: bool = False,
113-
batchwise: bool = False,
114-
partial_edge_blocks: bool = False,
105+
x,
106+
drop_prob: float = 0.1,
107+
block_size: int = 7,
108+
gamma_scale: float = 1.0,
109+
with_noise: bool = False,
110+
inplace: bool = False,
111+
batchwise: bool = False,
112+
partial_edge_blocks: bool = False,
115113
):
116114
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
117115
@@ -147,11 +145,9 @@ def drop_block_2d(
147145
# batchwise => one mask for whole batch, quite a bit faster
148146
mask_shape = (1 if batchwise else B, C, H, W)
149147

150-
selection = torch.empty(
151-
mask_shape,
152-
dtype=x.dtype,
153-
device=x.device
154-
).bernoulli_(gamma)
148+
selection = torch.empty(mask_shape, dtype=x.dtype, device=x.device).bernoulli_(
149+
gamma
150+
)
155151

156152
drop_filter = drop_block_2d_drop_filter_(
157153
selection=selection,
@@ -190,14 +186,14 @@ def drop_block_2d(
190186

191187

192188
def drop_block_fast_2d(
193-
x: torch.Tensor,
194-
drop_prob: float = 0.1,
195-
block_size: int = 7,
196-
gamma_scale: float = 1.0,
197-
with_noise: bool = False,
198-
inplace: bool = False,
189+
x: torch.Tensor,
190+
drop_prob: float = 0.1,
191+
block_size: int = 7,
192+
gamma_scale: float = 1.0,
193+
with_noise: bool = False,
194+
inplace: bool = False,
199195
):
200-
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
196+
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
201197
202198
DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
203199
block mask at edges.
@@ -226,6 +222,7 @@ class DropBlock2d(nn.Module):
226222
batchwise: should the entire batch use the same drop mask?
227223
partial_edge_blocks: partial-blocks at the edges, faster.
228224
"""
225+
229226
drop_prob: float
230227
block_size: int
231228
gamma_scale: float
@@ -235,14 +232,14 @@ class DropBlock2d(nn.Module):
235232
partial_edge_blocks: bool
236233

237234
def __init__(
238-
self,
239-
drop_prob: float = 0.1,
240-
block_size: int = 7,
241-
gamma_scale: float = 1.0,
242-
with_noise: bool = False,
243-
inplace: bool = False,
244-
batchwise: bool = False,
245-
partial_edge_blocks: bool = True,
235+
self,
236+
drop_prob: float = 0.1,
237+
block_size: int = 7,
238+
gamma_scale: float = 1.0,
239+
with_noise: bool = False,
240+
inplace: bool = False,
241+
batchwise: bool = False,
242+
partial_edge_blocks: bool = True,
246243
):
247244
super(DropBlock2d, self).__init__()
248245
self.drop_prob = drop_prob
@@ -265,10 +262,13 @@ def forward(self, x):
265262
with_noise=self.with_noise,
266263
inplace=self.inplace,
267264
batchwise=self.batchwise,
268-
partial_edge_blocks=self.partial_edge_blocks)
265+
partial_edge_blocks=self.partial_edge_blocks,
266+
)
269267

270268

271-
def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: bool = True):
269+
def drop_path(
270+
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
271+
):
272272
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
273273
274274
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
@@ -278,20 +278,22 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b
278278
'survival rate' as the argument.
279279
280280
"""
281-
if drop_prob == 0. or not training:
281+
if drop_prob == 0.0 or not training:
282282
return x
283283
keep_prob = 1 - drop_prob
284-
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
284+
shape = (x.shape[0],) + (1,) * (
285+
x.ndim - 1
286+
) # work with diff dim tensors, not just 2D ConvNets
285287
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
286288
if keep_prob > 0.0 and scale_by_keep:
287289
random_tensor.div_(keep_prob)
288290
return x * random_tensor
289291

290292

291293
class DropPath(nn.Module):
292-
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
293-
"""
294-
def __init__(self, drop_prob: float = 0., scale_by_keep: bool = True):
294+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
295+
296+
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
295297
super(DropPath, self).__init__()
296298
self.drop_prob = drop_prob
297299
self.scale_by_keep = scale_by_keep
@@ -300,4 +302,4 @@ def forward(self, x):
300302
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
301303

302304
def extra_repr(self):
303-
return f'drop_prob={round(self.drop_prob,3):0.3f}'
305+
return f"drop_prob={round(self.drop_prob,3):0.3f}"

0 commit comments

Comments
 (0)