Skip to content

Commit 8580f19

Browse files
Samoedgithub-actions[bot]voorhsSeBorgeyDarinochka
authored
add interruption handling (#169)
* add interruption handling * fix test * fix test * update * fix test * lint * remove step * use patch instead of monkeypatch * add n_jobs as param * change n_jobs to -1 * try fix * remove old study * add logging warning * Update optimizer_config.schema.json * lint * try dumping * lint * np encoder * update warning trigger * Fix/n trials issue (#196) * try to fix * fix typing errors * bug fix * Update autointent/nodes/_node_optimizer.py Co-authored-by: Roman Solomatin <[email protected]> --------- Co-authored-by: Roman Solomatin <[email protected]> * Fix/context not dumped error (#197) * try to fix * dump context constantly and fix serialization issues * add exclude option to dumper * fix codestyle and typing errors * try to fix file exists error * fix no fixture found error * Update interruption handling (#198) * full tuning (#165) * Added code for full tuning * work on review * renaming * fix ruff * mypy test * ignote mypy * Feat/bert scorer config refactoring (#168) * refactor configs * add proper configs to BERTScorer * fix typing * fix tokenizer's parameters * fix transformers and accelerate issue * Update optimizer_config.schema.json * bug fix * update callback test * fix tests --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * delete validate_task * report_to * batches * Fix/docs building for bert scorer (#171) * fix * fix codestyle --------- Co-authored-by: Алексеев Илья <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * bert-scorer ending (#172) * batches * tests check * fix * return to torch * fix for tests * Fix/bert scorer (#174) * fix str and float issue and shrinken search space * update `inference node config` overriding logic * fix typing * fix codestyle * fix multilabel issue * attempt to fix `inference node config` bugs * another attempt --------- Co-authored-by: Алексеев Илья <[email protected]> * Feat/code carbon each node (#175) * feat: update codecarbon * feat: update codecarbon * feat: added codecarbon * Update optimizer_config.schema.json * fix: fixed import mypy * fix: codecarbon package * fix: only float\integer log * fix: codecarbon package * fix: mypy * fix: test * fix: delete emissions * fix: test --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * standartize pyproject & speedup tests (#176) * speedup tests * fix pyproject * Update optimizer_config.schema.json * move optional dependencies * fixes * add xdist * fix ci * download data from hub in doc * add caching * add doc cache --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: voorhs <[email protected]> * add proper `omit` definition for tests coverage report (#179) * add proper `omit` definition * Update optimizer_config.schema.json * exclude tmp from coverage report --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * add node validators (#177) * add node validators * add comments * Update optimizer_config.schema.json * rename bert model * lint * fixes * fix test --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: voorhs <[email protected]> * dumper saving (#180) * added main code for saving models * Update optimizer_config.schema.json * checker fixes * Revert "checker fixes" This reverts commit 6e32eb9. * Revert "added main code for saving models" This reverts commit 5637fb8. * drat main code for new dumper * ruf fix * comments * added code for test dumper * Check dumper (#182) * Feat/code carbon each node (#175) * feat: update codecarbon * feat: update codecarbon * feat: added codecarbon * Update optimizer_config.schema.json * fix: fixed import mypy * fix: codecarbon package * fix: only float\integer log * fix: codecarbon package * fix: mypy * fix: test * fix: delete emissions * fix: test --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * standartize pyproject & speedup tests (#176) * speedup tests * fix pyproject * Update optimizer_config.schema.json * move optional dependencies * fixes * add xdist * fix ci * download data from hub in doc * add caching * add doc cache --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: voorhs <[email protected]> * add proper `omit` definition for tests coverage report (#179) * add proper `omit` definition * Update optimizer_config.schema.json * exclude tmp from coverage report --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * add node validators (#177) * add node validators * add comments * Update optimizer_config.schema.json * rename bert model * lint * fixes * fix test --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: voorhs <[email protected]> * update makefile * update bert test * mypy workaround * attempt to fix windows permission error * workaround --------- Co-authored-by: Darinochka <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Roman Solomatin <[email protected]> --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Алексеев Илья <[email protected]> Co-authored-by: Darinochka <[email protected]> Co-authored-by: Roman Solomatin <[email protected]> * Update embedder prompt (#183) * Add trust remote code (#185) * lint * fix trust remote code * Update optimizer_config.schema.json * update fix trust remote code * fix test cllback --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * Remove autointent org from docs (#186) * lint * update paths * feat: added crossencoder (#181) * feat: added crossencoder * refactor * feat: added arg similarity * Update optimizer_config.schema.json * feat: added tests * feat: added errors * fix: scoring test * fix: description vectors error * fix: description vectors error * fix: lint * fix: test * add node validators (#177) * add node validators * add comments * Update optimizer_config.schema.json * rename bert model * lint * fixes * fix test --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: voorhs <[email protected]> * fix: unit tests * feat: added test for description * feat: delete encoder_type from the class args * feat: update assets * feat: update assets * fix: fixed test * Update optimizer_config.schema.json --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Roman Solomatin <[email protected]> Co-authored-by: voorhs <[email protected]> * Add few shot (#187) * init few shot * Update optimizer_config.schema.json * apply few shot to all * Update optimizer_config.schema.json * fix test * lint --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * update numpy typing (#188) * Lora scorer (#170) * added lora scorer * fix ruff * Update __init__.py * updated after mr #165 * Update pyproject.toml * fixed requested changes * fixed ruff failing * fixed remarks * Update optimizer_config.schema.json * added test * ruff fix * convert labels to float * Update autointent/modules/scoring/_lora/lora.py Co-authored-by: Roman Solomatin <[email protected]> * Update autointent/modules/scoring/_lora/lora.py Co-authored-by: Roman Solomatin <[email protected]> * change model_config name, added trust_remote_code * Update lora.py * inherited lora from bert * fix ruff * fix search space * Update lora.py * Update lora.py * added dump check * Update test_lora.py * Update test_lora.py * added docstring * fix ruff * Update test_lora.py * Update test_lora.py --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Roman Solomatin <[email protected]> * PTuningScorer (#178) * Initial commit of PTuningScorer module * Added peft (>=0.10.0, <0.15.0) in dependencies * Implement fit/predict PTuningScorer * Added PTuningScorer in __init__ file * Update optimizer_config.schema.json * Minor fixs * PGH00 * Refactor clear_cache in fit method * Refactor typing ignore + remove unnecessary * Fix fit method status check * Added test for PTuningScorer * Fix mypy typing * Update and fix peft version dependencies * Fix mypy typing * Added test in multiclass.yaml, multilabel.yaml * Update docs strings * Fix mypy typing * Added trust_remote_code * make proper rst reference * Added test for dump lod * feat: added crossencoder (#181) * feat: added crossencoder * refactor * feat: added arg similarity * Update optimizer_config.schema.json * feat: added tests * feat: added errors * fix: scoring test * fix: description vectors error * fix: description vectors error * fix: lint * fix: test * add node validators (#177) * add node validators * add comments * Update optimizer_config.schema.json * rename bert model * lint * fixes * fix test --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: voorhs <[email protected]> * fix: unit tests * feat: added test for description * feat: delete encoder_type from the class args * feat: update assets * feat: update assets * fix: fixed test * Update optimizer_config.schema.json --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Roman Solomatin <[email protected]> Co-authored-by: voorhs <[email protected]> * Added fixed seed to test reproduction * Pull LoraScorer and Bert Refactor * Refactor PTuningScorer * Refactor test for ptuning * Fix typing * Fix multilabel multiclass tests * Fix typing --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: voorhs <[email protected]> Co-authored-by: Darinochka <[email protected]> Co-authored-by: Roman Solomatin <[email protected]> * Rerank scorer: опция для выбора источника для расчета вектора вероятностей (#115) * Enable rerank scorer to use crossencoder scores for the probability vector * add cross encoder scores range options * upd test --------- Co-authored-by: voorhs <[email protected]> * feat: add DISABLE_EMISSIONS_TRACKING (#191) * feat: add DISABLE_EMISSIONS_TRACKING * try to fix docs error * Update optimizer_config.schema.json * another attempt * Update optimizer_config.schema.json * i give up for now * Update optimizer_config.schema.json --------- Co-authored-by: voorhs <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * fix issue (#194) * Refactor/embedding caching (#195) * implement new hashing strategy * fix codestyle * Update optimizer_config.schema.json * minor bug fix * fix typing error * refactor similarity calculation * Update optimizer_config.schema.json * upd callback test * solve 429 error --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> * forgot something --------- Co-authored-by: Сергей Малышев <[email protected]> Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Darinochka <[email protected]> Co-authored-by: Roman Solomatin <[email protected]> Co-authored-by: VALERIA RUBANOVA <[email protected]> Co-authored-by: nikiduki <[email protected]> Co-authored-by: Dmitryv-2024 <[email protected]> --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Алексеев Илья <[email protected]> Co-authored-by: Сергей Малышев <[email protected]> Co-authored-by: Darinochka <[email protected]> Co-authored-by: VALERIA RUBANOVA <[email protected]> Co-authored-by: nikiduki <[email protected]> Co-authored-by: Dmitryv-2024 <[email protected]> Co-authored-by: voorhs <[email protected]>
1 parent 389eb40 commit 8580f19

File tree

10 files changed

+511
-81
lines changed

10 files changed

+511
-81
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,4 +179,6 @@ tests_logs
179179
tests/logs
180180
runs/
181181
vector_db*
182+
*.db
183+
*.sqlite
182184
/wandb

autointent/_dump_tools.py

Lines changed: 43 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import inspect
1+
import importlib
22
import json
33
import logging
44
from pathlib import Path
5-
from types import UnionType
6-
from typing import Any, TypeAlias, Union, get_args, get_origin
5+
from typing import Any, TypeAlias
76

87
import joblib
98
import numpy as np
@@ -37,11 +36,12 @@ class Dumper:
3736
hf_tokenizers = "hf_tokenizers"
3837

3938
@staticmethod
40-
def make_subdirectories(path: Path) -> None:
39+
def make_subdirectories(path: Path, exists_ok: bool = False) -> None:
4140
"""Make subdirectories for dumping.
4241
4342
Args:
4443
path: Path to make subdirectories in
44+
exists_ok: If True, do not raise an error if the directory already exists
4545
"""
4646
subdirectories = [
4747
path / Dumper.tags,
@@ -54,23 +54,27 @@ def make_subdirectories(path: Path) -> None:
5454
path / Dumper.hf_tokenizers,
5555
]
5656
for subdir in subdirectories:
57-
subdir.mkdir(parents=True, exist_ok=True)
57+
subdir.mkdir(parents=True, exist_ok=exists_ok)
5858

5959
@staticmethod
60-
def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901, PLR0912, PLR0915
60+
def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]] | None = None) -> None: # noqa: ANN401, C901, PLR0912, PLR0915
6161
"""Dump modules attributes to filestystem.
6262
6363
Args:
6464
obj: Object to dump
6565
path: Path to dump to
66+
exists_ok: If True, do not raise an error if the directory already exists
67+
exclude: List of types to exclude from dumping
6668
"""
6769
attrs: dict[str, ModuleAttributes] = vars(obj)
6870
simple_attrs = {}
6971
arrays: dict[str, npt.NDArray[Any]] = {}
7072

