Skip to content

Commit 94f81a2

Browse files
fix tbd usage in dist env (#1202)
1 parent f91995c commit 94f81a2

File tree

1 file changed

+12
-11
lines changed

1 file changed

+12
-11
lines changed

ppsci/solver/solver.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -428,9 +428,7 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:
428428
raise ModuleNotFoundError(
429429
"Please install 'visualdl' with `pip install visualdl` first."
430430
)
431-
with misc.RankZeroOnly(self.rank) as is_master:
432-
if is_master:
433-
self.vdl_writer = vdl.LogWriter(osp.join(self.output_dir, "vdl"))
431+
self.vdl_writer = vdl.LogWriter(osp.join(self.output_dir, "vdl"))
434432
logger.info(
435433
"VisualDL is enabled for logging, you can view it by "
436434
f"running:\nvisualdl --logdir {self.vdl_writer._logdir} --port 8080"
@@ -448,6 +446,7 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:
448446
raise ModuleNotFoundError(
449447
"Please install 'wandb' with `pip install wandb` first."
450448
)
449+
# FIXME: wandb may hanging here in distributed env
451450
with misc.RankZeroOnly(self.rank) as is_master:
452451
if is_master:
453452
self.wandb_writer = wandb.init(**self.wandb_config)
@@ -463,11 +462,11 @@ def dist_wrapper(model: nn.Layer) -> paddle.DataParallel:
463462
raise ModuleNotFoundError(
464463
"Please install 'tensorboardX' with `pip install tensorboardX` first."
465464
)
466-
with misc.RankZeroOnly(self.rank) as is_master:
467-
if is_master:
468-
self.tbd_writer = tensorboardX.SummaryWriter(
469-
osp.join(self.output_dir, "tensorboard")
470-
)
465+
# NOTE: To prevent program hangs, initialize the tensorboardX writer across all processes,
466+
# but it will only be used in rank 0
467+
self.tbd_writer = tensorboardX.SummaryWriter(
468+
osp.join(self.output_dir, "tensorboard")
469+
)
471470
logger.message(
472471
"TensorboardX is enabled for logging, you can view it by "
473472
f"running:\ntensorboard --logdir {self.tbd_writer.logdir}"
@@ -565,9 +564,11 @@ def train(self) -> None:
565564
start_epoch = self.best_metric["epoch"] + 1
566565

567566
if self.use_tbd and isinstance(self.cfg, DictConfig):
568-
self.tbd_writer.add_text(
569-
"config", f"<pre>{str(OmegaConf.to_yaml(self.cfg))}</pre>"
570-
)
567+
with misc.RankZeroOnly(self.rank) as is_master:
568+
if is_master:
569+
self.tbd_writer.add_text(
570+
"config", f"<pre>{str(OmegaConf.to_yaml(self.cfg))}</pre>"
571+
)
571572

572573
if self.nvtx_flag:
573574
core.nvprof_start()

0 commit comments

Comments
 (0)