Skip to content

Commit 2bac917

Browse files
committed
resolve conflicts
2 parents acc5d41 + c66d97f commit 2bac917

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+1624
-1839
lines changed

.github/workflows/build-docs.yaml

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ on:
44
push:
55
branches:
66
- dev
7+
pull_request:
8+
branches:
9+
- dev
710
workflow_dispatch:
811

912
concurrency:
@@ -37,6 +40,10 @@ jobs:
3740
run: |
3841
poetry install --with docs
3942
43+
- name: Test documentation
44+
run: |
45+
make test-docs
46+
4047
- name: build documentation
4148
run: |
4249
make docs
@@ -49,15 +56,16 @@ jobs:
4956
BRANCH_NAME=${BRANCH_NAME////_}
5057
echo BRANCH_NAME=${BRANCH_NAME} >> $GITHUB_ENV
5158
52-
- name: save artifact
59+
- name: Upload artifact
5360
uses: actions/upload-artifact@v4
5461
with:
5562
name: ${{ format('github-pages-for-branch-{0}', env.BRANCH_NAME) }}
5663
path: docs/build/
5764
retention-days: 3
5865

59-
- name: deploy website
66+
- name: Deploy to GitHub Pages
6067
uses: JamesIves/[email protected]
68+
if: ${{ github.ref == 'refs/heads/dev' }}
6169
with:
6270
branch: gh-pages
6371
folder: docs/build/html/

Makefile

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ sync:
2828

2929
.PHONY: docs
3030
docs:
31+
$(poetry) python -m sphinx build -b html docs/source docs/build/html
32+
33+
.PHONY: test-docs
34+
test-docs: docs
3135
$(poetry) python -m sphinx build -b doctest docs/source docs/build/html
3236

3337
.PHONY: serve-docs

autointent/_embedder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,8 @@ def delete(self) -> None:
7272
"""Delete the embedding model and its associated directory."""
7373
self.clear_ram()
7474
shutil.rmtree(
75-
self.dump_dir, ignore_errors=True
75+
self.dump_dir,
76+
ignore_errors=True,
7677
) # TODO: `ignore_errors=True` is workaround for PermissionError: [WinError 5] Access is denied
7778

7879
def dump(self, path: Path) -> None:

autointent/configs/_optimization_cli.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,8 +123,6 @@ class OptimizationConfig:
123123
"""Configuration for the logging"""
124124
vector_index: VectorIndexConfig = field(default_factory=VectorIndexConfig)
125125
"""Configuration for the vector index"""
126-
augmentation: AugmentationConfig = field(default_factory=AugmentationConfig)
127-
"""Configuration for the augmentation"""
128126
embedder: EmbedderConfig = field(default_factory=EmbedderConfig)
129127
"""Configuration for the embedder"""
130128

@@ -133,7 +131,7 @@ class OptimizationConfig:
133131
"_self_",
134132
{"override hydra/job_logging": "autointent_standard_job_logger"},
135133
{"override hydra/help": "autointent_help"},
136-
]
134+
],
137135
)
138136

139137

autointent/context/_context.py

Lines changed: 16 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,14 @@
99
import yaml
1010

1111
from autointent.configs import (
12-
AugmentationConfig,
1312
DataConfig,
1413
EmbedderConfig,
1514
LoggingConfig,
1615
VectorIndexConfig,
1716
)
1817

1918
from ._utils import NumpyEncoder, load_data
20-
from .data_handler import DataAugmenter, DataHandler, Dataset
19+
from .data_handler import DataHandler, Dataset
2120
from .optimization_info import OptimizationInfo
2221
from .vector_index_client import VectorIndex, VectorIndexClient
2322

@@ -71,43 +70,29 @@ def configure_vector_index(self, config: VectorIndexConfig, embedder_config: Emb
7170
self.embedder_config.max_length,
7271
)
7372

