Skip to content

Commit a079203

Browse files
author
Donglai Wei
committed
fix DDP bug
1 parent 1c687f4 commit a079203

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

connectomics/lightning/lit_trainer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,12 @@ def create_trainer(
6464

6565
# Check if deep supervision is enabled (requires DDP with find_unused_parameters=True)
6666
deep_supervision_enabled = False
67-
if hasattr(cfg, 'model') and hasattr(cfg.model, 'deep_supervision'):
68-
deep_supervision_enabled = cfg.model.deep_supervision
67+
ddp_find_unused_params = False
68+
if hasattr(cfg, 'model'):
69+
if hasattr(cfg.model, 'deep_supervision'):
70+
deep_supervision_enabled = cfg.model.deep_supervision
71+
if hasattr(cfg.model, 'ddp_find_unused_parameters'):
72+
ddp_find_unused_params = cfg.model.ddp_find_unused_parameters
6973

7074
# Configure DDP strategy for multi-GPU training
7175
strategy = 'auto' # Default strategy
@@ -83,8 +87,8 @@ def create_trainer(
8387

8488
if num_gpus > 1:
8589
# Multi-GPU training: use DDP
86-
if deep_supervision_enabled:
87-
# Deep supervision requires find_unused_parameters=True
90+
if deep_supervision_enabled or ddp_find_unused_params:
91+
# Deep supervision or explicit config requires find_unused_parameters=True
8892
# because auxiliary heads at different scales may not all be used
8993
strategy = DDPStrategy(find_unused_parameters=True)
9094
else:

scripts/main.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,12 +1013,15 @@ def create_trainer(
10131013
if system_cfg.num_gpus > 1:
10141014
# Multi-GPU training: configure DDP
10151015
deep_supervision_enabled = getattr(cfg.model, "deep_supervision", False)
1016-
if deep_supervision_enabled:
1017-
# Deep supervision requires find_unused_parameters=True
1016+
ddp_find_unused_params = getattr(cfg.model, "ddp_find_unused_parameters", False)
1017+
1018+
if deep_supervision_enabled or ddp_find_unused_params:
1019+
# Deep supervision or explicit config requires find_unused_parameters=True
10181020
# because auxiliary heads at different scales may not all be used
10191021
from pytorch_lightning.strategies import DDPStrategy
10201022
strategy = DDPStrategy(find_unused_parameters=True)
1021-
print(" Strategy: DDP with find_unused_parameters=True (deep supervision enabled)")
1023+
reason = "deep supervision" if deep_supervision_enabled else "explicit config"
1024+
print(f" Strategy: DDP with find_unused_parameters=True ({reason})")
10221025
else:
10231026
from pytorch_lightning.strategies import DDPStrategy
10241027
strategy = DDPStrategy(find_unused_parameters=False)

0 commit comments

Comments
 (0)