Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions trinity/cli/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,9 @@ def both(config: Config) -> None:
logger.error("Evaluation failed.")
raise e

ray.get(explorer.log_finalize.remote(step=explore_iter_num))
ray.get(trainer.log_finalize.remote(step=train_iter_num))


def activate_data_module(config_path: str):
"""Check whether to activate data module and preprocess datasets."""
Expand Down
4 changes: 4 additions & 0 deletions trinity/explorer/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,7 @@ def sync_weight(self) -> None:
self._offline_weights_update()
else: # online weights update
self._online_weights_update()

def log_finalize(self, step: int) -> None:
"""Commit the logging results to wandb"""
self.monitor.log({"dummy_log_explorer": step}, step=step, commit=True)
4 changes: 4 additions & 0 deletions trinity/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,10 @@ def sync_weight(self) -> None:
if self.config.synchronizer.sync_method == "online":
self.engine.sync_weight()

def log_finalize(self, step: int) -> None:
"""Commit the logging results to wandb"""
self.engine.logger.log({"dummy_log_trainer": step}, step=step, commit=True)


class TrainEngineWrapper(ABC):
"""A wrapper class to wrap various training engines."""
Expand Down
4 changes: 2 additions & 2 deletions trinity/utils/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def calculate_metrics(
metrics[key] = val
return metrics

def log(self, data: dict, step: int) -> None:
def log(self, data: dict, step: int, commit: bool = False) -> None:
"""Log metrics."""
self.logger.log(data, step=step)
self.logger.log(data, step=step, commit=commit)
self.console_logger.info(f"Step {step}: {data}")

def __del__(self) -> None:
Expand Down