Skip to content

Commit 1fdb218

Browse files
committed
Merge branch 'refs/heads/dev' into add_dspy
2 parents 8b2f849 + de4d37d commit 1fdb218

File tree

131 files changed

+6470
-4239
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

131 files changed

+6470
-4239
lines changed
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
name: test presets
2+
3+
on:
4+
push:
5+
branches:
6+
- dev
7+
pull_request:
8+
9+
jobs:
10+
test:
11+
runs-on: ${{ matrix.os }}
12+
strategy:
13+
fail-fast: false
14+
matrix:
15+
os: [ ubuntu-latest ]
16+
python-version: [ "3.10", "3.11", "3.12" ]
17+
include:
18+
- os: windows-latest
19+
python-version: "3.10"
20+
21+
steps:
22+
- name: Checkout code
23+
uses: actions/checkout@v4
24+
25+
- name: Setup Python ${{ matrix.python-version }}
26+
uses: actions/setup-python@v5
27+
with:
28+
python-version: ${{ matrix.python-version }}
29+
cache: "pip"
30+
31+
- name: Install dependencies
32+
run: |
33+
pip install .
34+
pip install pytest pytest-asyncio
35+
36+
- name: Run tests
37+
run: |
38+
pytest tests/pipeline/test_presets.py

README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ Example of building an intent classifier in a couple of lines of code:
3030
from autointent import Pipeline, Dataset
3131

