2323
2424def conv2d_kernel_midpoint_mask (
2525 * ,
26- shape : Tuple [int , int ],
2726 kernel : Tuple [int , int ],
28- device ,
29- dtype ,
27+ inplace = None ,
28+ shape : Optional [Tuple [int , int ]] = None ,
29+ device = None ,
30+ dtype = None ,
3031):
3132 """Build a mask of kernel midpoints.
3233
@@ -43,30 +44,46 @@ def conv2d_kernel_midpoint_mask(
4344
4445 Args:
4546 kernel: the (kh, hw) shape of the kernel.
47+ inplace: use the provided tensor as the mask; set masked-out values to 0.
4648 shape: the (h, w) shape of the tensor.
4749 device: the target device.
4850 dtype: the target dtype.
4951
5052 Returns:
5153 a (h, w) bool mask tensor.
5254 """
55+ if inplace is None :
56+ assert shape is not None , f"shape is required when inplace is None."
57+ assert dtype is not None , f"dtype is required when inplace is None."
58+ assert device is not None , f"device is required when inplace is None."
59+
60+ mask = torch .ones (shape , dtype = dtype , device = device )
61+ else :
62+ assert shape is None , f"shape and inplace are incompatile"
63+ assert dtype is None , f"dtype and inplace are incompatile"
64+ assert device is None , f"device and inplace are incompatile"
65+
66+ mask = inplace
67+ shape = inplace .shape [- 2 :]
68+ device = inplace .device
69+ dtype = inplace .dtype
70+
5371 h , w = shape
5472 kh , kw = kernel
5573 assert kh <= h and kw <= w , f"{ kernel = } ! <= { shape = } "
5674
57- mask = torch .zeros (shape , dtype = dtype , device = device )
58-
59- mask [
60- kh // 2 : h - ((kh - 1 ) // 2 ),
61- kw // 2 : w - ((kw - 1 ) // 2 ),
62- ] = 1.0
75+ mask [..., 0 : kh // 2 , :] = 0
76+ mask [..., :, 0 : kw // 2 :] = 0
77+ mask [..., h - ((kh - 1 ) // 2 ) :, :] = 0
78+ mask [..., :, w - ((kw - 1 ) // 2 ) :] = 0
6379
6480 return mask
6581
6682
6783def drop_block_2d_drop_filter_ (
6884 * ,
6985 selection ,
86+ inplace : bool = False ,
7087 kernel : Tuple [int , int ],
7188 partial_edge_blocks : bool ,
7289):
@@ -78,19 +95,26 @@ def drop_block_2d_drop_filter_(
7895 selection: 4D (B, C, H, W) input selection noise;
7996 `1.0` at the midpoints of selected blocks to drop,
8097 `0.0` everywhere else. Expected to be gamma noise.
98+ inplace: permit in-place updates to `selection`.
8199 kernel: the shape of the 2d kernel.
82100 partial_edge_blocks: permit partial blocks at the edges, faster.
83101
84102 Returns:
85103 A drop filter, `1.0` at points to drop, `0.0` at points to keep.
86104 """
87105 if not partial_edge_blocks :
88- selection = selection * conv2d_kernel_midpoint_mask (
89- shape = selection .shape [- 2 :],
90- kernel = kernel ,
91- dtype = selection .dtype ,
92- device = selection .device ,
93- )
106+ if inplace :
107+ selection = conv2d_kernel_midpoint_mask (
108+ kernel = kernel ,
109+ inplace = selection ,
110+ )
111+ else :
112+ selection = selection * conv2d_kernel_midpoint_mask (
113+ shape = selection .shape [- 2 :],
114+ kernel = kernel ,
115+ dtype = selection .dtype ,
116+ device = selection .device ,
117+ )
94118
95119 kh , kw = kernel
96120
@@ -136,62 +160,55 @@ def drop_block_2d(
136160 B , C , H , W = x .shape
137161
138162 # TODO: This behaves oddly when clipped_block_size < block_size.
139- kh = kw = block_size
140-
141- kernel = [min (kh , H ), min (kw , W )]
163+ # We could expose non-square blocks above this layer.
164+ kernel = [min (block_size , H ), min (block_size , W )]
142165 kh , kw = kernel
143166
167+ # batchwise => one mask for whole batch, quite a bit faster
168+ noise_shape = (1 if batchwise else B , C , H , W )
169+
144170 gamma = (
145171 float (gamma_scale * drop_prob * H * W )
146172 / float (kh * kw )
147173 / float ((H - kh + 1 ) * (W - kw + 1 ))
148174 )
149175
150- # batchwise => one mask for whole batch, quite a bit faster
151- mask_shape = (1 if batchwise else B , C , H , W )
152-
153- selection = torch .empty (
154- mask_shape ,
155- dtype = x .dtype ,
156- device = x .device ,
157- ).bernoulli_ (gamma )
158-
159176 drop_filter = drop_block_2d_drop_filter_ (
160- selection = selection ,
161177 kernel = kernel ,
162178 partial_edge_blocks = partial_edge_blocks ,
179+ inplace = True ,
180+ selection = torch .empty (
181+ noise_shape ,
182+ dtype = x .dtype ,
183+ device = x .device ,
184+ ).bernoulli_ (gamma ),
163185 )
164186 keep_filter = 1.0 - drop_filter
165187
166- if inplace :
167- x .mul_ (keep_filter )
168- else :
169- x = x * keep_filter
170-
171188 if with_noise :
172- # x += (noise * (1 - block_mask))
173- noise = torch .randn (
174- mask_shape ,
175- dtype = x .dtype ,
176- device = x .device ,
177- )
189+ # x += (noise * drop_filter)
190+ drop_noise = torch .randn_like (drop_filter )
191+ drop_noise .mul_ (drop_filter )
178192
179193 if inplace :
180- noise .mul_ (drop_filter )
181- x .add_ (noise )
194+ x .mul_ (keep_filter )
195+ x .add_ (drop_noise )
196+
182197 else :
183- x = x + noise * drop_filter
198+ x = x * keep_filter + drop_noise
184199
185200 else :
186- # x *= (size(block_mask ) / sum(block_mask ))
201+ # x *= (size(keep_filter ) / ( sum(keep_filter) + eps ))
187202 count = keep_filter .numel ()
188203 total = keep_filter .to (dtype = torch .float32 ).sum ()
189- normalize_scale = count / total .add (1e-7 ).to (x .dtype )
204+ keep_scale = count / total .add (1e-7 ).to (x .dtype )
205+
206+ keep_filter .mul_ (keep_scale )
190207
191208 if inplace :
192- x .mul_ (normalize_scale )
209+ x .mul_ (keep_filter )
193210 else :
194- x = x * normalize_scale
211+ x = x * keep_filter
195212
196213 return x
197214
0 commit comments