71-
Dumper.make_subdirectories(path)
73+
Dumper.make_subdirectories(path, exists_ok)
7274

7375
for key, val in attrs.items():
76+
if exclude and isinstance(val, tuple(exclude)):
77+
continue
7478
if isinstance(val, TagsList):
7579
val.dump(path / Dumper.tags / key)
7680
elif isinstance(val, ModuleSimpleAttributes):
@@ -85,10 +89,14 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901, PLR0912, PLR0915
8589
joblib.dump(val, path / Dumper.estimators / key)
8690
elif isinstance(val, Ranker):
8791
val.save(str(path / Dumper.cross_encoders / key))
88-
elif isinstance(val, CrossEncoderConfig | EmbedderConfig):
92+
elif isinstance(val, BaseModel):
8993
try:
90-
pydantic_path = path / Dumper.pydantic_models / f"{key}.json"
91-
with pydantic_path.open("w", encoding="utf-8") as file:
94+
class_info = {"name": val.__class__.__name__, "module": val.__class__.__module__}
95+
pydantic_path = path / Dumper.pydantic_models / key
96+
pydantic_path.mkdir(parents=True, exist_ok=exists_ok)
97+
with (pydantic_path / "class_info.json").open("w", encoding="utf-8") as file:
98+
json.dump(class_info, file, ensure_ascii=False, indent=4)
99+
with (pydantic_path / "model_dump.json").open("w", encoding="utf-8") as file:
92100
json.dump(val.model_dump(), file, ensure_ascii=False, indent=4)
93101
except Exception as e:
94102
msg = f"Error dumping pydantic model {key}: {e}"
@@ -125,7 +133,7 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901, PLR0912, PLR0915
125133
np.savez(path / Dumper.arrays, allow_pickle=False, **arrays)
126134

