Skip to content

Commit 4c86e30

Browse files
committed
Merge branch 'dev' into feat/cross_encoder_refactor
2 parents 0e93c26 + b71a301 commit 4c86e30

File tree

28 files changed

+818
-67
lines changed

28 files changed

+818
-67
lines changed

.github/workflows/build-docs.yaml

Lines changed: 29 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
name: Build and publish docs
1+
name: Build and Publish Multi-Version Docs
22

33
on:
44
push:
55
branches:
6-
- dev
6+
- dev
7+
release:
8+
types:
9+
- published
710
pull_request:
811
branches:
912
- dev
@@ -17,13 +20,18 @@ permissions:
1720
contents: write
1821

1922
jobs:
20-
publish:
21-
name: build and publish docs
23+
build-docs:
24+
name: Build Documentation
2225
runs-on: ubuntu-latest
26+
2327
steps:
24-
- uses: actions/checkout@v4
28+
- name: Checkout code
29+
uses: actions/checkout@v4
30+
with:
31+
fetch-depth: 0
32+
fetch-tags: true
2533

26-
- name: set up python 3.10
34+
- name: Set up Python 3.10
2735
uses: actions/setup-python@v5
2836
with:
2937
python-version: "3.10"
@@ -36,37 +44,31 @@ jobs:
3644
run: |
3745
sudo apt install pandoc
3846
39-
- name: install dependencies
47+
- name: Install dependencies
4048
run: |
4149
poetry install --with docs
4250
43-
- name: Test documentation
51+
- name: Run tests
52+
if: github.event_name != 'workflow_dispatch'
4453
run: |
54+
echo "Testing documentation build..."
4555
make test-docs
4656
47-
- name: build documentation
57+
- name: Build documentation
58+
if: ${{ github.ref == 'refs/heads/dev' }} && github.event_name != 'workflow_dispatch'
4859
run: |
4960
make docs
5061
51-
- name: save branch name without slashes
52-
env:
53-
BRANCH_NAME: ${{ github.head_ref || github.ref_name }}
62+
- name: build multiversion documentation
63+
if: github.event_name == 'release' || github.event_name == 'workflow_dispatch'
5464
run: |
55-
BRANCH_NAME=${{ env.BRANCH_NAME }}
56-
BRANCH_NAME=${BRANCH_NAME////_}
57-
echo BRANCH_NAME=${BRANCH_NAME} >> $GITHUB_ENV
58-
59-
- name: Upload artifact
60-
uses: actions/upload-artifact@v4
61-
with:
62-
name: ${{ format('github-pages-for-branch-{0}', env.BRANCH_NAME) }}
63-
path: docs/build/
64-
retention-days: 3
65+
make multi-version-docs
6566
6667
- name: Deploy to GitHub Pages
67-
uses: JamesIves/github-pages-deploy-action@v4.6.4
68-
if: ${{ github.ref == 'refs/heads/dev' }}
68+
uses: peaceiris/actions-gh-pages@v3
69+
if: github.event_name == 'release' || github.event_name == 'workflow_dispatch'
6970
with:
70-
branch: gh-pages
71-
folder: docs/build/html/
72-
single-commit: True
71+
github_token: ${{ github.token }}
72+
publish_dir: docs/build/html/versions
73+
destination_dir: versions
74+
keep_files: true

.github/workflows/ruff.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@ jobs:
55
runs-on: ubuntu-latest
66
steps:
77
- uses: actions/checkout@v4
8-
- uses: astral-sh/ruff-action@v1
8+
- uses: astral-sh/ruff-action@v2
9+
with:
10+
version: "0.8.4"

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,5 @@ poetry.lock
177177
indexes_dirnames.json
178178
tests_logs
179179
tests/logs
180+
runs/
181+
vector_db*

Makefile

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,22 @@ docs:
3131
$(poetry) python -m sphinx build -b html docs/source docs/build/html
3232

3333
.PHONY: test-docs
34-
test-docs: docs
34+
test-docs:
3535
$(poetry) python -m sphinx build -b doctest docs/source docs/build/html
3636

3737
.PHONY: serve-docs
38-
serve-docs: docs
38+
serve-docs:
3939
$(poetry) python -m http.server -d docs/build/html 8333
4040

41+
.PHONY: multi-version-docs
42+
multi-version-docs:
43+
$(poetry) sphinx-multiversion docs/source docs/build/html
44+
4145
.PHONY: clean-docs
4246
clean-docs:
4347
rm -rf docs/build
4448
rm -rf docs/source/autoapi
4549
rm -rf docs/source/user_guides
4650

4751
.PHONY: all
48-
all: lint
52+
all: lint

autointent/_callbacks/__init__.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
from autointent._callbacks.base import OptimizerCallback
2+
from autointent._callbacks.callback_handler import CallbackHandler
3+
from autointent._callbacks.tensorboard import TensorBoardCallback
4+
from autointent._callbacks.wandb import WandbCallback
5+
6+
REPORTERS = {cb.name: cb for cb in [WandbCallback, TensorBoardCallback]}
7+
8+
9+
def get_callbacks(reporters: list[str] | None) -> CallbackHandler:
10+
"""
11+
Get the list of callbacks.
12+
13+
:param reporters: List of reporters to use.
14+
:return: Callback handler.
15+
"""
16+
if not reporters:
17+
return CallbackHandler()
18+
19+
reporters_cb = []
20+
for reporter in reporters:
21+
if reporter not in REPORTERS:
22+
msg = f"Reporter {reporter} not supported. Supported reporters {','.join(REPORTERS)}"
23+
raise ValueError(msg)
24+
reporters_cb.append(REPORTERS[reporter])
25+
return CallbackHandler(callbacks=reporters_cb)
26+
27+
28+
__all__ = [
29+
"CallbackHandler",
30+
"OptimizerCallback",
31+
"TensorBoardCallback",
32+
"WandbCallback",
33+
"get_callbacks",
34+
]

autointent/_callbacks/base.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
"""Base class for reporters (W&B, TensorBoard, etc)."""
2+
3+
from abc import ABC, abstractmethod
4+
from pathlib import Path
5+
from typing import Any
6+
7+
8+
class OptimizerCallback(ABC):
9+
"""Base class for reporters (W&B, TensorBoard, etc)."""
10+
11+
# Implementation inspired by TrainerCallback from HuggingFace Transformers. https://github.com/huggingface/transformers/blob/91b8ab18b778ae9e2f8191866e018cd1dc7097be/src/transformers/trainer_callback.py#L260
12+
name: str
13+
14+
@abstractmethod
15+
def __init__(self) -> None:
16+
pass
17+
18+
@abstractmethod
19+
def start_run(self, run_name: str, dirpath: Path) -> None:
20+
"""
21+
Start a new run.
22+
23+
:param run_name: Name of the run.
24+
:param dirpath: Path to the directory where the logs will be saved.
25+
"""
26+
27+
@abstractmethod
28+
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
29+
"""
30+
Start a new module.
31+
32+
:param module_name: Name of the module.
33+
:param num: Number of the module.
34+
:param module_kwargs: Module parameters.
35+
"""
36+
37+
@abstractmethod
38+
def log_value(self, **kwargs: dict[str, Any]) -> None:
39+
"""
40+
Log data.
41+
42+
:param kwargs: Data to log.
43+
"""
44+
45+
@abstractmethod
46+
def end_module(self) -> None:
47+
"""End a module."""
48+
49+
@abstractmethod
50+
def end_run(self) -> None:
51+
"""End a run."""
52+
53+
@abstractmethod
54+
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
55+
"""
56+
Log final metrics.
57+
58+
:param metrics: Final metrics.
59+
"""
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
from pathlib import Path
2+
from typing import Any
3+
4+
from autointent._callbacks.base import OptimizerCallback
5+
6+
7+
class CallbackHandler(OptimizerCallback):
8+
"""Internal class that just calls the list of callbacks in order."""
9+
10+
callbacks: list[OptimizerCallback]
11+
12+
def __init__(self, callbacks: list[type[OptimizerCallback]] | None = None) -> None:
13+
"""Initialize the callback handler."""
14+
if not callbacks:
15+
self.callbacks = []
16+
return
17+
18+
self.callbacks = [cb() for cb in callbacks]
19+
20+
def start_run(self, run_name: str, dirpath: Path) -> None:
21+
"""
22+
Start a new run.
23+
24+
:param run_name: Name of the run.
25+
:param dirpath: Path to the directory where the logs will be saved.
26+
"""
27+
self.call_events("start_run", run_name=run_name, dirpath=dirpath)
28+
29+
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
30+
"""
31+
Start a new module.
32+
33+
:param module_name: Name of the module.
34+
:param num: Number of the module.
35+
:param module_kwargs: Module parameters.
36+
"""
37+
self.call_events("start_module", module_name=module_name, num=num, module_kwargs=module_kwargs)
38+
39+
def log_value(self, **kwargs: dict[str, Any]) -> None:
40+
"""
41+
Log data.
42+
43+
:param kwargs: Data to log.
44+
"""
45+
self.call_events("log_value", **kwargs)
46+
47+
def end_module(self) -> None:
48+
"""End a module."""
49+
self.call_events("end_module")
50+
51+
def end_run(self) -> None:
52+
"""End a run."""
53+
self.call_events("end_run")
54+
55+
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
56+
"""
57+
Log final metrics.
58+
59+
:param metrics: Final metrics.
60+
"""
61+
self.call_events("log_final_metrics", metrics=metrics)
62+
63+
def call_events(self, event: str, **kwargs: Any) -> None: # noqa: ANN401
64+
for callback in self.callbacks:
65+
getattr(callback, event)(**kwargs)
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
from pathlib import Path
2+
from typing import Any
3+
4+
from autointent._callbacks.base import OptimizerCallback
5+
6+
7+
class TensorBoardCallback(OptimizerCallback):
8+
"""
9+
TensorBoard callback.
10+
11+
This callback logs the optimization process to TensorBoard.
12+
"""
13+
14+
name = "tensorboard"
15+
16+
def __init__(self) -> None:
17+
"""Initialize the callback."""
18+
try:
19+
from torch.utils.tensorboard import SummaryWriter # type: ignore[attr-defined]
20+
21+
self.writer = SummaryWriter
22+
except ImportError:
23+
try:
24+
from tensorboardX import SummaryWriter # type: ignore[no-redef]
25+
26+
self.writer = SummaryWriter
27+
except ImportError:
28+
msg = (
29+
"TensorBoardCallback requires tensorboard to be installed. Either update your PyTorch version or"
30+
" install tensorboardX."
31+
)
32+
raise ImportError(msg) from None
33+
34+
def start_run(self, run_name: str, dirpath: Path) -> None:
35+
"""
36+
Start a new run.
37+
38+
:param run_name: Name of the run.
39+
:param dirpath: Path to the directory where the logs will be saved.
40+
"""
41+
self.run_name = run_name
42+
self.dirpath = dirpath
43+
44+
def start_module(self, module_name: str, num: int, module_kwargs: dict[str, Any]) -> None:
45+
"""
46+
Start a new module.
47+
48+
:param module_name: Name of the module.
49+
:param num: Number of the module.
50+
:param module_kwargs: Module parameters.
51+
"""
52+
module_run_name = f"{self.run_name}_{module_name}_{num}"
53+
log_dir = Path(self.dirpath) / module_run_name
54+
self.module_writer = self.writer(log_dir=log_dir) # type: ignore[no-untyped-call]
55+
56+
self.module_writer.add_text("module_info", f"Starting module {module_name}_{num}") # type: ignore[no-untyped-call]
57+
for key, value in module_kwargs.items():
58+
self.module_writer.add_text(f"module_params/{key}", str(value)) # type: ignore[no-untyped-call]
59+
60+
def log_value(self, **kwargs: dict[str, Any]) -> None:
61+
"""
62+
Log data.
63+
64+
:param kwargs: Data to log.
65+
"""
66+
if self.module_writer is None:
67+
msg = "start_run must be called before log_value."
68+
raise RuntimeError(msg)
69+
70+
for key, value in kwargs.items():
71+
if isinstance(value, int | float):
72+
self.module_writer.add_scalar(key, value)
73+
else:
74+
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
75+
76+
def log_final_metrics(self, metrics: dict[str, Any]) -> None:
77+
"""
78+
Log final metrics.
79+
80+
:param metrics: Final metrics.
81+
"""
82+
if self.module_writer is None:
83+
msg = "start_run must be called before log_final_metrics."
84+
raise RuntimeError(msg)
85+
86+
log_dir = Path(self.dirpath) / "final_metrics"
87+
self.module_writer = self.writer(log_dir=log_dir) # type: ignore[no-untyped-call]
88+
89+
for key, value in metrics.items():
90+
if isinstance(value, int | float):
91+
self.module_writer.add_scalar(key, value) # type: ignore[no-untyped-call]
92+
else:
93+
self.module_writer.add_text(key, str(value)) # type: ignore[no-untyped-call]
94+
95+
def end_module(self) -> None:
96+
"""End a module."""
97+
if self.module_writer is None:
98+
msg = "start_run must be called before end_module."
99+
raise RuntimeError(msg)
100+
101+
self.module_writer.add_text("module_info", "Ending module") # type: ignore[no-untyped-call]
102+
self.module_writer.close() # type: ignore[no-untyped-call]
103+
104+
def end_run(self) -> None:
105+
pass

0 commit comments

Comments
 (0)