Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,10 @@ def both(config: Config) -> None:
logger.error("Evaluation failed.")
raise e

# a trick to update wandb timely
ray.get(explorer.log_finalize.remote(step=train_iter_num))
ray.get(trainer.log_finalize.remote(step=train_iter_num))


def activate_data_module(config_path: str):
"""Check whether to activate data module and preprocess datasets."""
Expand Down
4 changes: 4 additions & 0 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,7 @@ def sync_weight(self) -> None:
self._offline_weights_update()
else: # online weights update
self._online_weights_update()

def log_finalize(self, step: int) -> None:
"""Commit the logging results to wandb"""
self.monitor.log({"dummy_log_explorer": step}, step=step, commit=True)
4 changes: 4 additions & 0 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def sync_weight(self) -> None:
if self.config.synchronizer.sync_method == "online":
self.engine.sync_weight()

def log_finalize(self, step: int) -> None:
"""Commit the logging results to wandb"""
self.engine.logger.log({"dummy_log_trainer": step}, step=step, commit=True)


class TrainEngineWrapper(ABC):
"""A wrapper class to wrap various training engines."""
Expand Down
4 changes: 2 additions & 2 deletions trinity/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def calculate_metrics(
metrics[key] = val
return metrics

def log(self, data: dict, step: int) -> None:
def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
self.logger.log(data, step=step)
self.logger.log(data, step=step, commit=commit)
self.console_logger.info(f"Step {step}: {data}")

def __del__(self) -> None:
Expand Down