Skip to content

Commit 8959abd

Browse files
Change to google (#154)
* update docs * fix docs * Update optimizer_config.schema.json * remove * Update optimizer_config.schema.json --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent b0e1e0c commit 8959abd

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

78 files changed

+2964
-1858
lines changed

autointent/_callbacks/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111

1212

1313
def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
14-
"""
15-
Get the list of callbacks.
14+
"""Get the list of callbacks.
15+
16+
Args:
17+
reporters: List of reporters to use.
1618
17-
:param reporters: List of reporters to use.
18-
:return: Callback handler.
19+
Returns:
20+
CallbackHandler: Callback handler.
1921
"""
2022
if not reporters:
2123
return CallbackHandler()

autointent/_callbacks/base.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,37 @@ def __init__(self) -> None:
1717

1818
@abstractmethod
1919
def start_run(self, run_name: str, dirpath: Path) -> None:
20-
"""
21-
Start a new run.
20+
"""Start a new run.
2221
23-
:param run_name: Name of the run.
24-
:param dirpath: Path to the directory where the logs will be saved.
22+
Args:
23+
run_name: Name of the run.
24+
dirpath: Path to the directory where the logs will be saved.
2525
"""
2626

2727
@abstractmethod
2828
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
29-
"""
30-
Start a new module.
29+
"""Start a new module.
3130
32-
:param module_name: Name of the module.
33-
:param num: Number of the module.
34-
:param module_kwargs: Module parameters.
31+
Args:
32+
module_name: Name of the module.
33+
num: Number of the module.
34+
module_kwargs: Module parameters.
3535
"""
3636

3737
@abstractmethod
3838
def log_value(self, **kwargs: dict[str, Any]) -> None:
39-
"""
40-
Log data.
39+
"""Log data.
4140
42-
:param kwargs: Data to log.
41+
Args:
42+
kwargs: Data to log.
4343
"""
4444

4545
@abstractmethod
4646
def log_metrics(self, metrics: dict[str, Any]) -> None:
47-
"""
48-
Log metrics during training.
47+
"""Log metrics during training.
4948
50-
:param metrics: Metrics to log.
49+
Args:
50+
metrics: Metrics to log.
5151
"""
5252

5353
@abstractmethod
@@ -60,8 +60,8 @@ def end_run(self) -> None:
6060

6161
@abstractmethod
6262
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
63-
"""
64-
Log final metrics.
63+
"""Log final metrics.
6564
66-
:param metrics: Final metrics.
65+
Args:
66+
metrics: Final metrics.
6767
"""

autointent/_callbacks/callback_handler.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,45 +10,49 @@ class CallbackHandler(OptimizerCallback):
1010
callbacks: list[OptimizerCallback]
1111

1212
def __init__(self, callbacks: list[type[OptimizerCallback]] | None = None) -> None:
13-
"""Initialize the callback handler."""
13+
"""Initialize the callback handler.
14+
15+
Args:
16+
callbacks: List of callback classes.
17+
"""
1418
if not callbacks:
1519
self.callbacks = []
1620
return
1721

1822
self.callbacks = [cb() for cb in callbacks]
1923

2024
def start_run(self, run_name: str, dirpath: Path) -> None:
21-
"""
22-
Start a new run.
25+
"""Start a new run.
2326
24-
:param run_name: Name of the run.
25-
:param dirpath: Path to the directory where the logs will be saved.
27+
Args:
28+
run_name: Name of the run.
29+
dirpath: Path to the directory where the logs will be saved.
2630
"""
2731
self.call_events("start_run", run_name=run_name, dirpath=dirpath)
2832

2933
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
30-
"""
31-
Start a new module.
34+
"""Start a new module.
3235
33-
:param module_name: Name of the module.
34-
:param num: Number of the module.
35-
:param module_kwargs: Module parameters.
36+
Args:
37+
module_name: Name of the module.
38+
num: Number of the module.
39+
module_kwargs: Module parameters.
3640
"""
3741
self.call_events("start_module", module_name=module_name, num=num, module_kwargs=module_kwargs)
3842

3943
def log_value(self, **kwargs: dict[str, Any]) -> None:
40-
"""
41-
Log data.
44+
"""Log data.
4245
43-
:param kwargs: Data to log.
46+
Args:
47+
kwargs: Data to log.
4448
"""
4549
self.call_events("log_value", **kwargs)
4650

4751
def log_metrics(self, metrics: dict[str, Any]) -> None:
48-
"""
49-
Log metrics during training.
52+
"""Log metrics during training.
5053
51-
:param metrics: Metrics to log.
54+
Args:
55+
metrics: Metrics to log.
5256
"""
5357
self.call_events("log_metrics", metrics=metrics)
5458

@@ -61,13 +65,19 @@ def end_run(self) -> None:
6165
self.call_events("end_run")
6266

6367
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
64-
"""
65-
Log final metrics.
68+
"""Log final metrics.
6669
67-
:param metrics: Final metrics.
70+
Args:
71+
metrics: Final metrics.
6872
"""
6973
self.call_events("log_final_metrics", metrics=metrics)
7074

7175
def call_events(self, event: str, **kwargs: Any) -> None: # noqa: ANN401
76+
"""Call events for all callbacks.
77+
78+
Args:
79+
event: Event name.
80+
kwargs: Event parameters.
81+
"""
7282
for callback in self.callbacks:
7383
getattr(callback, event)(**kwargs)

autointent/_callbacks/tensorboard.py

Lines changed: 33 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55

66

77
class TensorBoardCallback(OptimizerCallback):
8-
"""
9-
TensorBoard callback.
10-
11-
This callback logs the optimization process to TensorBoard.
12-
"""
8+
"""TensorBoard callback for logging the optimization process."""
139

1410
name = "tensorboard"
1511

1612
def __init__(self) -> None:
17-
"""Initialize the callback."""
13+
"""Initializes the TensorBoard callback.
14+
15+
Attempts to import `torch.utils.tensorboard` first. If unavailable, tries to import `tensorboardX`.
16+
Raises an ImportError if neither are installed.
17+
"""
1818
try:
1919
from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined]
2020

@@ -32,22 +32,22 @@ def __init__(self) -> None:
3232
raise ImportError(msg) from None
3333

3434
def start_run(self, run_name: str, dirpath: Path) -> None:
35-
"""
36-
Start a new run.
35+
"""Starts a new run and sets the directory for storing logs.
3736
38-
:param run_name: Name of the run.
39-
:param dirpath: Path to the directory where the logs will be saved.
37+
Args:
38+
run_name: Name of the run.
39+
dirpath: Path to the directory where logs will be saved.
4040
"""
4141
self.run_name = run_name
4242
self.dirpath = dirpath
4343

4444
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
45-
"""
46-
Start a new module.
45+
"""Starts a new module and initializes a TensorBoard writer for it.
4746
48-
:param module_name: Name of the module.
49-
:param num: Number of the module.
50-
:param module_kwargs: Module parameters.
47+
Args:
48+
module_name: Name of the module.
49+
num: Identifier number of the module.
50+
module_kwargs: Dictionary containing module parameters.
5151
"""
5252
module_run_name = f"{self.run_name}_{module_name}_{num}"
5353
log_dir = Path(self.dirpath) / module_run_name
@@ -58,10 +58,10 @@ def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]
5858
self.module_writer.add_text(f"module_params/{key}", str(value)) # type: ignore[no-untyped-call]
5959

6060
def log_value(self, **kwargs: dict[str, int | float | Any]) -> None:
61-
"""
62-
Log data.
61+
"""Logs scalar or text values.
6362
64-
:param kwargs: Data to log.
63+
Args:
64+
**kwargs: Key-value pairs of data to log. Scalars will be logged as numerical values, others as text.
6565
"""
6666
for key, value in kwargs.items():
6767
if isinstance(value, int | float):
@@ -70,10 +70,10 @@ def log_value(self, **kwargs: dict[str, int | float | Any]) -> None:
7070
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
7171

7272
def log_metrics(self, metrics: dict[str, Any]) -> None:
73-
"""
74-
Log metrics during training.
73+
"""Logs training metrics.
7574
76-
:param metrics: Metrics to log.
75+
Args:
76+
metrics: Dictionary of metrics to log.
7777
"""
7878
for key, value in metrics.items():
7979
if isinstance(value, int | float):
@@ -82,10 +82,13 @@ def log_metrics(self, metrics: dict[str, Any]) -> None:
8282
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
8383

8484
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
85-
"""
86-
Log final metrics.
85+
"""Logs final metrics at the end of training.
86+
87+
Args:
88+
metrics: Dictionary of final metrics.
8789
88-
:param metrics: Final metrics.
90+
Raises:
91+
RuntimeError: If `start_run` has not been called before logging final metrics.
8992
"""
9093
if self.module_writer is None:
9194
msg = "start_run must be called before log_final_metrics."
@@ -101,7 +104,11 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None:
101104
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
102105

103106
def end_module(self) -> None:
104-
"""End a module."""
107+
"""Ends the current module and closes the TensorBoard writer.
108+
109+
Raises:
110+
RuntimeError: If `start_run` has not been called before ending the module.
111+
"""
105112
if self.module_writer is None:
106113
msg = "start_run must be called before end_module."
107114
raise RuntimeError(msg)
@@ -110,4 +117,4 @@ def end_module(self) -> None:
110117
self.module_writer.close() # type: ignore[no-untyped-call]
111118

112119
def end_run(self) -> None:
113-
pass
120+
"""Ends the current run. This method is currently a placeholder."""

0 commit comments

Comments
 (0)