Skip to content

Commit a435ea1

Browse files
committed
Change reduce_bn to distribute_bn, add ability to choose between broadcast and reduce (mean). Add crop_pct arg to allow selecting validation crop while training.
1 parent 3bff2b2 commit a435ea1

File tree

2 files changed

+19
-11
lines changed

2 files changed

+19
-11
lines changed

timm/utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -210,12 +210,17 @@ def reduce_tensor(tensor, n):
210210
return rt
211211

212212

213-
def reduce_bn(model, world_size):
213+
def distribute_bn(model, world_size, reduce=False):
214214
# ensure every node has the same running bn stats
215215
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
216216
if ('running_mean' in bn_name) or ('running_var' in bn_name):
217-
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
218-
bn_buf /= float(world_size)
217+
if reduce:
218+
# average bn stats across whole group
219+
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
220+
bn_buf /= float(world_size)
221+
else:
222+
# broadcast bn stats from rank 0 to whole group
223+
torch.distributed.broadcast(bn_buf, 0)
219224

220225

221226
class ModelEma:

train.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@
5555
help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")')
5656
parser.add_argument('--img-size', type=int, default=None, metavar='N',
5757
help='Image patch size (default: None => model default)')
58+
parser.add_argument('--crop-pct', default=None, type=float,
59+
metavar='N', help='Input image center crop percent (for validation only)')
5860
parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN',
5961
help='Override mean pixel value of dataset')
6062
parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD',
@@ -121,6 +123,10 @@
121123
help='BatchNorm momentum override (if not None)')
122124
parser.add_argument('--bn-eps', type=float, default=None,
123125
help='BatchNorm epsilon override (if not None)')
126+
parser.add_argument('--sync-bn', action='store_true',
127+
help='Enable NVIDIA Apex or Torch synchronized BatchNorm.')
128+
parser.add_argument('--dist-bn', type=str, default='',
129+
help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")')
124130
# Model Exponential Moving Average
125131
parser.add_argument('--model-ema', action='store_true', default=False,
126132
help='Enable tracking moving average of model weights')
@@ -143,10 +149,6 @@
143149
help='save images of input bathes every log interval for debugging')
144150
parser.add_argument('--amp', action='store_true', default=False,
145151
help='use NVIDIA amp for mixed precision training')
146-
parser.add_argument('--sync-bn', action='store_true',
147-
help='enabling apex sync BN.')
148-
parser.add_argument('--reduce-bn', action='store_true',
149-
help='average BN running stats across all distributed nodes between train and validation.')
150152
parser.add_argument('--no-prefetcher', action='store_true', default=False,
151153
help='disable fast prefetcher')
152154
parser.add_argument('--output', default='', type=str, metavar='PATH',
@@ -349,6 +351,7 @@ def main():
349351
std=data_config['std'],
350352
num_workers=args.workers,
351353
distributed=args.distributed,
354+
crop_pct=data_config['crop_pct'],
352355
)
353356

354357
if args.mixup > 0.:
@@ -390,16 +393,16 @@ def main():
390393
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
391394
use_amp=use_amp, model_ema=model_ema)
392395

393-
if args.distributed and args.reduce_bn:
396+
if args.distributed and args.dist_bn and args.dist_bn in ('broadcast', 'reduce'):
394397
if args.local_rank == 0:
395-
logging.info("Averaging bn running means and vars")
396-
reduce_bn(model, args.world_size)
398+
logging.info("Distributing BatchNorm running means and vars")
399+
distribute_bn(model, args.world_size, args.dist_bn == 'reduce')
397400

398401
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
399402

400403
if model_ema is not None and not args.model_ema_force_cpu:
401404
if args.distributed and args.reduce_bn:
402-
reduce_bn(model_ema, args.world_size)
405+
distribute_bn(model_ema, args.world_size)
403406

404407
ema_eval_metrics = validate(
405408
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')

0 commit comments

Comments
 (0)