Skip to content

Commit a63e4a9

Browse files
author
Donglai Wei
committed
fix DDP
1 parent a079203 commit a63e4a9

File tree

1 file changed

+15
-6
lines changed

1 file changed

+15
-6
lines changed

scripts/main.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -468,19 +468,20 @@ def create_datamodule(
468468
try:
469469
from pathlib import Path
470470
import json
471+
471472
json_path = Path(cfg.data.train_json)
472473
if not json_path.exists():
473474
train_json_empty = True
474475
else:
475476
# Check if JSON file is empty or has no images
476-
with open(json_path, 'r') as f:
477+
with open(json_path, "r") as f:
477478
json_data = json.load(f)
478479
image_files = json_data.get(cfg.data.train_image_key, [])
479480
if not image_files:
480481
train_json_empty = True
481482
except (FileNotFoundError, json.JSONDecodeError, KeyError):
482483
train_json_empty = True
483-
484+
484485
if train_json_empty:
485486
# Fallback to volume-based dataset when train_json is empty
486487
print(f" ⚠️ Train JSON is empty or invalid, falling back to volume-based dataset")
@@ -497,7 +498,7 @@ def create_datamodule(
497498
# Here we just need placeholder dicts
498499
train_data_dicts = [{"dataset_type": "filename"}]
499500
val_data_dicts = None # Handled by train_val_split in DataModule
500-
501+
501502
if dataset_type != "filename":
502503
# Standard mode: separate train and val files (supports glob patterns)
503504
if cfg.data.train_image is None:
@@ -1014,16 +1015,18 @@ def create_trainer(
10141015
# Multi-GPU training: configure DDP
10151016
deep_supervision_enabled = getattr(cfg.model, "deep_supervision", False)
10161017
ddp_find_unused_params = getattr(cfg.model, "ddp_find_unused_parameters", False)
1017-
1018+
10181019
if deep_supervision_enabled or ddp_find_unused_params:
10191020
# Deep supervision or explicit config requires find_unused_parameters=True
10201021
# because auxiliary heads at different scales may not all be used
10211022
from pytorch_lightning.strategies import DDPStrategy
1023+
10221024
strategy = DDPStrategy(find_unused_parameters=True)
10231025
reason = "deep supervision" if deep_supervision_enabled else "explicit config"
10241026
print(f" Strategy: DDP with find_unused_parameters=True ({reason})")
10251027
else:
10261028
from pytorch_lightning.strategies import DDPStrategy
1029+
10271030
strategy = DDPStrategy(find_unused_parameters=False)
10281031
print(" Strategy: DDP (standard)")
10291032

@@ -1099,6 +1102,7 @@ def main():
10991102

11001103
# Check if this is a DDP re-launch (LOCAL_RANK is set by PyTorch Lightning)
11011104
import os
1105+
11021106
is_ddp_subprocess = "LOCAL_RANK" in os.environ
11031107
local_rank = int(os.environ.get("LOCAL_RANK", 0))
11041108

@@ -1107,6 +1111,7 @@ def main():
11071111
if not is_ddp_subprocess:
11081112
# First invocation (main process) - create new timestamp
11091113
from datetime import datetime
1114+
11101115
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
11111116
run_dir = output_base / timestamp
11121117

@@ -1127,6 +1132,7 @@ def main():
11271132
else:
11281133
# DDP subprocess - read existing timestamp
11291134
import time
1135+
11301136
max_wait = 30 # Maximum 30 seconds
11311137
waited = 0
11321138
while not timestamp_file.exists() and waited < max_wait:
@@ -1139,7 +1145,9 @@ def main():
11391145
cfg.monitor.checkpoint.dirpath = str(run_dir / "checkpoints")
11401146
print(f"📁 [DDP Rank {local_rank}] Using run directory: {run_dir}")
11411147
else:
1142-
raise RuntimeError(f"DDP subprocess (LOCAL_RANK={local_rank}) timed out waiting for timestamp file")
1148+
raise RuntimeError(
1149+
f"DDP subprocess (LOCAL_RANK={local_rank}) timed out waiting for timestamp file"
1150+
)
11431151
else:
11441152
# For test/predict mode, use a dummy run_dir (won't be created)
11451153
output_base = Path(cfg.monitor.checkpoint.dirpath).parent
@@ -1325,8 +1333,9 @@ def main():
13251333
# Cleanup: Remove timestamp file (only in main process, not DDP subprocesses)
13261334
if args.mode == "train":
13271335
import os
1336+
13281337
is_ddp_subprocess = "LOCAL_RANK" in os.environ
1329-
if not is_ddp_subprocess and 'output_base' in locals():
1338+
if not is_ddp_subprocess and "output_base" in locals():
13301339
timestamp_file = output_base / ".latest_timestamp"
13311340
if timestamp_file.exists():
13321341
try:

0 commit comments

Comments
 (0)