Skip to content

Commit 96822fd

Browse files
committed
try dumping
1 parent 0bdd785 commit 96822fd

File tree

5 files changed

+51
-14
lines changed

5 files changed

+51
-14
lines changed

autointent/_dump_tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def dump(obj: Any, path: Path) -> None: # noqa: ANN401, C901
8181
joblib.dump(val, path / Dumper.estimators / key)
8282
elif isinstance(val, Ranker):
8383
val.save(str(path / Dumper.cross_encoders / key))
84-
elif isinstance(val, CrossEncoderConfig | EmbedderConfig):
84+
elif isinstance(val, BaseModel):
8585
try:
8686
pydantic_path = path / Dumper.pydantic_models / f"{key}.json"
8787
with pydantic_path.open("w", encoding="utf-8") as file:

autointent/context/_context.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,15 +77,9 @@ def dump(self) -> None:
7777
Save metrics, hyperparameters, inference, configurations, and datasets to disk.
7878
"""
7979
self._logger.debug("dumping logs...")
80-
optimization_results = self.optimization_info.dump_evaluation_results()
81-
8280
logs_dir = self.logging_config.dirpath
83-
logs_dir.mkdir(parents=True, exist_ok=True)
84-
85-
logs_path = logs_dir / "logs.json"
86-
with logs_path.open("w") as file:
87-
json.dump(optimization_results, file, indent=4, ensure_ascii=False, cls=NumpyEncoder)
8881

82+
self.optimization_info.dump(logs_dir)
8983
self.data_handler.dataset.to_json(logs_dir / "dataset.json")
9084

9185
self._logger.info("logs and other assets are saved to %s", logs_dir)
@@ -95,6 +89,18 @@ def dump(self) -> None:
9589
with inference_config_path.open("w") as file:
9690
yaml.dump(inference_config, file)
9791

92+
def load(self) -> None:
93+
"""Load all information about optimization process from disk."""
94+
self._logger.debug("loading logs...")
95+
logs_dir = self.logging_config.dirpath
96+
self.optimization_info.load(logs_dir)
97+
if not self.optimization_info.artifacts.has_artifacts():
98+
msg = (
99+
"It is impossible to continue from the previous point, "
100+
"start again with dump_modules=True settings if you want to resume the run"
101+
)
102+
raise RuntimeError(msg)
103+
98104
def get_dump_dir(self) -> Path | None:
99105
"""Get the directory for saving dumped modules.
100106

autointent/context/optimization_info/_data_models.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,15 @@ def get_best_artifact(self, node_type: str, idx: int) -> Artifact:
136136
"""
137137
return self.get_artifacts(node_type)[idx]
138138

139+
def has_artifacts(self) -> bool:
140+
"""Check if any artifacts have been saved in RAM.
141+
142+
Returns:
143+
True if any artifacts exist, False otherwise.
144+
"""
145+
node_types = [NodeType.regex, NodeType.embedding, NodeType.scoring, NodeType.decision]
146+
return any(len(self.get_artifacts(nt)) > 0 for nt in node_types)
147+
139148

140149
class Trial(BaseModel):
141150
"""Representation of an individual optimization trial.

autointent/context/optimization_info/_optimization_info.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010

1111
import numpy as np
1212
from numpy.typing import NDArray
13+
from pathlib import Path
1314

1415
from autointent.configs import EmbedderConfig, InferenceNodeConfig
1516
from autointent.custom_types import NodeType
1617

1718
from ._data_models import Artifact, Artifacts, EmbeddingArtifact, ScorerArtifact, Trial, Trials, TrialsIds
19+
from ..._dump_tools import Dumper
1820

1921
if TYPE_CHECKING:
2022
from autointent.modules.base import BaseModule
@@ -56,6 +58,19 @@ def add_module(self, node_type: str, module: "BaseModule") -> None:
5658
"""
5759
self.get(node_type).append(module)
5860

61+
def model_dump(self) -> dict[str, list["BaseModule"]]:
62+
"""Dump the modules to a dictionary format.
63+
64+
Returns:
65+
Dictionary representation of the modules.
66+
"""
67+
return {
68+
"regex": self.regex,
69+
"embedding": self.embedding,
70+
"scoring": self.scoring,
71+
"decision": self.decision,
72+
}
73+
5974

6075
class OptimizationInfo:
6176
"""Tracks optimization results, including trials, artifacts, and modules.
@@ -225,8 +240,18 @@ def dump_evaluation_results(self) -> dict[str, Any]:
225240
"pipeline_metrics": self.pipeline_metrics,
226241
"metrics": node_wise_metrics,
227242
"configs": self.trials.model_dump(),
243+
"artifacts": self.artifacts.model_dump(),
244+
"modules": self.modules.model_dump(),
228245
}
229246

247+
def dump(self, path: Path) -> None:
248+
"""Dump the optimization information to a file."""
249+
Dumper.dump(self, path / "optimization_info")
250+
251+
def load(self, path: Path) -> None:
252+
"""Load the optimization information from a file."""
253+
Dumper.load(self, path / "optimization_info")
254+
230255
def get_inference_nodes_config(self, asdict: bool = False) -> list[InferenceNodeConfig]:
231256
"""Generate configuration for inference nodes based on the best trials.
232257

autointent/nodes/_node_optimizer.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,14 @@ def fit(self, context: Context, sampler: SamplerType = "brute", n_jobs: int = 1)
113113
)
114114
self._counter = max(self._counter, finished_trials)
115115

116-
# if n_trials == 0:
117-
# config = self.suggest(Trial(), search_space)
118-
#
119-
# module = self.node_info.modules_available[module_name]
120-
# module.load()
116+
if n_trials == 0:
117+
context.load()
121118

122119
optuna.logging.set_verbosity(optuna.logging.WARNING)
123120
obj = partial(self.objective, module_name=module_name, search_space=search_space, context=context)
124121

125122
study.optimize(obj, n_trials=n_trials, n_jobs=n_jobs)
126-
123+
context.dump()
127124
self._logger.info("%s node optimization is finished!", self.node_info.node_type)
128125

129126
def objective(

0 commit comments

Comments
 (0)