1- """ DropBlock, DropPath
1+ """DropBlock, DropPath
22
33PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
44
1414
1515Hacked together by / Copyright 2020 Ross Wightman
1616"""
17+
1718from typing import Optional , Tuple
1819import torch
1920import torch .nn as nn
2021import torch .nn .functional as F
2122
2223
2324def conv2d_kernel_midpoint_mask (
24- * ,
25- shape : Tuple [int , int ],
26- kernel : Tuple [int , int ],
27- device ,
28- dtype ,
25+ * ,
26+ shape : Tuple [int , int ],
27+ kernel : Tuple [int , int ],
28+ device ,
29+ dtype ,
2930):
3031 """Build a mask of kernel midpoints.
3132
@@ -55,16 +56,13 @@ def conv2d_kernel_midpoint_mask(
5556
5657 mask = torch .zeros (shape , dtype = dtype , device = device )
5758
58- mask [kh // 2 : h - ((kh - 1 ) // 2 ), kw // 2 : w - ((kw - 1 ) // 2 )] = 1.0
59+ mask [kh // 2 : h - ((kh - 1 ) // 2 ), kw // 2 : w - ((kw - 1 ) // 2 )] = 1.0
5960
6061 return mask
6162
6263
6364def drop_block_2d_drop_filter_ (
64- * ,
65- selection ,
66- kernel : Tuple [int , int ],
67- partial_edge_blocks : bool
65+ * , selection , kernel : Tuple [int , int ], partial_edge_blocks : bool
6866):
6967 """Convert drop block gamma noise to a drop filter.
7068
@@ -98,20 +96,20 @@ def drop_block_2d_drop_filter_(
9896 padding = [kh // 2 , kw // 2 ],
9997 )
10098 if (kh % 2 == 0 ) or (kw % 2 == 0 ):
101- drop_filter = drop_filter [..., (kh % 2 == 0 ) :, (kw % 2 == 0 ) :]
99+ drop_filter = drop_filter [..., (kh % 2 == 0 ) :, (kw % 2 == 0 ) :]
102100
103101 return drop_filter
104102
105103
106104def drop_block_2d (
107- x ,
108- drop_prob : float = 0.1 ,
109- block_size : int = 7 ,
110- gamma_scale : float = 1.0 ,
111- with_noise : bool = False ,
112- inplace : bool = False ,
113- batchwise : bool = False ,
114- partial_edge_blocks : bool = False ,
105+ x ,
106+ drop_prob : float = 0.1 ,
107+ block_size : int = 7 ,
108+ gamma_scale : float = 1.0 ,
109+ with_noise : bool = False ,
110+ inplace : bool = False ,
111+ batchwise : bool = False ,
112+ partial_edge_blocks : bool = False ,
115113):
116114 """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
117115
@@ -147,11 +145,9 @@ def drop_block_2d(
147145 # batchwise => one mask for whole batch, quite a bit faster
148146 mask_shape = (1 if batchwise else B , C , H , W )
149147
150- selection = torch .empty (
151- mask_shape ,
152- dtype = x .dtype ,
153- device = x .device
154- ).bernoulli_ (gamma )
148+ selection = torch .empty (mask_shape , dtype = x .dtype , device = x .device ).bernoulli_ (
149+ gamma
150+ )
155151
156152 drop_filter = drop_block_2d_drop_filter_ (
157153 selection = selection ,
@@ -190,14 +186,14 @@ def drop_block_2d(
190186
191187
192188def drop_block_fast_2d (
193- x : torch .Tensor ,
194- drop_prob : float = 0.1 ,
195- block_size : int = 7 ,
196- gamma_scale : float = 1.0 ,
197- with_noise : bool = False ,
198- inplace : bool = False ,
189+ x : torch .Tensor ,
190+ drop_prob : float = 0.1 ,
191+ block_size : int = 7 ,
192+ gamma_scale : float = 1.0 ,
193+ with_noise : bool = False ,
194+ inplace : bool = False ,
199195):
200- """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
196+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
201197
202198 DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
203199 block mask at edges.
@@ -226,6 +222,7 @@ class DropBlock2d(nn.Module):
226222 batchwise: should the entire batch use the same drop mask?
227223 partial_edge_blocks: partial-blocks at the edges, faster.
228224 """
225+
229226 drop_prob : float
230227 block_size : int
231228 gamma_scale : float
@@ -235,14 +232,14 @@ class DropBlock2d(nn.Module):
235232 partial_edge_blocks : bool
236233
237234 def __init__ (
238- self ,
239- drop_prob : float = 0.1 ,
240- block_size : int = 7 ,
241- gamma_scale : float = 1.0 ,
242- with_noise : bool = False ,
243- inplace : bool = False ,
244- batchwise : bool = False ,
245- partial_edge_blocks : bool = True ,
235+ self ,
236+ drop_prob : float = 0.1 ,
237+ block_size : int = 7 ,
238+ gamma_scale : float = 1.0 ,
239+ with_noise : bool = False ,
240+ inplace : bool = False ,
241+ batchwise : bool = False ,
242+ partial_edge_blocks : bool = True ,
246243 ):
247244 super (DropBlock2d , self ).__init__ ()
248245 self .drop_prob = drop_prob
@@ -265,10 +262,13 @@ def forward(self, x):
265262 with_noise = self .with_noise ,
266263 inplace = self .inplace ,
267264 batchwise = self .batchwise ,
268- partial_edge_blocks = self .partial_edge_blocks )
265+ partial_edge_blocks = self .partial_edge_blocks ,
266+ )
269267
270268
271- def drop_path (x , drop_prob : float = 0. , training : bool = False , scale_by_keep : bool = True ):
269+ def drop_path (
270+ x , drop_prob : float = 0.0 , training : bool = False , scale_by_keep : bool = True
271+ ):
272272 """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
273273
274274 This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
@@ -278,20 +278,22 @@ def drop_path(x, drop_prob: float = 0., training: bool = False, scale_by_keep: b
278278 'survival rate' as the argument.
279279
280280 """
281- if drop_prob == 0. or not training :
281+ if drop_prob == 0.0 or not training :
282282 return x
283283 keep_prob = 1 - drop_prob
284- shape = (x .shape [0 ],) + (1 ,) * (x .ndim - 1 ) # work with diff dim tensors, not just 2D ConvNets
284+ shape = (x .shape [0 ],) + (1 ,) * (
285+ x .ndim - 1
286+ ) # work with diff dim tensors, not just 2D ConvNets
285287 random_tensor = x .new_empty (shape ).bernoulli_ (keep_prob )
286288 if keep_prob > 0.0 and scale_by_keep :
287289 random_tensor .div_ (keep_prob )
288290 return x * random_tensor
289291
290292
291293class DropPath (nn .Module ):
292- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
293- """
294- def __init__ (self , drop_prob : float = 0. , scale_by_keep : bool = True ):
294+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
295+
296+ def __init__ (self , drop_prob : float = 0.0 , scale_by_keep : bool = True ):
295297 super (DropPath , self ).__init__ ()
296298 self .drop_prob = drop_prob
297299 self .scale_by_keep = scale_by_keep
@@ -300,4 +302,4 @@ def forward(self, x):
300302 return drop_path (x , self .drop_prob , self .training , self .scale_by_keep )
301303
302304 def extra_repr (self ):
303- return f' drop_prob={ round (self .drop_prob ,3 ):0.3f} '
305+ return f" drop_prob={ round (self .drop_prob ,3 ):0.3f} "
0 commit comments