Skip to content

Commit f195cc3

Browse files
committed
couple_channels; no_grad
1 parent b2fe962 commit f195cc3

File tree

2 files changed

+54
-36
lines changed

2 files changed

+54
-36
lines changed

timm/layers/drop.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def drop_block_2d(
138138
with_noise: bool = False,
139139
inplace: bool = False,
140140
batchwise: bool = False,
141+
couple_channels: bool = False,
141142
partial_edge_blocks: bool = False,
142143
):
143144
"""DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
@@ -151,8 +152,10 @@ def drop_block_2d(
151152
gamma_scale: adjustment scale for the drop_prob.
152153
with_noise: should normal noise be added to the dropped region?
153154
inplace: if the drop should be applied in-place on the input tensor.
154-
batchwise: should the entire batch use the same drop mask?
155-
partial_edge_blocks: partial-blocks at the edges, faster.
155+
batchwise: when true, the entire batch is shares the same drop mask; much faster.
156+
couple_channels: when true, channels share the same drop mask;
157+
much faster, with significant semantic impact.
158+
partial_edge_blocks: partial-blocks at the edges; minor speedup, minor semantic impact.
156159
157160
Returns:
158161
If inplace, the modified `x`; otherwise, the dropped copy of `x`, on the same device.
@@ -165,45 +168,47 @@ def drop_block_2d(
165168
kh, kw = kernel
166169

167170
# batchwise => one mask for whole batch, quite a bit faster
168-
noise_shape = (1 if batchwise else B, C, H, W)
171+
noise_shape = (1 if batchwise else B, 1 if couple_channels else C, H, W)
169172

170173
gamma = (
171174
float(gamma_scale * drop_prob * H * W)
172175
/ float(kh * kw)
173176
/ float((H - kh + 1) * (W - kw + 1))
174177
)
175178

176-
drop_filter = drop_block_2d_drop_filter_(
177-
kernel=kernel,
178-
partial_edge_blocks=partial_edge_blocks,
179-
inplace=True,
180-
selection=torch.empty(
181-
noise_shape,
182-
dtype=x.dtype,
183-
device=x.device,
184-
).bernoulli_(gamma),
185-
)
186-
keep_filter = 1.0 - drop_filter
179+
with torch.no_grad():
180+
drop_filter = drop_block_2d_drop_filter_(
181+
kernel=kernel,
182+
partial_edge_blocks=partial_edge_blocks,
183+
inplace=True,
184+
selection=torch.empty(
185+
noise_shape,
186+
dtype=x.dtype,
187+
device=x.device,
188+
).bernoulli_(gamma),
189+
)
190+
keep_filter = 1.0 - drop_filter
187191

188192
if with_noise:
189193
# x += (noise * drop_filter)
190-
drop_noise = torch.randn_like(drop_filter)
191-
drop_noise.mul_(drop_filter)
194+
with torch.no_grad():
195+
drop_noise = torch.randn_like(drop_filter)
196+
drop_noise.mul_(drop_filter)
192197

193198
if inplace:
194199
x.mul_(keep_filter)
195200
x.add_(drop_noise)
196-
197201
else:
198202
x = x * keep_filter + drop_noise
199203

200204
else:
201205
# x *= (size(keep_filter) / (sum(keep_filter) + eps))
202-
count = keep_filter.numel()
203-
total = keep_filter.to(dtype=torch.float32).sum()
204-
keep_scale = count / total.add(1e-7).to(x.dtype)
206+
with torch.no_grad():
207+
count = keep_filter.numel()
208+
total = keep_filter.to(dtype=torch.float32).sum()
209+
keep_scale = count / total.add(1e-7).to(x.dtype)
205210

206-
keep_filter.mul_(keep_scale)
211+
keep_filter.mul_(keep_scale)
207212

208213
if inplace:
209214
x.mul_(keep_filter)
@@ -247,8 +252,10 @@ class DropBlock2d(nn.Module):
247252
gamma_scale: adjustment scale for the drop_prob.
248253
with_noise: should normal noise be added to the dropped region?
249254
inplace: if the drop should be applied in-place on the input tensor.
250-
batchwise: should the entire batch use the same drop mask?
251-
partial_edge_blocks: partial-blocks at the edges, faster.
255+
batchwise: when true, the entire batch is shares the same drop mask; much faster.
256+
couple_channels: when true, channels share the same drop mask;
257+
much faster, with significant semantic impact.
258+
partial_edge_blocks: partial-blocks at the edges; minor speedup, minor semantic impact.
252259
"""
253260

254261
drop_prob: float
@@ -257,6 +264,7 @@ class DropBlock2d(nn.Module):
257264
with_noise: bool
258265
inplace: bool
259266
batchwise: bool
267+
couple_channels: bool
260268
partial_edge_blocks: bool
261269

