Skip to content
This repository was archived by the owner on Feb 7, 2025. It is now read-only.

Commit bdca1bd

Browse files
Add warning when initialising the patchgan discriminator with batchnorm in a distributed environment (#454)
* Add warning * Update generative/networks/nets/patchgan_discriminator.py Co-authored-by: Eric Kerfoot <[email protected]> Signed-off-by: Mark Graham <[email protected]> * Adds missing warning --------- Signed-off-by: Mark Graham <[email protected]> Co-authored-by: Eric Kerfoot <[email protected]>
1 parent 894f2ec commit bdca1bd

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

generative/networks/nets/patchgan_discriminator.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from __future__ import annotations
1313

14+
import warnings
1415
from collections.abc import Sequence
1516

1617
import torch
@@ -218,6 +219,12 @@ def __init__(
218219
)
219220

220221
self.apply(self.initialise_weights)
222+
if norm.lower() == "batch" and torch.distributed.is_initialized():
223+
warnings.warn(
224+
"WARNING: Discriminator is using BatchNorm and a distributed training environment has been detected. "
225+
"To train with DDP, convert discriminator to SyncBatchNorm using "
226+
"torch.nn.SyncBatchNorm.convert_sync_batchnorm(model).)"
227+
)
221228

222229
def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
223230
"""

0 commit comments

Comments
 (0)