Skip to content

Commit eef55f0

Browse files
A trick to update wandb after each step (#16)
Co-authored-by: Yuchang Sun <[email protected]>
1 parent 488d3b2 commit eef55f0

File tree

4 files changed

+13
-2
lines changed

4 files changed

+13
-2
lines changed

trinity/cli/launcher.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ def both(config: Config) -> None:
103103
logger.error("Evaluation failed.")
104104
raise e
105105

106+
ray.get(explorer.log_finalize.remote(step=explore_iter_num))
107+
ray.get(trainer.log_finalize.remote(step=train_iter_num))
108+
106109

107110
def activate_data_module(config_path: str):
108111
"""Check whether to activate data module and preprocess datasets."""

trinity/explorer/explorer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,3 +247,7 @@ def sync_weight(self) -> None:
247247
self._offline_weights_update()
248248
else: # online weights update
249249
self._online_weights_update()
250+
251+
def log_finalize(self, step: int) -> None:
252+
"""Commit the logging results to wandb"""
253+
self.monitor.log({"dummy_log_explorer": step}, step=step, commit=True)

trinity/trainer/trainer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ def sync_weight(self) -> None:
105105
if self.config.synchronizer.sync_method == "online":
106106
self.engine.sync_weight()
107107

108+
def log_finalize(self, step: int) -> None:
109+
"""Commit the logging results to wandb"""
110+
self.engine.logger.log({"dummy_log_trainer": step}, step=step, commit=True)
111+
108112

109113
class TrainEngineWrapper(ABC):
110114
"""A wrapper class to wrap various training engines."""

trinity/utils/monitor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@ def calculate_metrics(
5252
metrics[key] = val
5353
return metrics
5454

55-
def log(self, data: dict, step: int) -> None:
55+
def log(self, data: dict, step: int, commit: bool = False) -> None:
5656
"""Log metrics."""
57-
self.logger.log(data, step=step)
57+
self.logger.log(data, step=step, commit=commit)
5858
self.console_logger.info(f"Step {step}: {data}")
5959

6060
def __del__(self) -> None:

0 commit comments

Comments
 (0)