3232
dataset = Dataset.from_json(path_to_json)
33-
pipeline = Pipeline.default_optimizer(multilabel=False)
33+
pipeline = Pipeline.from_preset("light")
3434
pipeline.fit(dataset)
35-
pipeline.predict(["show me my latest recent transactions"])
35+
pipeline.predict(["show me my latest transactions"])
3636
```

autointent/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from ._dataset import Dataset
88
from ._hash import Hasher
99
from .context import Context, load_dataset
10+
from ._optimization_config import OptimizationConfig
1011
from ._pipeline import Pipeline
1112

1213

@@ -15,6 +16,7 @@
1516
"Dataset",
1617
"Embedder",
1718
"Hasher",
19+
"OptimizationConfig",
1820
"Pipeline",
1921
"Ranker",
2022
"VectorIndex",

autointent/_callbacks/__init__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77

88
REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback]}
99

10-
REPORTERS_NAMES = list(REPORTERS.keys())
10+
REPORTERS_NAMES = Literal[tuple(REPORTERS.keys())] # type: ignore[valid-type]
1111

1212

1313
def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
14-
"""
15-
Get the list of callbacks.
14+
"""Get the list of callbacks.
15+
16+
Args:
17+
reporters: List of reporters to use.
1618
17-
:param reporters: List of reporters to use.
18-
:return: Callback handler.
19+
Returns:
20+
CallbackHandler: Callback handler.
1921
"""
2022
if not reporters:
2123
return CallbackHandler()

autointent/_callbacks/base.py

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,37 +17,37 @@ def __init__(self) -> None:
1717

1818
@abstractmethod
1919
def start_run(self, run_name: str, dirpath: Path) -> None:
20-
"""
21-
Start a new run.
20+
"""Start a new run.
2221
23-
:param run_name: Name of the run.
24-
:param dirpath: Path to the directory where the logs will be saved.
22+
Args:
23+
run_name: Name of the run.
24+
dirpath: Path to the directory where the logs will be saved.
2525
"""
2626

2727
@abstractmethod
2828
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
29-
"""
30-
Start a new module.
29+
"""Start a new module.
3130
32-
:param module_name: Name of the module.
33-
:param num: Number of the module.
34-
:param module_kwargs: Module parameters.
31+
Args:
32+
module_name: Name of the module.
33+
num: Number of the module.
34+
module_kwargs: Module parameters.
3535
"""
3636

3737
@abstractmethod
3838
def log_value(self, **kwargs: dict[str, Any]) -> None:
39-
"""
40-
Log data.
39+
"""Log data.
4140
42-
:param kwargs: Data to log.
41+
Args:
42+
kwargs: Data to log.
4343
"""
4444

4545
@abstractmethod
4646
def log_metrics(self, metrics: dict[str, Any]) -> None:
47-
"""
48-
Log metrics during training.
47+
"""Log metrics during training.
4948
50-
:param metrics: Metrics to log.
49+
Args:
50+
metrics: Metrics to log.
5151
"""
5252

5353
@abstractmethod
@@ -60,8 +60,8 @@ def end_run(self) -> None:
6060

6161
@abstractmethod
6262
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
63-
"""
64-
Log final metrics.
63+
"""Log final metrics.
6564
66-
:param metrics: Final metrics.
65+
Args:
66+
metrics: Final metrics.
6767
"""

autointent/_callbacks/callback_handler.py

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -10,45 +10,49 @@ class CallbackHandler(OptimizerCallback):
1010
callbacks: list[OptimizerCallback]
1111

1212
def __init__(self, callbacks: list[type[OptimizerCallback]] | None = None) -> None:
13-
"""Initialize the callback handler."""
13+
"""Initialize the callback handler.
14+
15+
Args:
16+
callbacks: List of callback classes.
17+
"""
1418
if not callbacks:
1519
self.callbacks = []
1620
return
1721

1822
self.callbacks = [cb() for cb in callbacks]
1923

2024
def start_run(self, run_name: str, dirpath: Path) -> None:
21-
"""
22-
Start a new run.
25+
"""Start a new run.
2326
24-
:param run_name: Name of the run.
25-
:param dirpath: Path to the directory where the logs will be saved.
27+
Args:
28+
run_name: Name of the run.
29+
dirpath: Path to the directory where the logs will be saved.
2630
"""
2731
self.call_events("start_run", run_name=run_name, dirpath=dirpath)
2832

2933
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
30-
"""
31-
Start a new module.
34+
"""Start a new module.
3235
33-
:param module_name: Name of the module.
34-
:param num: Number of the module.
35-
:param module_kwargs: Module parameters.
36+
Args:
37+
module_name: Name of the module.
38+
num: Number of the module.
39+
module_kwargs: Module parameters.
3640
"""
3741
self.call_events("start_module", module_name=module_name, num=num, module_kwargs=module_kwargs)
3842

3943
def log_value(self, **kwargs: dict[str, Any]) -> None:
40-
"""
41-
Log data.
44+
"""Log data.
4245
43-
:param kwargs: Data to log.
46+
Args:
47+
kwargs: Data to log.
4448
"""
4549
self.call_events("log_value", **kwargs)
4650

4751
def log_metrics(self, metrics: dict[str, Any]) -> None:
48-
"""
49-
Log metrics during training.
52+
"""Log metrics during training.
5053
51-
:param metrics: Metrics to log.
54+
Args:
55+
metrics: Metrics to log.
5256
"""
5357
self.call_events("log_metrics", metrics=metrics)
5458

@@ -61,13 +65,19 @@ def end_run(self) -> None:
6165
self.call_events("end_run")
6266

6367
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
64-
"""
65-
Log final metrics.
68+
"""Log final metrics.
6669
67-
:param metrics: Final metrics.
70+
Args:
71+
metrics: Final metrics.
6872
"""
6973
self.call_events("log_final_metrics", metrics=metrics)
7074

7175
def call_events(self, event: str, **kwargs: Any) -> None: # noqa: ANN401
76+
"""Call events for all callbacks.
77+
78+
Args:
79+
event: Event name.
80+
kwargs: Event parameters.
81+
"""
7282
for callback in self.callbacks:
7383
getattr(callback, event)(**kwargs)

autointent/_callbacks/tensorboard.py

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,16 @@
55

66

77
class TensorBoardCallback(OptimizerCallback):
8-
"""
9-
TensorBoard callback.
10-
11-
This callback logs the optimization process to TensorBoard.
12-
"""
8+
"""TensorBoard callback for logging the optimization process."""
139

1410
name = "tensorboard"
1511

1612
def __init__(self) -> None:
17-
"""Initialize the callback."""
13+
"""Initializes the TensorBoard callback.
14+
15+
Attempts to import `torch.utils.tensorboard` first. If unavailable, tries to import `tensorboardX`.
16+
Raises an ImportError if neither are installed.
17+
"""
1818
try:
1919
from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined]
2020

@@ -32,22 +32,22 @@ def __init__(self) -> None:
3232
raise ImportError(msg) from None
3333

3434
def start_run(self, run_name: str, dirpath: Path) -> None:
35-
"""
36-
Start a new run.
35+
"""Starts a new run and sets the directory for storing logs.
3736
38-
:param run_name: Name of the run.
39-
:param dirpath: Path to the directory where the logs will be saved.
37+
Args:
38+
run_name: Name of the run.
39+
dirpath: Path to the directory where logs will be saved.
4040
"""
4141
self.run_name = run_name
4242
self.dirpath = dirpath
4343

4444
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
45-
"""
46-
Start a new module.
45+
"""Starts a new module and initializes a TensorBoard writer for it.
4746
48-
:param module_name: Name of the module.
49-
:param num: Number of the module.
50-
:param module_kwargs: Module parameters.
47+
Args:
48+
module_name: Name of the module.
49+
num: Identifier number of the module.
50+
module_kwargs: Dictionary containing module parameters.
5151
"""
5252
module_run_name = f"{self.run_name}_{module_name}_{num}"
5353
log_dir = Path(self.dirpath) / module_run_name
@@ -57,43 +57,38 @@ def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]
5757
for key, value in module_kwargs.items():
5858
self.module_writer.add_text(f"module_params/{key}", str(value)) # type: ignore[no-untyped-call]
5959

60-
def log_value(self, **kwargs: dict[str, Any]) -> None:
61-
"""
62-
Log data.
60+
def log_value(self, **kwargs: dict[str, int | float | Any]) -> None:
61+
"""Logs scalar or text values.
6362
64-
:param kwargs: Data to log.
63+
Args:
64+
**kwargs: Key-value pairs of data to log. Scalars will be logged as numerical values, others as text.
6565
"""
66-
if self.module_writer is None:
67-
msg = "start_run must be called before log_value."
68-
raise RuntimeError(msg)
69-
7066
for key, value in kwargs.items():
7167
if isinstance(value, int | float):
7268
self.module_writer.add_scalar(key, value)
7369
else:
7470
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
7571

7672
def log_metrics(self, metrics: dict[str, Any]) -> None:
77-
"""
78-
Log metrics during training.
73+
"""Logs training metrics.
7974
80-
:param metrics: Metrics to log.
75+
Args:
76+
metrics: Dictionary of metrics to log.
8177
"""
82-
if self.module_writer is None:
83-
msg = "start_run must be called before log_value."
84-
raise RuntimeError(msg)
85-
8678
for key, value in metrics.items():
8779
if isinstance(value, int | float):
8880
self.module_writer.add_scalar(key, value) # type: ignore[no-untyped-call]
8981
else:
9082
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
9183

9284
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
93-
"""
94-
Log final metrics.
85+
"""Logs final metrics at the end of training.
86+
87+
Args:
88+
metrics: Dictionary of final metrics.
9589
96-
:param metrics: Final metrics.
90+
Raises:
91+
RuntimeError: If `start_run` has not been called before logging final metrics.
9792
"""
9893
if self.module_writer is None:
9994
msg = "start_run must be called before log_final_metrics."
@@ -109,7 +104,11 @@ def log_final_metrics(self, metrics: dict[str, Any]) -> None:
109104
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
110105

111106
def end_module(self) -> None:
112-
"""End a module."""
107+
"""Ends the current module and closes the TensorBoard writer.
108+
109+
Raises:
110+
RuntimeError: If `start_run` has not been called before ending the module.
111+
"""
113112
if self.module_writer is None:
114113
msg = "start_run must be called before end_module."
115114
raise RuntimeError(msg)
@@ -118,4 +117,4 @@ def end_module(self) -> None:
118117
self.module_writer.close() # type: ignore[no-untyped-call]
119118

120119
def end_run(self) -> None:
121-
pass
120+
"""Ends the current run. This method is currently a placeholder."""

0 commit comments

Comments
 (0)