74-
def configure_data(self, config: DataConfig, augmentation_config: AugmentationConfig | None = None) -> None:
73+
def configure_data(self, config: DataConfig) -> None:
7574
"""
76-
Configure data handling and augmentation.
75+
Configure data handling.
7776
7877
:param config: Configuration for the data handling process.
79-
:param augmentation_config: Configuration for data augmentation. If None, no augmentation is applied.
80-
"""
81-
if augmentation_config is not None:
82-
self.augmentation_config = AugmentationConfig()
83-
augmenter = DataAugmenter(
84-
self.augmentation_config.multilabel_generation_config,
85-
self.augmentation_config.regex_sampling,
86-
self.seed,
87-
)
88-
else:
89-
augmenter = None
90-
78+
"""
9179
self.data_handler = DataHandler(
9280
dataset=load_data(config.train_path),
93-
test_dataset=None if config.test_path is None else load_data(config.test_path),
9481
random_seed=self.seed,
9582
force_multilabel=config.force_multilabel,
96-
augmenter=augmenter,
9783
)
9884

99-
def set_datasets(
100-
self, train_data: Dataset, val_data: Dataset | None = None, force_multilabel: bool = False
101-
) -> None:
85+
def set_dataset(self, dataset: Dataset, force_multilabel: bool = False) -> None:
10286
"""
103-
Set the datasets for training and validation.
87+
Set the datasets for training, validation and testing.
10488
105-
:param train_data: Training dataset.
106-
:param val_data: Validation dataset. If None, only training data is used.
89+
:param dataset: Dataset.
10790
:param force_multilabel: Whether to force multilabel classification.
10891
"""
10992
self.data_handler = DataHandler(
110-
dataset=train_data, test_dataset=val_data, random_seed=self.seed, force_multilabel=force_multilabel
93+
dataset=dataset,
94+
force_multilabel=force_multilabel,
95+
random_seed=self.seed,
11196
)
11297

11398
def get_best_index(self) -> VectorIndex:
@@ -159,13 +144,12 @@ def dump(self) -> None:
159144
with logs_path.open("w") as file:
160145
json.dump(optimization_results, file, indent=4, ensure_ascii=False, cls=NumpyEncoder)
161146

162-
train_data, test_data = self.data_handler.dump()
163-
train_path = logs_dir / "train_data.json"
164-
test_path = logs_dir / "test_data.json"
165-
with train_path.open("w") as file:
166-
json.dump(train_data, file, indent=4, ensure_ascii=False)
167-
with test_path.open("w") as file:
168-
json.dump(test_data, file, indent=4, ensure_ascii=False)
147+
# self._logger.info(make_report(optimization_results, nodes=nodes))
148+
149+
# dump train and test data splits
150+
dataset_path = logs_dir / "dataset.json"
151+
with dataset_path.open("w") as file:
152+
json.dump(self.data_handler.dump(), file, indent=4, ensure_ascii=False)
169153

170154
self._logger.info("logs and other assets are saved to %s", logs_dir)
171155

autointent/context/_utils.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def default(self, obj: Any) -> str | int | float | list[Any] | Any: # noqa: ANN
4141
return super().default(obj)
4242

4343

44-
def load_data(data_path: str | Path) -> Dataset:
44+
def load_data(filepath: str | Path) -> Dataset:
4545
"""
4646
Load data from a specified path or use default sample data.
4747
@@ -54,14 +54,12 @@ def load_data(data_path: str | Path) -> Dataset:
5454
- "default-multilabel": Loads sample multilabel dataset.
5555
:return: A `Dataset` object containing the loaded data.
5656
"""
57-
if data_path == "default-multiclass":
58-
with ires.files("autointent.datafiles").joinpath("banking77.json").open() as file:
59-
res = json.load(file)
60-
elif data_path == "default-multilabel":
61-
with ires.files("autointent.datafiles").joinpath("dstc3-20shot.json").open() as file:
62-
res = json.load(file)
63-
else:
64-
with Path(data_path).open() as file:
65-
res = json.load(file)
66-
67-
return Dataset.model_validate(res)
57+
if filepath == "default-multiclass":
58+
return Dataset.from_json(
59+
ires.files("autointent.datafiles").joinpath("banking77.json"), # type: ignore[arg-type]
60+
)
61+
if filepath == "default-multilabel":
62+
return Dataset.from_json(
63+
ires.files("autointent.datafiles").joinpath("dstc3-20shot.json"), # type: ignore[arg-type]
64+
)
65+
return Dataset.from_json(filepath)
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from ._data_handler import DataAugmenter, DataHandler
2-
from ._schemas import Dataset
3-
from ._tags import Tag
1+
from ._data_handler import DataHandler
2+
from ._dataset import Dataset
3+
from ._schemas import Intent, Sample, Tag
44

5-
__all__ = ["DataAugmenter", "DataHandler", "Dataset", "Tag"]
5+
__all__ = ["DataHandler", "Dataset", "Intent", "Sample", "Tag"]

0 commit comments

Comments
 (0)