|
55 | 55 | help='Type of global pool, "avg", "max", "avgmax", "avgmaxc" (default: "avg")') |
56 | 56 | parser.add_argument('--img-size', type=int, default=None, metavar='N', |
57 | 57 | help='Image patch size (default: None => model default)') |
| 58 | +parser.add_argument('--crop-pct', default=None, type=float, |
| 59 | + metavar='N', help='Input image center crop percent (for validation only)') |
58 | 60 | parser.add_argument('--mean', type=float, nargs='+', default=None, metavar='MEAN', |
59 | 61 | help='Override mean pixel value of dataset') |
60 | 62 | parser.add_argument('--std', type=float, nargs='+', default=None, metavar='STD', |
|
121 | 123 | help='BatchNorm momentum override (if not None)') |
122 | 124 | parser.add_argument('--bn-eps', type=float, default=None, |
123 | 125 | help='BatchNorm epsilon override (if not None)') |
| 126 | +parser.add_argument('--sync-bn', action='store_true', |
| 127 | + help='Enable NVIDIA Apex or Torch synchronized BatchNorm.') |
| 128 | +parser.add_argument('--dist-bn', type=str, default='', |
| 129 | + help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")') |
124 | 130 | # Model Exponential Moving Average |
125 | 131 | parser.add_argument('--model-ema', action='store_true', default=False, |
126 | 132 | help='Enable tracking moving average of model weights') |
|
143 | 149 | help='save images of input bathes every log interval for debugging') |
144 | 150 | parser.add_argument('--amp', action='store_true', default=False, |
145 | 151 | help='use NVIDIA amp for mixed precision training') |
146 | | -parser.add_argument('--sync-bn', action='store_true', |
147 | | - help='enabling apex sync BN.') |
148 | | -parser.add_argument('--reduce-bn', action='store_true', |
149 | | - help='average BN running stats across all distributed nodes between train and validation.') |
150 | 152 | parser.add_argument('--no-prefetcher', action='store_true', default=False, |
151 | 153 | help='disable fast prefetcher') |
152 | 154 | parser.add_argument('--output', default='', type=str, metavar='PATH', |
@@ -349,6 +351,7 @@ def main(): |
349 | 351 | std=data_config['std'], |
350 | 352 | num_workers=args.workers, |
351 | 353 | distributed=args.distributed, |
| 354 | + crop_pct=data_config['crop_pct'], |
352 | 355 | ) |
353 | 356 |
|
354 | 357 | if args.mixup > 0.: |
@@ -390,16 +393,16 @@ def main(): |
390 | 393 | lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir, |
391 | 394 | use_amp=use_amp, model_ema=model_ema) |
392 | 395 |
|
393 | | - if args.distributed and args.reduce_bn: |
| 396 | + if args.distributed and args.dist_bn and args.dist_bn in ('broadcast', 'reduce'): |
394 | 397 | if args.local_rank == 0: |
395 | | - logging.info("Averaging bn running means and vars") |
396 | | - reduce_bn(model, args.world_size) |
| 398 | + logging.info("Distributing BatchNorm running means and vars") |
| 399 | + distribute_bn(model, args.world_size, args.dist_bn == 'reduce') |
397 | 400 |
|
398 | 401 | eval_metrics = validate(model, loader_eval, validate_loss_fn, args) |
399 | 402 |
|
400 | 403 | if model_ema is not None and not args.model_ema_force_cpu: |
401 | 404 | if args.distributed and args.reduce_bn: |
402 | | - reduce_bn(model_ema, args.world_size) |
| 405 | + distribute_bn(model_ema, args.world_size) |
403 | 406 |
|
404 | 407 | ema_eval_metrics = validate( |
405 | 408 | model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)') |
|
0 commit comments