262270
def __init__(
@@ -266,8 +274,9 @@ def __init__(
266274
gamma_scale: float = 1.0,
267275
with_noise: bool = False,
268276
inplace: bool = False,
269-
batchwise: bool = False,
270-
partial_edge_blocks: bool = True,
277+
batchwise: bool = True,
278+
couple_channels: bool = False,
279+
partial_edge_blocks: bool = False,
271280
):
272281
super(DropBlock2d, self).__init__()
273282
self.drop_prob = drop_prob

timm/models/resnet.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -326,9 +326,10 @@ def make_blocks(
326326
down_kernel_size: int = 1,
327327
avg_down: bool = False,
328328
drop_block_rate: float = 0.,
329-
drop_path_rate: float = 0.,
330-
drop_block_batchwise: bool = False,
329+
drop_block_batchwise: bool = True,
330+
drop_block_couple_channels: bool = False,
331331
drop_block_partial_edge_blocks: bool = True,
332+
drop_path_rate: float = 0.,
332333
**kwargs,
333334
) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Any]]]:
334335
"""Create ResNet stages with specified block configurations.
@@ -343,8 +344,10 @@ def make_blocks(
343344
down_kernel_size: Kernel size for downsample layers.
344345
avg_down: Use average pooling for downsample.
345346
drop_block_rate: DropBlock drop rate.
346-
drop_block_batchwise: Batchwise block dropping, faster.
347-
drop_block_partial_edge_blocks: dropping produces partial blocks on the edge, faster.
347+
drop_block_batchwise: Batchwise block dropping, much faster.
348+
drop_block_couple_channels: Couple channel drops.
349+
drop_block_partial_edge_blocks: Permit partial drop blocks on the edge,
350+
slightly faster.
348351
drop_path_rate: Drop path rate for stochastic depth.
349352
**kwargs: Additional arguments passed to block constructors.
350353
@@ -364,6 +367,7 @@ def make_blocks(
364367
drop_blocks(
365368
drop_prob=drop_block_rate,
366369
batchwise=drop_block_batchwise,
370+
couple_channels=drop_block_couple_channels,
367371
partial_edge_blocks=drop_block_partial_edge_blocks,
368372
))):
369373
stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
@@ -465,10 +469,11 @@ def __init__(
465469
norm_layer: LayerType = nn.BatchNorm2d,
466470
aa_layer: Optional[Type[nn.Module]] = None,
467471
drop_rate: float = 0.0,
468-
drop_path_rate: float = 0.,
469472
drop_block_rate: float = 0.,
470473
drop_block_batchwise: bool = True,
474+
drop_block_couple_channels: bool = False,
471475
drop_block_partial_edge_blocks: bool = True,
476+
drop_path_rate: float = 0.,
472477
zero_init_last: bool = True,
473478
block_args: Optional[Dict[str, Any]] = None,
474479
):
@@ -497,10 +502,11 @@ def __init__(
497502
norm_layer (str, nn.Module): normalization layer
498503
aa_layer (nn.Module): anti-aliasing layer
499504
drop_rate (float): Dropout probability before classifier, for training (default 0.)
500-
drop_path_rate (float): Stochastic depth drop-path rate (default 0.)
501505
drop_block_rate (float): Drop block rate (default 0.)
502-
drop_block_batchwise (bool): Sample blocks batchwise, faster.
506+
drop_block_batchwise (bool): Sample blocks batchwise, significantly faster.
507+
drop_block_couple_channels (bool): couple channels when dropping blocks.
503508
drop_block_partial_edge_blocks (bool): Partial block dropping at the edges, faster.
509+
drop_path_rate (float): Stochastic depth drop-path rate (default 0.)
504510
zero_init_last (bool): zero-init the last weight in residual path (usually last BN affine weight)
505511
block_args (dict): Extra kwargs to pass through to block module
506512
"""
@@ -572,6 +578,7 @@ def __init__(
572578
aa_layer=aa_layer,
573579
drop_block_rate=drop_block_rate,
574580
drop_block_batchwise=drop_block_batchwise,
581+
drop_block_couple_channels=drop_block_couple_channels,
575582
drop_block_partial_edge_blocks=drop_block_partial_edge_blocks,
576583
drop_path_rate=drop_path_rate,
577584
**block_args,
@@ -1459,8 +1466,8 @@ def resnet10t(pretrained: bool = False, **kwargs) -> ResNet:
14591466
return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs))
14601467

14611468
@register_model
1462-
def resnet10t_dropblock_correct(pretrained: bool = False, **kwargs) -> ResNet:
1463-
"""Constructs a ResNet-10-T model with drop_block_rate=0.05, using the most accurate DropBlock2d features.
1469+
def resnet10t_dropblock_slow(pretrained: bool = False, **kwargs) -> ResNet:
1470+
"""Constructs a ResNet-10-T model with drop_block_rate=0.05, using the slowest DropBlock2d features.
14641471
"""
14651472
model_args = dict(
14661473
block=BasicBlock,
@@ -1469,7 +1476,8 @@ def resnet10t_dropblock_correct(pretrained: bool = False, **kwargs) -> ResNet:
14691476
stem_type='deep_tiered',
14701477
avg_down=True,
14711478
drop_block_rate=0.05,
1472-
drop_block_batchwise=True,
1479+
drop_block_batchwise=False,
1480+
drop_block_couple_channels=False,
14731481
drop_block_partial_edge_blocks=True,
14741482
)
14751483
return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs))
@@ -1485,7 +1493,8 @@ def resnet10t_dropblock_fast(pretrained: bool = False, **kwargs) -> ResNet:
14851493
stem_type='deep_tiered',
14861494
avg_down=True,
14871495
drop_block_rate=0.05,
1488-
drop_block_batchwise=False,
1496+
drop_block_batchwise=True,
1497+
drop_block_couple_channels=True,
14891498
drop_block_partial_edge_blocks=False,
14901499
)
14911500
return _create_resnet('resnet10t', pretrained, **dict(model_args, **kwargs))

0 commit comments

Comments
 (0)