Skip to content

Commit d4249aa

Browse files
authored
Fix/context not dumped error (#197)
* try to fix * dump context constantly and fix serialization issues * add exclude option to dumper * fix codestyle and typing errors * try to fix file exists error * fix no fixture found error
1 parent b11f845 commit d4249aa

File tree

6 files changed

+129
-67
lines changed

6 files changed

+129
-67
lines changed

autointent/_dump_tools.py

Lines changed: 44 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
import inspect
1+
import importlib
22
import json
33
import logging
44
from pathlib import Path
5-
from types import UnionType
6-
from typing import Any, TypeAlias, Union, get_args, get_origin
5+
from typing import Any, TypeAlias
76

87
import joblib
98
import numpy as np
@@ -13,7 +12,6 @@
1312

1413
from autointent import Embedder, Ranker, VectorIndex
1514
from autointent.configs import CrossEncoderConfig, EmbedderConfig
16-
from autointent.context._utils import NumpyEncoder
1715
from autointent.schemas import TagsList
1816

1917
ModuleSimpleAttributes = None | str | int | float | bool | list # type: ignore[type-arg]
@@ -36,11 +34,12 @@ class Dumper:
3634
pydantic_models: str = "pydantic"
3735

3836
@staticmethod
39-
def make_subdirectories(path: Path) -> None:
37+
def make_subdirectories(path: Path, exists_ok: bool = False) -> None:
4038
"""Make subdirectories for dumping.
4139
4240
Args:
4341
path: Path to make subdirectories in
42+
exists_ok: If True, do not raise an error if the directory already exists
4443
"""
4544
subdirectories = [
4645
path / Dumper.tags,
@@ -51,23 +50,27 @@ def make_subdirectories(path: Path) -> None:
5150
path / Dumper.pydantic_models,
5251
]
5352
for subdir in subdirectories:
54-
subdir.mkdir(parents=True, exist_ok=True)
53+
subdir.mkdir(parents=True, exist_ok=exists_ok)
5554

5655
@staticmethod
57-
def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
56+
def dump(obj: Any, path: Path, exists_ok: bool = False, exclude: list[type[Any]] | None = None) -> None: # noqa: ANN401, C901
5857
"""Dump modules attributes to filestystem.
5958
6059
Args:
6160
obj: Object to dump
6261
path: Path to dump to
62+
exists_ok: If True, do not raise an error if the directory already exists
63+
exclude: List of types to exclude from dumping
6364
"""
6465
attrs: dict[str, ModuleAttributes] = vars(obj)
6566
simple_attrs = {}
6667
arrays: dict[str, npt.NDArray[Any]] = {}
6768

68-
Dumper.make_subdirectories(path)
69+
Dumper.make_subdirectories(path, exists_ok)
6970

7071
for key, val in attrs.items():
72+
if exclude and isinstance(val, tuple(exclude)):
73+
continue
7174
if isinstance(val, TagsList):
7275
val.dump(path / Dumper.tags / key)
7376
elif isinstance(val, ModuleSimpleAttributes):
@@ -84,9 +87,13 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
8487
val.save(str(path / Dumper.cross_encoders / key))
8588
elif isinstance(val, BaseModel):
8689
try:
87-
pydantic_path = path / Dumper.pydantic_models / f"{key}.json"
88-
with pydantic_path.open("w", encoding="utf-8") as file:
89-
json.dump(val.model_dump(), file, ensure_ascii=False, indent=4, cls=NumpyEncoder)
90+
class_info = {"name": val.__class__.__name__, "module": val.__class__.__module__}
91+
pydantic_path = path / Dumper.pydantic_models / key
92+
pydantic_path.mkdir(parents=True, exist_ok=exists_ok)
93+
with (pydantic_path / "class_info.json").open("w", encoding="utf-8") as file:
94+
json.dump(class_info, file, ensure_ascii=False, indent=4)
95+
with (pydantic_path / "model_dump.json").open("w", encoding="utf-8") as file:
96+
json.dump(val.model_dump(), file, ensure_ascii=False, indent=4)
9097
except Exception as e:
9198
msg = f"Error dumping pydantic model {key}: {e}"
9299
logging.exception(msg)
@@ -100,7 +107,7 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
100107
np.savez(path / Dumper.arrays, allow_pickle=False, **arrays)
101108

102109
@staticmethod
103-
def load( # noqa: PLR0912, C901, PLR0915
110+
def load( # noqa: C901, PLR0912, PLR0915
104111
obj: Any, # noqa: ANN401
105112
path: Path,
106113
embedder_config: EmbedderConfig | None = None,
@@ -139,42 +146,34 @@ def load( # noqa: PLR0912, C901, PLR0915
139146
for cross_encoder_dump in child.iterdir()
140147
}
141148
elif child.name == Dumper.pydantic_models:
142-
for model_file in child.iterdir():
143-
with model_file.open("r", encoding="utf-8") as file:
144-
content = json.load(file)
145-
variable_name = model_file.stem
146-
147-
# First try to get the type annotation from the class annotations.
148-
model_type = obj.__class__.__annotations__.get(variable_name)
149-
150-
# Fallback: inspect __init__ signature if not found in class-level annotations.
151-
if model_type is None:
152-
sig = inspect.signature(obj.__init__)
153-
if variable_name in sig.parameters:
154-
model_type = sig.parameters[variable_name].annotation
155-
156-
if model_type is None:
157-
msg = f"No type annotation found for {variable_name}"
158-
logger.error(msg)
159-
continue
160-
161-
# If the annotation is a Union, extract the pydantic model type.
162-
if get_origin(model_type) in (UnionType, Union):
163-
for arg in get_args(model_type):
164-
if isinstance(arg, type) and issubclass(arg, BaseModel):
165-
model_type = arg
166-
break
167-
else:
168-
msg = f"No pydantic type found in Union for {variable_name}"
169-
logger.error(msg)
149+
for model_dir in child.iterdir():
150+
try:
151+
with (model_dir / "model_dump.json").open("r", encoding="utf-8") as file:
152+
content = json.load(file)
153+
154+
variable_name = model_dir.name
155+
156+
with (model_dir / "class_info.json").open("r", encoding="utf-8") as file:
157+
class_info = json.load(file)
158+
159+
try:
160+
model_type = importlib.import_module(class_info["module"])
161+
model_type = getattr(model_type, class_info["name"])
162+
except (ImportError, AttributeError) as e:
163+
msg = f"Failed to import model type for {variable_name}: {e}"
164+
logger.exception(msg)
170165
continue
171166

172-
if not (isinstance(model_type, type) and issubclass(model_type, BaseModel)):
173-
msg = f"Type for {variable_name} is not a pydantic model: {model_type}"
174-
logger.error(msg)
167+
try:
168+
pydantic_models[variable_name] = model_type.model_validate(content)
169+
except Exception as e:
170+
msg = f"Failed to reconstruct Pydantic model {variable_name}: {e}"
171+
logger.exception(msg)
172+
continue
173+
except Exception as e:
174+
msg = f"Error loading Pydantic model from {model_dir}: {e}"
175+
logger.exception(msg)
175176
continue
176-
177-
pydantic_models[variable_name] = model_type(**content)
178177
else:
179178
msg = f"Found unexpected child {child}"
180179
logger.error(msg)

autointent/context/_context.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,19 @@ def dump(self) -> None:
8888
yaml.dump(inference_config, file)
8989

9090
def load(self) -> None:
91-
"""Load all information about optimization process from disk."""
91+
"""Restore the context state to resume the optimization process.
92+
93+
Raises:
94+
RuntimeError: If the modules artifacts are not found.
95+
"""
9296
self._logger.debug("loading logs...")
9397
logs_dir = self.logging_config.dirpath
9498
self.optimization_info.load(logs_dir)
9599
if not self.optimization_info.artifacts.has_artifacts():
96100
msg = (
97101
"It is impossible to continue from the previous point, "
98-
"start again with dump_modules=True settings if you want to resume the run"
102+
"start again with dump_modules=True settings if you want to resume the run."
103+
"To load optimization info only, use Context.optimization_info.load(logs_dir)."
99104
)
100105
raise RuntimeError(msg)
101106

autointent/context/optimization_info/_data_models.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,32 @@ class ScorerArtifact(Artifact):
5252
None, description="Scores for each fold from cross-validation"
5353
)
5454

55+
def model_dump(self, **kwargs: Any) -> dict[str, Any]: # noqa: ANN401
56+
"""Convert the model to a dictionary, converting numpy arrays to lists."""
57+
data = super().model_dump(**kwargs)
58+
if data["train_scores"] is not None:
59+
data["train_scores"] = data["train_scores"].tolist()
60+
if data["validation_scores"] is not None:
61+
data["validation_scores"] = data["validation_scores"].tolist()
62+
if data["test_scores"] is not None:
63+
data["test_scores"] = data["test_scores"].tolist()
64+
if data["folded_scores"] is not None:
65+
data["folded_scores"] = [arr.tolist() for arr in data["folded_scores"]]
66+
return data
67+
68+
@classmethod
69+
def model_validate(cls, obj: dict[str, Any]) -> "ScorerArtifact":
70+
"""Convert lists back to numpy arrays during validation."""
71+
if obj.get("train_scores") is not None:
72+
obj["train_scores"] = np.array(obj["train_scores"])
73+
if obj.get("validation_scores") is not None:
74+
obj["validation_scores"] = np.array(obj["validation_scores"])
75+
if obj.get("test_scores") is not None:
76+
obj["test_scores"] = np.array(obj["test_scores"])
77+
if obj.get("folded_scores") is not None:
78+
obj["folded_scores"] = [np.array(arr) for arr in obj["folded_scores"]]
79+
return super().model_validate(obj)
80+
5581

5682
class DecisionArtifact(Artifact):
5783
"""Artifact containing outputs from the predictor node.
@@ -104,6 +130,31 @@ class Artifacts(BaseModel):
104130
scoring: list[ScorerArtifact] = []
105131
decision: list[DecisionArtifact] = []
106132

133+
def model_dump(self, **kwargs: Any) -> dict[str, Any]: # noqa: ANN401
134+
"""Convert the model to a dictionary, ensuring nested artifacts are properly serialized."""
135+
data = super().model_dump(**kwargs)
136+
for node_type in [NodeType.regex, NodeType.embedding, NodeType.scoring, NodeType.decision]:
137+
artifacts = getattr(self, node_type.value)
138+
data[node_type.value] = [artifact.model_dump(**kwargs) for artifact in artifacts]
139+
return data
140+
141+
@classmethod
142+
def model_validate(cls, obj: dict[str, Any]) -> "Artifacts":
143+
"""Convert the dictionary back to an Artifacts instance, ensuring nested artifacts are properly deserialized."""
144+
# First convert the lists back to numpy arrays in the scoring artifacts
145+
if "scoring" in obj:
146+
for artifact in obj["scoring"]:
147+
if artifact.get("train_scores") is not None:
148+
artifact["train_scores"] = np.array(artifact["train_scores"])
149+
if artifact.get("validation_scores") is not None:
150+
artifact["validation_scores"] = np.array(artifact["validation_scores"])
151+
if artifact.get("test_scores") is not None:
152+
artifact["test_scores"] = np.array(artifact["test_scores"])
153+
if artifact.get("folded_scores") is not None:
154+
artifact["folded_scores"] = [np.array(arr) for arr in artifact["folded_scores"]]
155+
156+
return super().model_validate(obj)
157+
107158
def add_artifact(self, node_type: str, artifact: Artifact) -> None:
108159
"""Add an artifact to the specified node type.
109160

autointent/context/optimization_info/_optimization_info.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
trials, and modules during the pipeline's execution.
55
"""
66

7+
import json
78
import logging
89
from dataclasses import dataclass, field
910
from pathlib import Path
@@ -22,6 +23,9 @@
2223
from autointent.modules.base import BaseModule
2324

2425

26+
logger = logging.getLogger(__name__)
27+
28+
2529
@dataclass
2630
class ModulesList:
2731
"""Container for managing lists of modules for each node type.
@@ -88,8 +92,6 @@ class OptimizationInfo:
8892

8993
def __init__(self) -> None:
9094
"""Initialize optimization info."""
91-
self._logger = logging.getLogger(__name__)
92-
9395
self.artifacts = Artifacts()
9496
self.trials = Trials()
9597
self._trials_best_ids = TrialsIds()
@@ -130,7 +132,7 @@ def log_module_optimization(
130132
metrics=metrics,
131133
)
132134
self.trials.add_trial(node_type, trial)
133-
self._logger.debug("module %s fitted and saved to optimization info", module_name, extra=trial.model_dump())
135+
logger.debug("module %s fitted and saved to optimization info %s", module_name, json.dumps(trial.model_dump()))
134136

135137
if module:
136138
self.modules.add_module(node_type, module)
@@ -246,7 +248,8 @@ def dump_evaluation_results(self) -> dict[str, Any]:
246248

247249
def dump(self, path: Path) -> None:
248250
"""Dump the optimization information to a file."""
249-
Dumper.dump(self, path / "optimization_info")
251+
exclude = [ModulesList]
252+
Dumper.dump(self, path / "optimization_info", exists_ok=True, exclude=exclude)
250253

251254
def load(self, path: Path) -> None:
252255
"""Load the optimization information from a file."""

autointent/nodes/_node_optimizer.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import gc
44
import itertools as it
5+
import json
56
import logging
67
from abc import ABC, abstractmethod
78
from copy import deepcopy
@@ -140,7 +141,7 @@ def fit(self, context: Context, sampler: SamplerType = "brute", n_jobs: int = 1)
140141
Raises:
141142
AssertionError: If an invalid sampler type is provided.
142143
"""
143-
self._logger.info("Starting %s node optimization...", self.node_info.node_type)
144+
self._logger.info("Starting %s node optimization...", self.node_info.node_type.value)
144145
for search_space in deepcopy(self.modules_search_spaces):
145146
self._counter: int = 0
146147
module_name = search_space.pop("module_name")
@@ -163,21 +164,18 @@ def fit(self, context: Context, sampler: SamplerType = "brute", n_jobs: int = 1)
163164

164165
study, finished_trials, n_trials = load_or_create_study(
165166
study_name=f"{self.node_info.node_type}_{module_name}",
166-
storage_dir=context.get_dump_dir(),
167+
context=context,
167168
direction="maximize",
168169
sampler=sampler_instance,
169170
n_trials=n_trials,
170171
)
171172
self._counter = max(self._counter, finished_trials)
172173

173-
if n_trials == 0:
174-
context.load()
175-
176174
optuna.logging.set_verbosity(optuna.logging.WARNING)
177175
obj = partial(self.objective, module_name=module_name, search_space=search_space, context=context)
178176

179177
study.optimize(obj, n_trials=n_trials, n_jobs=n_jobs)
180-
context.dump()
178+
181179
self._logger.info("%s node optimization is finished!", self.node_info.node_type)
182180

183181
def objective(
@@ -200,7 +198,7 @@ def objective(
200198
"""
201199
config = self.suggest(trial, search_space)
202200

203-
self._logger.debug("Initializing %s module...", module_name)
201+
self._logger.debug("Initializing %s module with config: %s", module_name, json.dumps(config))
204202
module = self.node_info.modules_available[module_name].from_context(context, **config)
205203

206204
embedder_config = module.get_embedder_config()
@@ -235,6 +233,7 @@ def objective(
235233
module_dump_dir,
236234
module=module if not context.is_ram_to_clear() else None,
237235
)
236+
context.dump()
238237

239238
if context.is_ram_to_clear():
240239
module.clear_cache()
@@ -416,7 +415,7 @@ def get_storage_url(study_name: str, storage_dir: Path | None) -> str | None:
416415

417416
def load_or_create_study(
418417
study_name: str,
419-
storage_dir: Path | None,
418+
context: Context,
420419
sampler: optuna.samplers.BaseSampler,
421420
direction: str = "maximize",
422421
n_trials: int = 10,
@@ -425,7 +424,7 @@ def load_or_create_study(
425424
426425
Args:
427426
study_name: Name of the study
428-
storage_dir: Directory where study databases are stored
427+
context: Context object
429428
direction: Optimization direction (maximize or minimize)
430429
sampler: Optuna sampler instance
431430
n_trials: n_trials
@@ -436,7 +435,7 @@ def load_or_create_study(
436435
remaining_trials = n_trials
437436
finished_trials = 0
438437

439-
storage_url = get_storage_url(study_name, storage_dir)
438+
storage_url = get_storage_url(study_name, context.get_dump_dir())
440439

441440
try:
442441
# will catch exception if study does not exist
@@ -451,6 +450,8 @@ def load_or_create_study(
451450
finished_trials = max(t.number for t in study.trials) + 1
452451
# Calculate remaining trials if n_trials is specified
453452
remaining_trials = n_trials if n_trials is None else max(0, n_trials - len(study.trials))
453+
454+
context.load()
454455
return study, finished_trials, remaining_trials # noqa: TRY300
455456
except Exception: # noqa: BLE001
456457
# Create a new study if none exists

0 commit comments

Comments
 (0)