@@ -56,7 +56,10 @@ def conv2d_kernel_midpoint_mask(
5656
5757 mask = torch .zeros (shape , dtype = dtype , device = device )
5858
59- mask [kh // 2 : h - ((kh - 1 ) // 2 ), kw // 2 : w - ((kw - 1 ) // 2 )] = 1.0
59+ mask [
60+ kh // 2 : h - ((kh - 1 ) // 2 ),
61+ kw // 2 : w - ((kw - 1 ) // 2 ),
62+ ] = 1.0
6063
6164 return mask
6265
@@ -145,9 +148,11 @@ def drop_block_2d(
145148 # batchwise => one mask for whole batch, quite a bit faster
146149 mask_shape = (1 if batchwise else B , C , H , W )
147150
148- selection = torch .empty (mask_shape , dtype = x .dtype , device = x .device ).bernoulli_ (
149- gamma
150- )
151+ selection = torch .empty (
152+ mask_shape ,
153+ dtype = x .dtype ,
154+ device = x .device ,
155+ ).bernoulli_ (gamma )
151156
152157 drop_filter = drop_block_2d_drop_filter_ (
153158 selection = selection ,
@@ -163,7 +168,11 @@ def drop_block_2d(
163168
164169 if with_noise :
165170 # x += (noise * (1 - block_mask))
166- noise = torch .randn (mask_shape , dtype = x .dtype , device = x .device )
171+ noise = torch .randn (
172+ mask_shape ,
173+ dtype = x .dtype ,
174+ device = x .device ,
175+ )
167176
168177 if inplace :
169178 noise .mul_ (drop_filter )
@@ -281,10 +290,12 @@ def drop_path(
281290 if drop_prob == 0.0 or not training :
282291 return x
283292 keep_prob = 1 - drop_prob
284- shape = (x .shape [0 ],) + (1 ,) * (
285- x .ndim - 1
286- ) # work with diff dim tensors, not just 2D ConvNets
293+
294+ # work with diff dim tensors, not just 2D ConvNets
295+ shape = (x .shape [0 ],) + (1 ,) * (x .ndim - 1 )
296+
287297 random_tensor = x .new_empty (shape ).bernoulli_ (keep_prob )
298+
288299 if keep_prob > 0.0 and scale_by_keep :
289300 random_tensor .div_ (keep_prob )
290301 return x * random_tensor
0 commit comments