127135
@staticmethod
128-
def load( # noqa: PLR0912, C901, PLR0915
136+
def load( # noqa: C901, PLR0912, PLR0915
129137
obj: Any, # noqa: ANN401
130138
path: Path,
131139
embedder_config: EmbedderConfig | None = None,
@@ -166,41 +174,34 @@ def load( # noqa: PLR0912, C901, PLR0915
166174
for cross_encoder_dump in child.iterdir()
167175
}
168176
elif child.name == Dumper.pydantic_models:
169-
for model_file in child.iterdir():
170-
with model_file.open("r", encoding="utf-8") as file:
171-
content = json.load(file)
172-
variable_name = model_file.stem
173-
174-
# First try to get the type annotation from the class annotations.
175-
model_type = obj.__class__.__annotations__.get(variable_name)
176-
177-
# Fallback: inspect __init__ signature if not found in class-level annotations.
178-
if model_type is None:
179-
sig = inspect.signature(obj.__init__)
180-
if variable_name in sig.parameters:
181-
model_type = sig.parameters[variable_name].annotation
182-
if model_type is None:
183-
msg = f"No type annotation found for {variable_name}"
184-
logger.error(msg)
185-
continue
177+
for model_dir in child.iterdir():
178+
try:
179+
with (model_dir / "model_dump.json").open("r", encoding="utf-8") as file:
180+
content = json.load(file)
181+
182+
variable_name = model_dir.name
186183

187-
# If the annotation is a Union, extract the pydantic model type.
188-
if get_origin(model_type) in (UnionType, Union):
189-
for arg in get_args(model_type):
190-
if isinstance(arg, type) and issubclass(arg, BaseModel):
191-
model_type = arg
192-
break
193-
else:
194-
msg = f"No pydantic type found in Union for {variable_name}"
195-
logger.error(msg)
184+
with (model_dir / "class_info.json").open("r", encoding="utf-8") as file:
185+
class_info = json.load(file)
186+
187+
try:
188+
model_type = importlib.import_module(class_info["module"])
189+
model_type = getattr(model_type, class_info["name"])
190+
except (ImportError, AttributeError) as e:
191+
msg = f"Failed to import model type for {variable_name}: {e}"
192+
logger.exception(msg)
196193
continue
197194

198-
if not (isinstance(model_type, type) and issubclass(model_type, BaseModel)):
199-
msg = f"Type for {variable_name} is not a pydantic model: {model_type}"
200-
logger.error(msg)
195+
try:
196+
pydantic_models[variable_name] = model_type.model_validate(content)
197+
except Exception as e:
198+
msg = f"Failed to reconstruct Pydantic model {variable_name}: {e}"
199+
logger.exception(msg)
200+
continue
201+
except Exception as e:
202+
msg = f"Error loading Pydantic model from {model_dir}: {e}"
203+
logger.exception(msg)
201204
continue
202-
203-
pydantic_models[variable_name] = model_type(**content)
204205
elif child.name == Dumper.hf_models:
205206
for model_dir in child.iterdir():
206207
try:

autointent/_pipeline/_pipeline.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,14 @@ def _fit(self, context: Context, sampler: SamplerType) -> None:
144144
"""
145145
self.context = context
146146
self._logger.info("starting pipeline optimization...")
147+
148+
if not context.logging_config.dump_modules:
149+
self._logger.warning(
150+
"Memory storage is not compatible with resuming optimization. "
151+
"Modules from previous runs won't be available. "
152+
"Set dump_modules=True in LoggingConfig to enable proper resuming."
153+
)
154+
147155
self.context.callback_handler.start_run(
148156
run_name=self.context.logging_config.get_run_name(),
149157
dirpath=self.context.logging_config.dirpath,

autointent/context/_context.py

Lines changed: 18 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
"""Context manager for configuring and managing data handling, vector indexing, and optimization."""
22

3-
import json
43
import logging
54
from pathlib import Path
65

@@ -10,7 +9,6 @@
109
from autointent._callbacks import CallbackHandler, get_callbacks
1110
from autointent.configs import CrossEncoderConfig, DataConfig, EmbedderConfig, LoggingConfig
1211

13-
from ._utils import NumpyEncoder
1412
from .data_handler import DataHandler
1513
from .optimization_info import OptimizationInfo
1614

@@ -77,15 +75,9 @@ def dump(self) -> None:
7775
Save metrics, hyperparameters, inference, configurations, and datasets to disk.
7876
"""
7977
self._logger.debug("dumping logs...")
80-
optimization_results = self.optimization_info.dump_evaluation_results()
81-
8278
logs_dir = self.logging_config.dirpath
83-
logs_dir.mkdir(parents=True, exist_ok=True)
84-
85-
logs_path = logs_dir / "logs.json"
86-
with logs_path.open("w") as file:
87-
json.dump(optimization_results, file, indent=4, ensure_ascii=False, cls=NumpyEncoder)
8879

80+
self.optimization_info.dump(logs_dir)
8981
self.data_handler.dataset.to_json(logs_dir / "dataset.json")
9082

9183
self._logger.info("logs and other assets are saved to %s", logs_dir)
@@ -95,6 +87,23 @@ def dump(self) -> None:
9587
with inference_config_path.open("w") as file:
9688
yaml.dump(inference_config, file)
9789

90+
def load(self) -> None:
91+
"""Restore the context state to resume the optimization process.
92+
93+
Raises:
94+
RuntimeError: If the modules artifacts are not found.
95+
"""
96+
self._logger.debug("loading logs...")
97+
logs_dir = self.logging_config.dirpath
98+
self.optimization_info.load(logs_dir)
99+
if not self.optimization_info.artifacts.has_artifacts():
100+
msg = (
101+
"It is impossible to continue from the previous point, "
102+
"start again with dump_modules=True settings if you want to resume the run."
103+
"To load optimization info only, use Context.optimization_info.load(logs_dir)."
104+
)
105+
raise RuntimeError(msg)
106+
98107
def get_dump_dir(self) -> Path | None:
99108
"""Get the directory for saving dumped modules.
100109

autointent/context/optimization_info/_data_models.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,32 @@ class ScorerArtifact(Artifact):
5252
None, description="Scores for each fold from cross-validation"
5353
)
5454

55+
def model_dump(self, **kwargs: Any) -> dict[str, Any]: # noqa: ANN401
56+
"""Convert the model to a dictionary, converting numpy arrays to lists."""
57+
data = super().model_dump(**kwargs)
58+
if data["train_scores"] is not None:
59+
data["train_scores"] = data["train_scores"].tolist()
60+
if data["validation_scores"] is not None:
61+
data["validation_scores"] = data["validation_scores"].tolist()
62+
if data["test_scores"] is not None:
63+
data["test_scores"] = data["test_scores"].tolist()
64+
if data["folded_scores"] is not None:
65+
data["folded_scores"] = [arr.tolist() for arr in data["folded_scores"]]
66+
return data
67+
68+
@classmethod
69+
def model_validate(cls, obj: dict[str, Any]) -> "ScorerArtifact":
70+
"""Convert lists back to numpy arrays during validation."""
71+
if obj.get("train_scores") is not None:
72+
obj["train_scores"] = np.array(obj["train_scores"])
73+
if obj.get("validation_scores") is not None:
74+
obj["validation_scores"] = np.array(obj["validation_scores"])
75+
if obj.get("test_scores") is not None:
76+
obj["test_scores"] = np.array(obj["test_scores"])
77+
if obj.get("folded_scores") is not None:
78+
obj["folded_scores"] = [np.array(arr) for arr in obj["folded_scores"]]
79+
return super().model_validate(obj)
80+
5581

5682
class DecisionArtifact(Artifact):
5783
"""Artifact containing outputs from the predictor node.
@@ -104,6 +130,31 @@ class Artifacts(BaseModel):
104130
scoring: list[ScorerArtifact] = []
105131
decision: list[DecisionArtifact] = []
106132

133+
def model_dump(self, **kwargs: Any) -> dict[str, Any]: # noqa: ANN401
134+
"""Convert the model to a dictionary, ensuring nested artifacts are properly serialized."""
135+
data = super().model_dump(**kwargs)
136+
for node_type in [NodeType.regex, NodeType.embedding, NodeType.scoring, NodeType.decision]:
137+
artifacts = getattr(self, node_type.value)
138+
data[node_type.value] = [artifact.model_dump(**kwargs) for artifact in artifacts]
139+
return data
140+
141+
@classmethod
142+
def model_validate(cls, obj: dict[str, Any]) -> "Artifacts":
143+
"""Convert the dictionary back to an Artifacts instance, ensuring nested artifacts are properly deserialized."""
144+
# First convert the lists back to numpy arrays in the scoring artifacts
145+
if "scoring" in obj:
146+
for artifact in obj["scoring"]:
147+
if artifact.get("train_scores") is not None:
148+
artifact["train_scores"] = np.array(artifact["train_scores"])
149+
if artifact.get("validation_scores") is not None:
150+
artifact["validation_scores"] = np.array(artifact["validation_scores"])
151+
if artifact.get("test_scores") is not None:
152+
artifact["test_scores"] = np.array(artifact["test_scores"])
153+
if artifact.get("folded_scores") is not None:
154+
artifact["folded_scores"] = [np.array(arr) for arr in artifact["folded_scores"]]
155+
156+
return super().model_validate(obj)
157+
107158
def add_artifact(self, node_type: str, artifact: Artifact) -> None:
108159
"""Add an artifact to the specified node type.
109160
@@ -136,6 +187,15 @@ def get_best_artifact(self, node_type: str, idx: int) -> Artifact:
136187
"""
137188
return self.get_artifacts(node_type)[idx]
138189

190+
def has_artifacts(self) -> bool:
191+
"""Check if any artifacts have been saved in RAM.
192+
193+
Returns:
194+
True if any artifacts exist, False otherwise.
195+
"""
196+
node_types = [NodeType.regex, NodeType.embedding, NodeType.scoring, NodeType.decision]
197+
return any(len(self.get_artifacts(nt)) > 0 for nt in node_types)
198+
139199

140200
class Trial(BaseModel):
141201
"""Representation of an individual optimization trial.

autointent/context/optimization_info/_optimization_info.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,16 @@
44
trials, and modules during the pipeline's execution.
55
"""
66

7+
import json
78
import logging
89
from dataclasses import dataclass, field
10+
from pathlib import Path
911
from typing import TYPE_CHECKING, Any
1012

1113
import numpy as np
1214
from numpy.typing import NDArray
1315

16+
from autointent._dump_tools import Dumper
1417
from autointent.configs import EmbedderConfig, InferenceNodeConfig
1518
from autointent.custom_types import NodeType
1619

@@ -20,6 +23,9 @@
2023
from autointent.modules.base import BaseModule
2124

2225

26+
logger = logging.getLogger(__name__)
27+
28+
2329
@dataclass
2430
class ModulesList:
2531
"""Container for managing lists of modules for each node type.
@@ -56,6 +62,19 @@ def add_module(self, node_type: str, module: "BaseModule") -> None:
5662
"""
5763
self.get(node_type).append(module)
5864

65+
def model_dump(self) -> dict[str, list["BaseModule"]]:
66+
"""Dump the modules to a dictionary format.
67+
68+
Returns:
69+
Dictionary representation of the modules.
70+
"""
71+
return {
72+
"regex": self.regex,
73+
"embedding": self.embedding,
74+
"scoring": self.scoring,
75+
"decision": self.decision,
76+
}
77+
5978

6079
class OptimizationInfo:
6180
"""Tracks optimization results, including trials, artifacts, and modules.
@@ -73,8 +92,6 @@ class OptimizationInfo:
7392

7493
def __init__(self) -> None:
7594
"""Initialize optimization info."""
76-
self._logger = logging.getLogger(__name__)
77-
7895
self.artifacts = Artifacts()
7996
self.trials = Trials()
8097
self._trials_best_ids = TrialsIds()
@@ -115,7 +132,7 @@ def log_module_optimization(
115132
metrics=metrics,
116133
)
117134
self.trials.add_trial(node_type, trial)
118-
self._logger.debug("module %s fitted and saved to optimization info", module_name, extra=trial.model_dump())
135+
logger.debug("module %s fitted and saved to optimization info %s", module_name, json.dumps(trial.model_dump()))
119136

120137
if module:
121138
self.modules.add_module(node_type, module)
@@ -225,8 +242,19 @@ def dump_evaluation_results(self) -> dict[str, Any]:
225242
"pipeline_metrics": self.pipeline_metrics,
226243
"metrics": node_wise_metrics,
227244
"configs": self.trials.model_dump(),
245+
"artifacts": self.artifacts.model_dump(),
246+
"modules": self.modules.model_dump(),
228247
}
229248

249+
def dump(self, path: Path) -> None:
250+
"""Dump the optimization information to a file."""
251+
exclude = [ModulesList]
252+
Dumper.dump(self, path / "optimization_info", exists_ok=True, exclude=exclude)
253+
254+
def load(self, path: Path) -> None:
255+
"""Load the optimization information from a file."""
256+
Dumper.load(self, path / "optimization_info")
257+
230258
def get_inference_nodes_config(self, asdict: bool = False) -> list[InferenceNodeConfig]:
231259
"""Generate configuration for inference nodes based on the best trials.
232260

0 commit comments

Comments
 (0)