|
145 | 145 | help='use NVIDIA amp for mixed precision training') |
146 | 146 | parser.add_argument('--sync-bn', action='store_true', |
147 | 147 | help='enabling apex sync BN.') |
| 148 | +parser.add_argument('--reduce-bn', action='store_true', |
| 149 | + help='average BN running stats across all distributed nodes between train and validation.') |
148 | 150 | parser.add_argument('--no-prefetcher', action='store_true', default=False, |
149 | 151 | help='disable fast prefetcher') |
150 | 152 | parser.add_argument('--output', default='', type=str, metavar='PATH', |
@@ -256,7 +258,7 @@ def main(): |
256 | 258 | if args.local_rank == 0: |
257 | 259 | logging.info('Restoring NVIDIA AMP state from checkpoint') |
258 | 260 | amp.load_state_dict(resume_state['amp']) |
259 | | - resume_state = None # clear it |
| 261 | + del resume_state |
260 | 262 |
|
261 | 263 | model_ema = None |
262 | 264 | if args.model_ema: |
@@ -388,9 +390,17 @@ def main(): |
388 | 390 | lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, |
389 | 391 | use_amp=use_amp, model_ema=model_ema) |
390 | 392 |
|
| 393 | + if args.distributed and args.reduce_bn: |
| 394 | + if args.local_rank == 0: |
| 395 | + logging.info("Averaging bn running means and vars") |
| 396 | + reduce_bn(model, args.world_size) |
| 397 | + |
391 | 398 | eval_metrics = validate(model, loader_eval, validate_loss_fn, args) |
392 | 399 |
|
393 | 400 | if model_ema is not None and not args.model_ema_force_cpu: |
| 401 | + if args.distributed and args.reduce_bn: |
| 402 | + reduce_bn(model_ema, args.world_size) |
| 403 | + |
394 | 404 | ema_eval_metrics = validate( |
395 | 405 | model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') |
396 | 406 | eval_metrics = ema_eval_metrics |
|
0 commit comments