Skip to content

Commit 3bff2b2

Browse files
committed
Add support for keeping running bn stats the same across distributed training nodes before eval/save
1 parent 0161de0 commit 3bff2b2

File tree

2 files changed

+26
-4
lines changed

2 files changed

+26
-4
lines changed

timm/utils.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,15 @@
2121
from torch import distributed as dist
2222

2323

24-
def get_state_dict(model):
24+
def unwrap_model(model):
2525
if isinstance(model, ModelEma):
26-
return get_state_dict(model.ema)
26+
return unwrap_model(model.ema)
2727
else:
28-
return model.module.state_dict() if hasattr(model, 'module') else model.state_dict()
28+
return model.module if hasattr(model, 'module') else model
29+
30+
31+
def get_state_dict(model):
32+
return unwrap_model(model).state_dict()
2933

3034

3135
class CheckpointSaver:
@@ -206,6 +210,14 @@ def reduce_tensor(tensor, n):
206210
return rt
207211

208212

213+
def reduce_bn(model, world_size):
214+
# ensure every node has the same running bn stats
215+
for bn_name, bn_buf in unwrap_model(model).named_buffers(recurse=True):
216+
if ('running_mean' in bn_name) or ('running_var' in bn_name):
217+
torch.distributed.all_reduce(bn_buf, op=dist.ReduceOp.SUM)
218+
bn_buf /= float(world_size)
219+
220+
209221
class ModelEma:
210222
""" Model Exponential Moving Average
211223
Keep a moving average of everything in the model state_dict (parameters and buffers).

train.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@
145145
help='use NVIDIA amp for mixed precision training')
146146
parser.add_argument('--sync-bn', action='store_true',
147147
help='enabling apex sync BN.')
148+
parser.add_argument('--reduce-bn', action='store_true',
149+
help='average BN running stats across all distributed nodes between train and validation.')
148150
parser.add_argument('--no-prefetcher', action='store_true', default=False,
149151
help='disable fast prefetcher')
150152
parser.add_argument('--output', default='', type=str, metavar='PATH',
@@ -256,7 +258,7 @@ def main():
256258
if args.local_rank == 0:
257259
logging.info('Restoring NVIDIA AMP state from checkpoint')
258260
amp.load_state_dict(resume_state['amp'])
259-
resume_state = None # clear it
261+
del resume_state
260262

261263
model_ema = None
262264
if args.model_ema:
@@ -388,9 +390,17 @@ def main():
388390
lr_scheduler=lr_scheduler, saver=saver, output_dir=output_dir,
389391
use_amp=use_amp, model_ema=model_ema)
390392

393+
if args.distributed and args.reduce_bn:
394+
if args.local_rank == 0:
395+
logging.info("Averaging bn running means and vars")
396+
reduce_bn(model, args.world_size)
397+
391398
eval_metrics = validate(model, loader_eval, validate_loss_fn, args)
392399

393400
if model_ema is not None and not args.model_ema_force_cpu:
401+
if args.distributed and args.reduce_bn:
402+
reduce_bn(model_ema, args.world_size)
403+
394404
ema_eval_metrics = validate(
395405
model_ema.ema, loader_eval, validate_loss_fn, args, log_suffix=' (EMA)')
396406
eval_metrics = ema_eval_metrics

0 commit comments

Comments
 (0)