Skip to content

Commit 2e72653

Browse files
Add optional SyncBatchNorm support to ImageNet example
1 parent acc295d commit 2e72653

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

imagenet/main.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@
7979
'fastest way to use PyTorch for either single node or '
8080
'multi node data parallel training')
8181
parser.add_argument('--dummy', action='store_true', help="use fake data to benchmark")
82+
parser.add_argument(
83+
'--sync-bn',
84+
action='store_true',
85+
help='Convert BatchNorm layers to SyncBatchNorm for multi-GPU training'
86+
)
87+
8288

8389
best_acc1 = 0
8490

@@ -160,10 +166,15 @@ def main_worker(gpu, ngpus_per_node, args):
160166
if args.pretrained:
161167
print("=> using pre-trained model '{}'".format(args.arch))
162168
model = models.__dict__[args.arch](pretrained=True)
169+
163170
else:
164171
print("=> creating model '{}'".format(args.arch))
165172
model = models.__dict__[args.arch]()
166-
173+
174+
# Convert BN → SyncBatchNorm if requested AND distributed training is enabled
175+
if args.distributed and args.sync_bn:
176+
print("=> Converting BatchNorm layers to SyncBatchNorm")
177+
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
167178
if not use_accel:
168179
print('using CPU, this will be slow')
169180
elif args.distributed:

0 commit comments

Comments
 (0)