Skip to content

Commit bfb9f94

Browse files
authored
Refactor/artifacts (#199)
* get rid of trials_ids storing * bug fix * fix typing errors * adjust context loading to minimize errors * optimize memory occupied by modules' dumps * bug fix * add proper modules dumping * bug fix * remove nodes tests
1 parent 8580f19 commit bfb9f94

File tree

11 files changed

+159
-643
lines changed

11 files changed

+159
-643
lines changed

.github/workflows/test-nodes.yaml

Lines changed: 0 additions & 13 deletions
This file was deleted.

autointent/context/_context.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ def set_dataset(self, dataset: Dataset, config: DataConfig) -> None:
6969
"""
7070
self.data_handler = DataHandler(dataset=dataset, random_seed=self.seed, config=config)
7171

72+
def dump_optimization_info(self) -> None:
73+
"""Save optimization info to disk."""
74+
self.optimization_info.dump(self.logging_config.dirpath)
75+
7276
def dump(self) -> None:
7377
"""Save all information about optimization process to disk.
7478
@@ -77,7 +81,7 @@ def dump(self) -> None:
7781
self._logger.debug("dumping logs...")
7882
logs_dir = self.logging_config.dirpath
7983

80-
self.optimization_info.dump(logs_dir)
84+
self.dump_optimization_info()
8185
self.data_handler.dataset.to_json(logs_dir / "dataset.json")
8286

8387
self._logger.info("logs and other assets are saved to %s", logs_dir)
@@ -87,7 +91,7 @@ def dump(self) -> None:
8791
with inference_config_path.open("w") as file:
8892
yaml.dump(inference_config, file)
8993

90-
def load(self) -> None:
94+
def load_optimization_info(self) -> None:
9195
"""Restore the context state to resume the optimization process.
9296
9397
Raises:
@@ -124,7 +128,7 @@ def is_ram_to_clear(self) -> bool:
124128
def has_saved_modules(self) -> bool:
125129
"""Check if any modules have been saved in RAM."""
126130
node_types = ["regex", "embedding", "scoring", "decision"]
127-
return any(len(self.optimization_info.modules.get(nt)) > 0 for nt in node_types)
131+
return any(self.optimization_info.modules.get(nt) is not None for nt in node_types)
128132

129133
def resolve_embedder(self) -> EmbedderConfig:
130134
"""Resolve the embedder configuration.

autointent/context/optimization_info/_data_models.py

Lines changed: 41 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -114,78 +114,87 @@ def validate_node_name(value: str) -> str:
114114
class Artifacts(BaseModel):
115115
"""Container for storing and managing artifacts generated by pipeline nodes.
116116
117-
Modules hyperparams and outputs. The best ones are transmitted between nodes of the pipeline.
117+
Only stores the best artifact for each node type to optimize memory usage.
118+
The best ones are transmitted between nodes of the pipeline.
118119
119120
Attributes:
120-
regex: List of artifacts from the regex node.
121-
embedding: List of artifacts from the embedding node.
122-
scoring: List of artifacts from the scoring node.
123-
decision: List of artifacts from the decision node.
121+
regex: Best artifact from the regex node.
122+
embedding: Best artifact from the embedding node.
123+
scoring: Best artifact from the scoring node.
124+
decision: Best artifact from the decision node.
124125
"""
125126

126127
model_config = ConfigDict(arbitrary_types_allowed=True)
127128

128-
regex: list[RegexArtifact] = []
129-
embedding: list[EmbeddingArtifact] = []
130-
scoring: list[ScorerArtifact] = []
131-
decision: list[DecisionArtifact] = []
129+
regex: RegexArtifact | None = None
130+
embedding: EmbeddingArtifact | None = None
131+
scoring: ScorerArtifact | None = None
132+
decision: DecisionArtifact | None = None
132133

133134
def model_dump(self, **kwargs: Any) -> dict[str, Any]: # noqa: ANN401
134135
"""Convert the model to a dictionary, ensuring nested artifacts are properly serialized."""
135136
data = super().model_dump(**kwargs)
136137
for node_type in [NodeType.regex, NodeType.embedding, NodeType.scoring, NodeType.decision]:
137-
artifacts = getattr(self, node_type.value)
138-
data[node_type.value] = [artifact.model_dump(**kwargs) for artifact in artifacts]
138+
artifact = getattr(self, node_type.value)
139+
if artifact is not None:
140+
data[node_type.value] = artifact.model_dump(**kwargs)
141+
else:
142+
data[node_type.value] = None
139143
return data
140144

141145
@classmethod
142146
def model_validate(cls, obj: dict[str, Any]) -> "Artifacts":
143147
"""Convert the dictionary back to an Artifacts instance, ensuring nested artifacts are properly deserialized."""
144148
# First convert the lists back to numpy arrays in the scoring artifacts
145-
if "scoring" in obj:
146-
for artifact in obj["scoring"]:
147-
if artifact.get("train_scores") is not None:
148-
artifact["train_scores"] = np.array(artifact["train_scores"])
149-
if artifact.get("validation_scores") is not None:
150-
artifact["validation_scores"] = np.array(artifact["validation_scores"])
151-
if artifact.get("test_scores") is not None:
152-
artifact["test_scores"] = np.array(artifact["test_scores"])
153-
if artifact.get("folded_scores") is not None:
154-
artifact["folded_scores"] = [np.array(arr) for arr in artifact["folded_scores"]]
149+
if "scoring" in obj and obj["scoring"] is not None:
150+
if obj["scoring"].get("train_scores") is not None:
151+
obj["scoring"]["train_scores"] = np.array(obj["scoring"]["train_scores"])
152+
if obj["scoring"].get("validation_scores") is not None:
153+
obj["scoring"]["validation_scores"] = np.array(obj["scoring"]["validation_scores"])
154+
if obj["scoring"].get("test_scores") is not None:
155+
obj["scoring"]["test_scores"] = np.array(obj["scoring"]["test_scores"])
156+
if obj["scoring"].get("folded_scores") is not None:
157+
obj["scoring"]["folded_scores"] = [np.array(arr) for arr in obj["scoring"]["folded_scores"]]
155158

156159
return super().model_validate(obj)
157160

158161
def add_artifact(self, node_type: str, artifact: Artifact) -> None:
159-
"""Add an artifact to the specified node type.
162+
"""Add an artifact to the specified node type, replacing any existing artifact.
160163
161164
Args:
162165
node_type: Node type as a string.
163166
artifact: The artifact to add.
164167
"""
165-
self.get_artifacts(node_type).append(artifact)
168+
setattr(self, validate_node_name(node_type), artifact)
166169

167-
def get_artifacts(self, node_type: str) -> list[Artifact]:
168-
"""Retrieve all artifacts for a specified node type.
170+
def get_artifact(self, node_type: str) -> Artifact | None:
171+
"""Retrieve the artifact for a specified node type.
169172
170173
Args:
171174
node_type: Node type as a string.
172175
173176
Returns:
174-
A list of artifacts for the node type.
177+
The artifact for the node type, or None if no artifact exists.
175178
"""
176179
return getattr(self, validate_node_name(node_type)) # type: ignore[no-any-return]
177180

178-
def get_best_artifact(self, node_type: str, idx: int) -> Artifact:
179-
"""Retrieve the best artifact for a specified node type and index.
181+
def get_best_artifact(self, node_type: str) -> Artifact:
182+
"""Retrieve the artifact for a specified node type.
180183
181184
Args:
182185
node_type: Node type as a string.
183-
idx: Index of the artifact.
184186
185187
Returns:
186-
The best artifact.
188+
The artifact for the node type.
189+
190+
Raises:
191+
ValueError: If no artifact exists for the node type.
187192
"""
188-
return self.get_artifacts(node_type)[idx]
193+
artifact = self.get_artifact(node_type)
194+
if artifact is None:
195+
msg = f"No artifact for {node_type}"
196+
raise ValueError(msg)
197+
return artifact
189198

190199
def has_artifacts(self) -> bool:
191200
"""Check if any artifacts have been saved in RAM.
@@ -194,7 +203,7 @@ def has_artifacts(self) -> bool:
194203
True if any artifacts exist, False otherwise.
195204
"""
196205
node_types = [NodeType.regex, NodeType.embedding, NodeType.scoring, NodeType.decision]
197-
return any(len(self.get_artifacts(nt)) > 0 for nt in node_types)
206+
return any(self.get_artifact(nt) is not None for nt in node_types)
198207

199208

200209
class Trial(BaseModel):
@@ -263,39 +272,3 @@ def add_trial(self, node_type: str, trial: Trial) -> None:
263272
trial: The trial to add.
264273
"""
265274
self.get_trials(node_type).append(trial)
266-
267-
268-
class TrialsIds(BaseModel):
269-
"""Representation of the best trial IDs for each pipeline node.
270-
271-
Attributes:
272-
regex: Best trial index for the regex node.
273-
embedding: Best trial index for the embedding node.
274-
scoring: Best trial index for the scoring node.
275-
decision: Best trial index for the decision node.
276-
"""
277-
278-
regex: int | None = None
279-
embedding: int | None = None
280-
scoring: int | None = None
281-
decision: int | None = None
282-
283-
def get_best_trial_idx(self, node_type: str) -> int | None:
284-
"""Retrieve the best trial index for a specified node type.
285-
286-
Args:
287-
node_type: Node type as a string.
288-
289-
Returns:
290-
The index of the best trial, or None if not set.
291-
"""
292-
return getattr(self, validate_node_name(node_type)) # type: ignore[no-any-return]
293-
294-
def set_best_trial_idx(self, node_type: str, idx: int) -> None:
295-
"""Set the best trial index for a specified node type.
296-
297-
Args:
298-
node_type: Node type as a string.
299-
idx: Index of the best trial.
300-
"""
301-
setattr(self, validate_node_name(node_type), idx)

0 commit comments

Comments
 (0)