1414
1515Hacked together by / Copyright 2020 Ross Wightman
1616"""
17+ from typing import Optional , Tuple
1718import torch
1819import torch .nn as nn
1920import torch .nn .functional as F
2021
2122
2223def conv2d_kernel_midpoint_mask (
23- shape : (int , int ),
24- kernel : (int , int ),
25- device ,
26- dtype = torch .bool ,
24+ kernel : Tuple [int , int ],
25+ * ,
26+ inplace_mask = None ,
27+ shape : Optional [Tuple [int , int ]] = None ,
28+ device = None ,
29+ dtype = None ,
2730):
2831 """Build a mask of kernel midpoints.
2932
@@ -36,28 +39,53 @@ def conv2d_kernel_midpoint_mask(
3639
3740 Requires `kernel <= min(h, w)`.
3841
42+ When an `inplace_mask` is not provided, a new mask of `1`s is allocated,
43+ and then the `0` locations are cleared.
44+
45+ When an `inplace_mask` is provided, the `0` locations are cleared on the mask,
46+ and no other changes are made. `shape`, `dtype`, and `device` must match, if
47+ they are provided.
48+
3949 Args:
40- shape: the (h, w) shape of the tensor.
4150 kernel: the (kh, hw) shape of the kernel.
51+ inplace_mask: if supplied, updates will apply to the inplace_mask,
52+ and device and dtype will be ignored. Only clears 'false' locations.
53+ shape: the (h, w) shape of the tensor.
4254 device: the target device.
43- check_kernel: when true, assert that the kernel_size is odd .
55+ dtype: the target dtype .
4456
4557 Returns:
4658 a (h, w) bool mask tensor.
4759 """
60+ if inplace_mask is not None :
61+ mask = inplace_mask
62+
63+ if shape :
64+ assert shape == mask .shape [- 2 ], f"{ shape = } !~= { mask .shape = } "
65+
66+ shape = mask .shape
67+
68+ if device :
69+ device = torch .device (device )
70+ assert device == mask .device , f"{ device = } != { mask .device = } "
71+
72+ if dtype :
73+ dtype = torch .dtype (dtype )
74+ assert dtype == inplace_mask .dtype , f"{ dtype = } != { mask .dtype = } "
75+
76+ else :
77+ mask = torch .ones (shape , dtype = dtype , device = device )
78+
4879 h , w = shape
4980 kh , kw = kernel
5081 assert kh <= h and kw <= w , f"{ kernel = } ! <= { shape = } "
5182
52- mask = torch .zeros ((h , w ), dtype = dtype , device = device )
83+ # Set to 0, rather than set to 1, so we can clear the inplace mask.
84+ mask [:kh // 2 , :] = 0
85+ mask [h - (kh - 1 ) // 2 :, :] = 0
86+ mask [:, :kw // 2 ] = 0
87+ mask [:, w - (kw - 1 ) // 2 :] = 0
5388
54- h_start = kh // 2
55- h_end = (kh - 1 ) // 2
56-
57- w_start = kw // 2
58- w_end = (kw - 1 ) // 2
59-
60- mask [h_start :h - h_end , w_start :w - w_end ] = 1
6189 return mask
6290
6391
@@ -68,7 +96,8 @@ def drop_block_2d(
6896 gamma_scale : float = 1.0 ,
6997 with_noise : bool = False ,
7098 inplace : bool = False ,
71- batchwise : bool = False
99+ batchwise : bool = False ,
100+ messy : bool = False ,
72101):
73102 """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
74103
@@ -82,6 +111,7 @@ def drop_block_2d(
82111 with_noise: should normal noise be added to the dropped region?
83112 inplace: if the drop should be applied in-place on the input tensor.
84113 batchwise: should the entire batch use the same drop mask?
114+ messy: partial-blocks at the edges, faster.
85115
86116 Returns:
87117 If inplace, the modified `x`; otherwise, the dropped copy of `x`, on the same device.
@@ -90,44 +120,55 @@ def drop_block_2d(
90120 total_size = W * H
91121
92122 # TODO: This behaves oddly when clipped_block_size < block_size.
93- clipped_block_size = min (block_size , W , H )
123+ clipped_block_size = min (block_size , H , W )
124+
125+ gamma = (
126+ float (gamma_scale * drop_prob * total_size )
127+ / float (clipped_block_size ** 2 )
128+ / float ((H - block_size + 1 ) * (W - block_size + 1 ))
129+ )
94130
95- # seed_drop_rate, the gamma parameter
96- gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
97- (W - block_size + 1 ) * (H - block_size + 1 ))
131+ # batchwise => one mask for whole batch, quite a bit faster
132+ mask_shape = (1 if batchwise else B , C , H , W )
98133
99- # Forces the block to be inside the feature map.
100- valid_block = conv2d_kernel_midpoint_mask (
101- shape = (H , W ),
102- kernel = (clipped_block_size , clipped_block_size ),
103- device = x .device ,
134+ block_mask = torch .empty (
135+ mask_shape ,
104136 dtype = x .dtype ,
105- ).unsqueeze ().unsqueeze ()
137+ device = x .device
138+ ).bernoulli_ (gamma )
106139
107- if batchwise :
108- # one mask for whole batch, quite a bit faster
109- uniform_noise = torch . rand (( 1 , C , H , W ), dtype = x . dtype , device = x . device )
110- else :
111- uniform_noise = torch . rand_like ( x )
112- block_mask = (( 2 - gamma - valid_block + uniform_noise ) >= 1 ). to ( dtype = x . dtype )
113- block_mask = - F .max_pool2d (
114- - block_mask ,
115- kernel_size = clipped_block_size , # block_size,
140+ if not messy :
141+ conv2d_kernel_midpoint_mask (
142+ kernel = ( clipped_block_size , clipped_block_size ),
143+ inplace_mask = block_mask ,
144+ )
145+
146+ block_mask = F .max_pool2d (
147+ block_mask ,
148+ kernel_size = clipped_block_size ,
116149 stride = 1 ,
117150 padding = clipped_block_size // 2 )
118151
152+ if inplace :
153+ x .mul_ (block_mask )
154+ else :
155+ x = x * block_mask
156+
157+ # From this point on, we do inplace ops on X.
158+
119159 if with_noise :
120- normal_noise = torch .randn ((1 , C , H , W ), dtype = x .dtype , device = x .device ) if batchwise else torch .randn_like (x )
121- if inplace :
122- x .mul_ (block_mask ).add_ (normal_noise * (1 - block_mask ))
123- else :
124- x = x * block_mask + normal_noise * (1 - block_mask )
160+ noise = torch .randn (mask_shape , dtype = x .dtype , device = x .device )
161+ # x += (noise * (1 - block_mask))
162+ block_mask .neg_ ().add_ (1 )
163+ noise .mul_ (block_mask )
164+ x .add_ (noise )
165+
125166 else :
126- normalize_scale = (block_mask . numel ( ) / block_mask . to ( dtype = torch . float32 ). sum (). add ( 1e-7 )). to ( x . dtype )
127- if inplace :
128- x . mul_ ( block_mask * normalize_scale )
129- else :
130- x = x * block_mask * normalize_scale
167+ # x * = (size(block_mask ) / sum(block_mask) )
168+ total = block_mask . to ( dtype = torch . float32 ). sum ()
169+ normalize_scale = block_mask . numel () / total . add ( 1e-7 ). to ( x . dtype )
170+ x . mul_ ( normalize_scale )
171+
131172 return x
132173
133174
@@ -144,35 +185,37 @@ def drop_block_fast_2d(
144185 DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
145186 block mask at edges.
146187 """
147- B , C , H , W = x .shape
148- total_size = W * H
149- clipped_block_size = min (block_size , min (W , H ))
150- gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
151- (W - block_size + 1 ) * (H - block_size + 1 ))
152-
153- block_mask = torch .empty_like (x ).bernoulli_ (gamma )
154- block_mask = F .max_pool2d (
155- block_mask .to (x .dtype ), kernel_size = clipped_block_size , stride = 1 , padding = clipped_block_size // 2 )
156-
157- if with_noise :
158- normal_noise = torch .empty_like (x ).normal_ ()
159- if inplace :
160- x .mul_ (1. - block_mask ).add_ (normal_noise * block_mask )
161- else :
162- x = x * (1. - block_mask ) + normal_noise * block_mask
163- else :
164- block_mask = 1 - block_mask
165- normalize_scale = (block_mask .numel () / block_mask .to (dtype = torch .float32 ).sum ().add (1e-6 )).to (dtype = x .dtype )
166- if inplace :
167- x .mul_ (block_mask * normalize_scale )
168- else :
169- x = x * block_mask * normalize_scale
170- return x
188+ drop_block_2d (
189+ x = x ,
190+ drop_prob = drop_prob ,
191+ block_size = block_size ,
192+ gamma_scale = gamma_scale ,
193+ with_noise = with_noise ,
194+ inplace = inplace ,
195+ batchwise = True ,
196+ messy = True ,
197+ )
171198
172199
173200class DropBlock2d (nn .Module ):
174- """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
201+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
202+
203+ Args:
204+ drop_prob: the probability of dropping any given block.
205+ block_size: the size of the dropped blocks; should be odd.
206+ gamma_scale: adjustment scale for the drop_prob.
207+ with_noise: should normal noise be added to the dropped region?
208+ inplace: if the drop should be applied in-place on the input tensor.
209+ batchwise: should the entire batch use the same drop mask?
210+ messy: partial-blocks at the edges, faster.
175211 """
212+ drop_prob : float
213+ block_size : int
214+ gamma_scale : float
215+ with_noise : bool
216+ inplace : bool
217+ batchwise : bool
218+ messy : bool
176219
177220 def __init__ (
178221 self ,
@@ -182,25 +225,30 @@ def __init__(
182225 with_noise : bool = False ,
183226 inplace : bool = False ,
184227 batchwise : bool = False ,
185- fast : bool = True ):
228+ messy : bool = True ,
229+ ):
186230 super (DropBlock2d , self ).__init__ ()
187231 self .drop_prob = drop_prob
188232 self .gamma_scale = gamma_scale
189233 self .block_size = block_size
190234 self .with_noise = with_noise
191235 self .inplace = inplace
192236 self .batchwise = batchwise
193- self .fast = fast # FIXME finish comparisons of fast vs not
237+ self .messy = messy
194238
195239 def forward (self , x ):
196240 if not self .training or not self .drop_prob :
197241 return x
198- if self .fast :
199- return drop_block_fast_2d (
200- x , self .drop_prob , self .block_size , self .gamma_scale , self .with_noise , self .inplace )
201- else :
202- return drop_block_2d (
203- x , self .drop_prob , self .block_size , self .gamma_scale , self .with_noise , self .inplace , self .batchwise )
242+
243+ return drop_block_2d (
244+ x = x ,
245+ drop_prob = self .drop_prob ,
246+ block_size = self .block_size ,
247+ gamma_scale = self .gamma_scale ,
248+ with_noise = self .with_noise ,
249+ inplace = self .inplace ,
250+ batchwise = self .batchwise ,
251+ messy = self .messy )
204252
205253
206254def drop_path (x , drop_prob : float = 0. , training : bool = False , scale_by_keep : bool = True ):
0 commit comments