Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions autointent/_callbacks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,13 @@


def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
"""
Get the list of callbacks.
"""Get the list of callbacks.

Args:
reporters: List of reporters to use.

:param reporters: List of reporters to use.
:return: Callback handler.
Returns:
CallbackHandler: Callback handler.
"""
if not reporters:
return CallbackHandler()
Expand Down
36 changes: 18 additions & 18 deletions autointent/_callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,37 @@ def __init__(self) -> None:

@abstractmethod
def start_run(self, run_name: str, dirpath: Path) -> None:
"""
Start a new run.
"""Start a new run.

:param run_name: Name of the run.
:param dirpath: Path to the directory where the logs will be saved.
Args:
run_name: Name of the run.
dirpath: Path to the directory where the logs will be saved.
"""

@abstractmethod
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
"""
Start a new module.
"""Start a new module.

:param module_name: Name of the module.
:param num: Number of the module.
:param module_kwargs: Module parameters.
Args:
module_name: Name of the module.
num: Number of the module.
module_kwargs: Module parameters.
"""

@abstractmethod
def log_value(self, **kwargs: dict[str, Any]) -> None:
"""
Log data.
"""Log data.

:param kwargs: Data to log.
Args:
kwargs: Data to log.
"""

@abstractmethod
def log_metrics(self, metrics: dict[str, Any]) -> None:
"""
Log metrics during training.
"""Log metrics during training.

:param metrics: Metrics to log.
Args:
metrics: Metrics to log.
"""

@abstractmethod
Expand All @@ -60,8 +60,8 @@ def end_run(self) -> None:

@abstractmethod
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
"""
Log final metrics.
"""Log final metrics.

:param metrics: Final metrics.
Args:
metrics: Final metrics.
"""
48 changes: 29 additions & 19 deletions autointent/_callbacks/callback_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,45 +10,49 @@ class CallbackHandler(OptimizerCallback):
callbacks: list[OptimizerCallback]

def __init__(self, callbacks: list[type[OptimizerCallback]] | None = None) -> None:
"""Initialize the callback handler."""
"""Initialize the callback handler.

Args:
callbacks: List of callback classes.
"""
if not callbacks:
self.callbacks = []
return

self.callbacks = [cb() for cb in callbacks]

def start_run(self, run_name: str, dirpath: Path) -> None:
"""
Start a new run.
"""Start a new run.

:param run_name: Name of the run.
:param dirpath: Path to the directory where the logs will be saved.
Args:
run_name: Name of the run.
dirpath: Path to the directory where the logs will be saved.
"""
self.call_events("start_run", run_name=run_name, dirpath=dirpath)

def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
"""
Start a new module.
"""Start a new module.

:param module_name: Name of the module.
:param num: Number of the module.
:param module_kwargs: Module parameters.
Args:
module_name: Name of the module.
num: Number of the module.
module_kwargs: Module parameters.
"""
self.call_events("start_module", module_name=module_name, num=num, module_kwargs=module_kwargs)

def log_value(self, **kwargs: dict[str, Any]) -> None:
"""
Log data.
"""Log data.

:param kwargs: Data to log.
Args:
kwargs: Data to log.
"""
self.call_events("log_value", **kwargs)

def log_metrics(self, metrics: dict[str, Any]) -> None:
"""
Log metrics during training.
"""Log metrics during training.

:param metrics: Metrics to log.
Args:
metrics: Metrics to log.
"""
self.call_events("log_metrics", metrics=metrics)

Expand All @@ -61,13 +65,19 @@ def end_run(self) -> None:
self.call_events("end_run")

def log_final_metrics(self, metrics: dict[str, Any]) -> None:
"""
Log final metrics.
"""Log final metrics.

:param metrics: Final metrics.
Args:
metrics: Final metrics.
"""
self.call_events("log_final_metrics", metrics=metrics)

def call_events(self, event: str, **kwargs: Any) -> None: # noqa: ANN401
"""Call events for all callbacks.

Args:
event: Event name.
kwargs: Event parameters.
"""
for callback in self.callbacks:
getattr(callback, event)(**kwargs)
59 changes: 33 additions & 26 deletions autointent/_callbacks/tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,16 @@


class TensorBoardCallback(OptimizerCallback):
"""
TensorBoard callback.

This callback logs the optimization process to TensorBoard.
"""
"""TensorBoard callback for logging the optimization process."""

name = "tensorboard"

def __init__(self) -> None:
"""Initialize the callback."""
"""Initializes the TensorBoard callback.

Attempts to import `torch.utils.tensorboard` first. If unavailable, tries to import `tensorboardX`.
Raises an ImportError if neither are installed.
"""
try:
from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined]

Expand All @@ -32,22 +32,22 @@ def __init__(self) -> None:
raise ImportError(msg) from None

def start_run(self, run_name: str, dirpath: Path) -> None:
"""
Start a new run.
"""Starts a new run and sets the directory for storing logs.

:param run_name: Name of the run.
:param dirpath: Path to the directory where the logs will be saved.
Args:
run_name: Name of the run.
dirpath: Path to the directory where logs will be saved.
"""
self.run_name = run_name
self.dirpath = dirpath

def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
"""
Start a new module.
"""Starts a new module and initializes a TensorBoard writer for it.

:param module_name: Name of the module.
:param num: Number of the module.
:param module_kwargs: Module parameters.
Args:
module_name: Name of the module.
num: Identifier number of the module.
module_kwargs: Dictionary containing module parameters.
"""
module_run_name = f"{self.run_name}_{module_name}_{num}"
log_dir = Path(self.dirpath) / module_run_name
Expand All @@ -58,10 +58,10 @@ def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]
self.module_writer.add_text(f"module_params/{key}", str(value)) # type: ignore[no-untyped-call]

def log_value(self, **kwargs: dict[str, int | float | Any]) -> None:
"""
Log data.
"""Logs scalar or text values.

:param kwargs: Data to log.
Args:
**kwargs: Key-value pairs of data to log. Scalars will be logged as numerical values, others as text.
"""
for key, value in kwargs.items():
if isinstance(value, int | float):
Expand All @@ -70,10 +70,10 @@ def log_value(self, **kwargs: dict[str, int | float | Any]) -> None:
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]

def log_metrics(self, metrics: dict[str, Any]) -> None:
"""
Log metrics during training.
"""Logs training metrics.

:param metrics: Metrics to log.
Args:
metrics: Dictionary of metrics to log.
"""
for key, value in metrics.items():
if isinstance(value, int | float):
Expand All @@ -82,10 +82,13 @@ def log_metrics(self, metrics: dict[str, Any]) -> None:
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]

def log_final_metrics(self, metrics: dict[str, Any]) -> None:
"""
Log final metrics.
"""Logs final metrics at the end of training.

Args:
metrics: Dictionary of final metrics.

:param metrics: Final metrics.
Raises:
RuntimeError: If `start_run` has not been called before logging final metrics.
"""
if self.module_writer is None:
msg = "start_run must be called before log_final_metrics."
Expand All @@ -101,7 +104,11 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None:
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]

def end_module(self) -> None:
"""End a module."""
"""Ends the current module and closes the TensorBoard writer.

Raises:
RuntimeError: If `start_run` has not been called before ending the module.
"""
if self.module_writer is None:
msg = "start_run must be called before end_module."
raise RuntimeError(msg)
Expand All @@ -110,4 +117,4 @@ def end_module(self) -> None:
self.module_writer.close() # type: ignore[no-untyped-call]

def end_run(self) -> None:
pass
"""Ends the current run. This method is currently a placeholder."""
Loading