Skip to content

Commit ad097e8

Browse files
voorhsSamoedDarinochka
authored
Feat/pipeline simpler fitting (#36)
Co-authored-by: Roman Solomatin <[email protected]> Co-authored-by: Darinka <[email protected]>
1 parent eeb30f4 commit ad097e8

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

61 files changed

+2737
-526
lines changed
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: test inference
2+
3+
on:
4+
push:
5+
branches:
6+
- dev
7+
pull_request:
8+
branches:
9+
- dev
10+
11+
jobs:
12+
test:
13+
runs-on: ${{ matrix.os }}
14+
strategy:
15+
fail-fast: false
16+
matrix:
17+
os: [ ubuntu-latest ]
18+
python-version: [ "3.10", "3.11", "3.12" ]
19+
include:
20+
- os: windows-latest
21+
python-version: "3.10"
22+
23+
steps:
24+
- name: Checkout code
25+
uses: actions/checkout@v4
26+
27+
- name: Setup Python ${{ matrix.python-version }}
28+
uses: actions/setup-python@v5
29+
with:
30+
python-version: ${{ matrix.python-version }}
31+
cache: "pip"
32+
33+
- name: Install dependencies
34+
run: |
35+
pip install .
36+
pip install pytest pytest-asyncio
37+
38+
- name: Run tests
39+
run: |
40+
pytest tests/pipeline/test_inference.py
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Run Tests
1+
name: test nodes
22

33
on:
44
push:
@@ -37,4 +37,4 @@ jobs:
3737
3838
- name: Run tests
3939
run: |
40-
pytest
40+
pytest tests/nodes
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: test optimization
2+
3+
on:
4+
push:
5+
branches:
6+
- dev
7+
pull_request:
8+
branches:
9+
- dev
10+
11+
jobs:
12+
test:
13+
runs-on: ${{ matrix.os }}
14+
strategy:
15+
fail-fast: false
16+
matrix:
17+
os: [ ubuntu-latest ]
18+
python-version: [ "3.10", "3.11", "3.12" ]
19+
include:
20+
- os: windows-latest
21+
python-version: "3.10"
22+
23+
steps:
24+
- name: Checkout code
25+
uses: actions/checkout@v4
26+
27+
- name: Setup Python ${{ matrix.python-version }}
28+
uses: actions/setup-python@v5
29+
with:
30+
python-version: ${{ matrix.python-version }}
31+
cache: "pip"
32+
33+
- name: Install dependencies
34+
run: |
35+
pip install .
36+
pip install pytest pytest-asyncio
37+
38+
- name: Run tests
39+
run: |
40+
pytest tests/pipeline/test_optimization.py

.github/workflows/unit-tests.yaml

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
name: unit tests
2+
3+
on:
4+
push:
5+
branches:
6+
- dev
7+
pull_request:
8+
branches:
9+
- dev
10+
11+
jobs:
12+
test:
13+
runs-on: ${{ matrix.os }}
14+
strategy:
15+
fail-fast: false
16+
matrix:
17+
os: [ ubuntu-latest ]
18+
python-version: [ "3.10", "3.11", "3.12" ]
19+
include:
20+
- os: windows-latest
21+
python-version: "3.10"
22+
23+
steps:
24+
- name: Checkout code
25+
uses: actions/checkout@v4
26+
27+
- name: Setup Python ${{ matrix.python-version }}
28+
uses: actions/setup-python@v5
29+
with:
30+
python-version: ${{ matrix.python-version }}
31+
cache: "pip"
32+
33+
- name: Install dependencies
34+
run: |
35+
pip install .
36+
pip install pytest pytest-asyncio
37+
38+
- name: Run tests
39+
run: |
40+
pytest --ignore=tests/nodes --ignore=tests/pipeline

autointent/pipeline/optimization/utils/name.py renamed to autointent/configs/name.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import random
2+
from datetime import datetime
23

34
adjectives = [
45
"adorable",
@@ -342,3 +343,9 @@ def generate_name() -> str:
342343
adjective = random.choice(adjectives)
343344
noun = random.choice(nouns)
344345
return f"{adjective}_{noun}"
346+
347+
348+
def get_run_name(run_name: str | None = None) -> str:
349+
if run_name is None:
350+
run_name = generate_name()
351+
return f"{run_name}_{datetime.now().strftime('%m-%d-%Y_%H-%M-%S')}" # noqa: DTZ005

autointent/configs/node.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ class InferenceNodeConfig:
99
node_type: str = MISSING
1010
module_type: str = MISSING
1111
module_config: dict[str, Any] = MISSING
12-
load_path: str = MISSING
12+
load_path: str | None = None
1313
_target_: str = "autointent.nodes.InferenceNode"
1414

1515

autointent/configs/optimization_cli.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from hydra.core.config_store import ConfigStore
77
from omegaconf import MISSING
88

9-
from autointent.pipeline.optimization.utils import generate_name
9+
from .name import generate_name
1010

1111

1212
@dataclass
@@ -28,6 +28,8 @@ class LoggingConfig:
2828
run_name: str | None = None
2929
dirpath: Path | None = None
3030
dump_dir: Path | None = None
31+
dump_modules: bool = False
32+
clear_ram: bool = True
3133

3234
def __post_init__(self) -> None:
3335
self.define_run_name()
@@ -44,7 +46,6 @@ def define_dirpath(self) -> None:
4446
if self.run_name is None:
4547
raise ValueError
4648
self.dirpath = dirpath / self.run_name
47-
self.dirpath.mkdir(parents=True)
4849

4950
def define_dump_dir(self) -> None:
5051
if self.dump_dir is None:
@@ -57,6 +58,7 @@ def define_dump_dir(self) -> None:
5758
class VectorIndexConfig:
5859
db_dir: Path | None = None
5960
device: str = "cpu"
61+
save_db: bool = False
6062

6163

6264
@dataclass

autointent/context/context.py

Lines changed: 133 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,79 @@
1+
import json
2+
import logging
13
from dataclasses import asdict
24
from pathlib import Path
35
from typing import Any
46

7+
import yaml
8+
9+
from autointent.configs.optimization_cli import (
10+
AugmentationConfig,
11+
DataConfig,
12+
EmbedderConfig,
13+
LoggingConfig,
14+
VectorIndexConfig,
15+
)
16+
517
from .data_handler import DataAugmenter, DataHandler, Dataset
618
from .optimization_info import OptimizationInfo
19+
from .utils import NumpyEncoder, load_data
720
from .vector_index_client import VectorIndex, VectorIndexClient
821

922

1023
class Context:
11-
def __init__( # noqa: PLR0913
24+
data_handler: DataHandler
25+
vector_index_client: VectorIndexClient
26+
optimization_info: OptimizationInfo
27+
28+
def __init__(
1229
self,
13-
dataset: Dataset,
14-
test_dataset: Dataset | None = None,
15-
device: str = "cpu",
16-
multilabel_generation_config: str | None = None,
17-
regex_sampling: int = 0,
1830
seed: int = 42,
19-
db_dir: str | Path | None = None,
20-
dump_dir: str | Path | None = None,
21-
force_multilabel: bool = False,
22-
embedder_batch_size: int = 32,
23-
embedder_max_length: int | None = None,
2431
) -> None:
25-
augmenter = DataAugmenter(multilabel_generation_config, regex_sampling, seed)
32+
self.seed = seed
33+
self._logger = logging.getLogger(__name__)
34+
35+
def configure_logging(self, config: LoggingConfig) -> None:
36+
self.logging_config = config
37+
self.optimization_info = OptimizationInfo()
38+
39+
def configure_vector_index(self, config: VectorIndexConfig, embedder_config: EmbedderConfig | None = None) -> None:
40+
self.vector_index_config = config
41+
if embedder_config is None:
42+
embedder_config = EmbedderConfig()
43+
self.embedder_config = embedder_config
44+
45+
self.vector_index_client = VectorIndexClient(
46+
self.vector_index_config.device,
47+
self.vector_index_config.db_dir,
48+
self.embedder_config.batch_size,
49+
self.embedder_config.max_length,
50+
)
51+
52+
def configure_data(self, config: DataConfig, augmentation_config: AugmentationConfig | None = None) -> None:
53+
if augmentation_config is not None:
54+
self.augmentation_config = AugmentationConfig()
55+
augmenter = DataAugmenter(
56+
self.augmentation_config.multilabel_generation_config,
57+
self.augmentation_config.regex_sampling,
58+
self.seed,
59+
)
60+
else:
61+
augmenter = None
62+
2663
self.data_handler = DataHandler(
27-
dataset, test_dataset, random_seed=seed, force_multilabel=force_multilabel, augmenter=augmenter
64+
dataset=load_data(config.train_path),
65+
test_dataset=None if config.test_path is None else load_data(config.test_path),
66+
random_seed=self.seed,
67+
force_multilabel=config.force_multilabel,
68+
augmenter=augmenter,
69+
)
70+
71+
def set_datasets(
72+
self, train_data: Dataset, val_data: Dataset | None = None, force_multilabel: bool = False
73+
) -> None:
74+
self.data_handler = DataHandler(
75+
dataset=train_data, test_dataset=val_data, random_seed=self.seed, force_multilabel=force_multilabel
2876
)
29-
self.optimization_info = OptimizationInfo()
30-
self.vector_index_client = VectorIndexClient(device, db_dir, embedder_batch_size, embedder_max_length)
31-
32-
self.db_dir = self.vector_index_client.db_dir
33-
self.embedder_max_length = embedder_max_length
34-
self.embedder_batch_size = embedder_batch_size
35-
self.device = device
36-
self.multilabel = self.data_handler.multilabel
37-
self.n_classes = self.data_handler.n_classes
38-
self.seed = seed
39-
self.dump_dir = Path.cwd() / "modules_dumps" if dump_dir is None else Path(dump_dir)
4077

4178
def get_best_index(self) -> VectorIndex:
4279
model_name = self.optimization_info.get_best_embedder()
@@ -48,10 +85,79 @@ def get_inference_config(self) -> dict[str, Any]:
4885
cfg.pop("_target_")
4986
return {
5087
"metadata": {
51-
"device": self.device,
52-
"multilabel": self.multilabel,
53-
"n_classes": self.n_classes,
88+
"device": self.get_device(),
89+
"multilabel": self.is_multilabel(),
90+
"n_classes": self.get_n_classes(),
5491
"seed": self.seed,
5592
},
5693
"nodes_configs": nodes_configs,
5794
}
95+
96+
def dump(self) -> None:
97+
self._logger.debug("dumping logs...")
98+
optimization_results = self.optimization_info.dump_evaluation_results()
99+
100+
logs_dir = self.logging_config.dirpath
101+
if logs_dir is None:
102+
msg = "something's wrong with LoggingConfig"
103+
raise ValueError(msg)
104+
105+
# create appropriate directory
106+
logs_dir.mkdir(parents=True, exist_ok=True)
107+
108+
# dump search space and evaluation results
109+
logs_path = logs_dir / "logs.json"
110+
with logs_path.open("w") as file:
111+
json.dump(optimization_results, file, indent=4, ensure_ascii=False, cls=NumpyEncoder)
112+
# config_path = logs_dir / "config.yaml"
113+
# with config_path.open("w") as file:
114+
# yaml.dump(self.config, file)
115+
116+
# self._logger.info(make_report(optimization_results, nodes=nodes))
117+
118+
# dump train and test data splits
119+
train_data, test_data = self.data_handler.dump()
120+
train_path = logs_dir / "train_data.json"
121+
test_path = logs_dir / "test_data.json"
122+
with train_path.open("w") as file:
123+
json.dump(train_data, file, indent=4, ensure_ascii=False)
124+
with test_path.open("w") as file:
125+
json.dump(test_data, file, indent=4, ensure_ascii=False)
126+
127+
self._logger.info("logs and other assets are saved to %s", logs_dir)
128+
129+
# dump optimization results (config for inference)
130+
inference_config = self.get_inference_config()
131+
inference_config_path = logs_dir / "inference_config.yaml"
132+
with inference_config_path.open("w") as file:
133+
yaml.dump(inference_config, file)
134+
135+
def get_db_dir(self) -> Path:
136+
return self.vector_index_client.db_dir
137+
138+
def get_device(self) -> str:
139+
return self.vector_index_client.device
140+
141+
def get_batch_size(self) -> int:
142+
return self.vector_index_client.embedder_batch_size
143+
144+
def get_max_length(self) -> int | None:
145+
return self.vector_index_client.embedder_max_length
146+
147+
def get_dump_dir(self) -> Path | None:
148+
if self.logging_config.dump_modules:
149+
return self.logging_config.dump_dir
150+
return None
151+
152+
def is_multilabel(self) -> bool:
153+
return self.data_handler.multilabel
154+
155+
def get_n_classes(self) -> int:
156+
return self.data_handler.n_classes
157+
158+
def is_ram_to_clear(self) -> bool:
159+
return self.logging_config.clear_ram
160+
161+
def has_saved_modules(self) -> bool:
162+
node_types = ["regexp", "retrieval", "scoring", "prediction"]
163+
return any(len(self.optimization_info.modules.get(nt)) > 0 for nt in node_types)

0 commit comments

Comments
 (0)