Skip to content

Commit 8521edc

Browse files
committed
format
1 parent dc887ad commit 8521edc

File tree

1 file changed

+19
-5
lines changed

1 file changed

+19
-5
lines changed

timm/layers/drop.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,10 @@ def conv2d_kernel_midpoint_mask(
6565

6666

6767
def drop_block_2d_drop_filter_(
68-
*, selection, kernel: Tuple[int, int], partial_edge_blocks: bool
68+
*,
69+
selection,
70+
kernel: Tuple[int, int],
71+
partial_edge_blocks: bool,
6972
):
7073
"""Convert drop block gamma noise to a drop filter.
7174
@@ -81,7 +84,6 @@ def drop_block_2d_drop_filter_(
8184
Returns:
8285
A drop filter, `1.0` at points to drop, `0.0` at points to keep.
8386
"""
84-
8587
if not partial_edge_blocks:
8688
selection = selection * conv2d_kernel_midpoint_mask(
8789
shape=selection.shape[-2:],
@@ -276,7 +278,10 @@ def forward(self, x):
276278

277279

278280
def drop_path(
279-
x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
281+
x,
282+
drop_prob: float = 0.0,
283+
training: bool = False,
284+
scale_by_keep: bool = True,
280285
):
281286
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
282287
@@ -304,13 +309,22 @@ def drop_path(
304309
class DropPath(nn.Module):
305310
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
306311

307-
def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
312+
def __init__(
313+
self,
314+
drop_prob: float = 0.0,
315+
scale_by_keep: bool = True,
316+
):
308317
super(DropPath, self).__init__()
309318
self.drop_prob = drop_prob
310319
self.scale_by_keep = scale_by_keep
311320

312321
def forward(self, x):
313-
return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
322+
return drop_path(
323+
x,
324+
drop_prob=self.drop_prob,
325+
training=self.training,
326+
scale_by_keep=self.scale_by_keep,
327+
)
314328

315329
def extra_repr(self):
316330
return f"drop_prob={round(self.drop_prob,3):0.3f}"

0 commit comments

Comments
 (0)