Skip to content

Commit 9a337f9

Browse files
committed
refactor: artifacts are added only for best module
1 parent 6b168c4 commit 9a337f9

File tree

10 files changed

+38
-31
lines changed

10 files changed

+38
-31
lines changed

autointent/context/optimization_info/_optimization_info.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ def log_module_optimization(
7474
module_params: dict[str, Any],
7575
metric_value: float,
7676
metric_name: str,
77-
artifact: Artifact,
7877
module_dump_dir: str | None,
7978
module: "Module | None" = None,
8079
) -> None:
@@ -103,13 +102,11 @@ def log_module_optimization(
103102
if module:
104103
self.modules.add_module(node_type, module)
105104

106-
self.artifacts.add_artifact(node_type, artifact)
107-
108105
def _get_metrics_values(self, node_type: str) -> list[float]:
109106
"""Retrieve all metric values for a specific node type."""
110107
return [trial.metric_value for trial in self.trials.get_trials(node_type)]
111108

112-
def _get_best_trial_idx(self, node_type: str) -> int | None:
109+
def get_best_trial_idx(self, node_type: str) -> int | None:
113110
"""
114111
Retrieve the index of the best trial for a node type.
115112
@@ -133,7 +130,7 @@ def _get_best_artifact(self, node_type: str) -> RetrieverArtifact | ScorerArtifa
133130
:return: The best artifact for the node type.
134131
:raises ValueError: If no best trial exists for the node type.
135132
"""
136-
best_idx = self._get_best_trial_idx(node_type)
133+
best_idx = self.get_best_trial_idx(node_type)
137134
if best_idx is None:
138135
msg = f"No best trial for {node_type}"
139136
raise ValueError(msg)
@@ -194,7 +191,7 @@ def get_inference_nodes_config(self, asdict: bool = False) -> list[InferenceNode
194191
195192
:return: List of `InferenceNodeConfig` objects for inference nodes.
196193
"""
197-
trial_ids = [self._get_best_trial_idx(node_type) for node_type in NodeType]
194+
trial_ids = [self.get_best_trial_idx(node_type) for node_type in NodeType]
198195
res = []
199196
for idx, node_type in zip(trial_ids, NodeType, strict=True):
200197
if idx is None:
@@ -216,7 +213,7 @@ def _get_best_module(self, node_type: str) -> "Module | None":
216213
:param node_type: Type of the node.
217214
:return: The best module, or None if no best trial exists.
218215
"""
219-
idx = self._get_best_trial_idx(node_type)
216+
idx = self.get_best_trial_idx(node_type)
220217
if idx is not None:
221218
return self.modules.get(node_type)[idx]
222219
return None

autointent/modules/abc/_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def score(self, context: Context, split: Literal["validation", "test"], metrics:
4444
"""
4545

4646
@abstractmethod
47-
def get_assets(self) -> Artifact:
47+
def get_artifact(self, context: Context) -> Artifact:
4848
"""Return useful assets that represent intermediate data into context."""
4949

5050
@abstractmethod

autointent/modules/abc/_decision.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,13 +49,14 @@ def score(self, context: Context, split: Literal["validation", "test"], metrics:
4949
:return: Computed metrics value for the test set or error code of metrics
5050
"""
5151
labels, scores = get_decision_evaluation_data(context, split)
52-
self._decisions = self.predict(scores)
52+
decisions = self.predict(scores)
5353
chosen_metrics = {name: fn for name, fn in PREDICTION_METRICS_MULTICLASS.items() if name in metrics}
54-
return self.score_metrics((labels, self._decisions), chosen_metrics)
54+
return self.score_metrics((labels, decisions), chosen_metrics)
5555

56-
def get_assets(self) -> DecisionArtifact:
56+
def get_artifact(self, context: Context) -> DecisionArtifact:
5757
"""Return useful assets that represent intermediate data into context."""
58-
return DecisionArtifact(labels=self._decisions)
58+
_, scores = get_decision_evaluation_data(context, split="test")
59+
return DecisionArtifact(labels=self.predict(scores))
5960

6061
def clear_cache(self) -> None:
6162
"""Clear cache."""

autointent/modules/abc/_scoring.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,24 +41,24 @@ def score(self, context: Context, split: Literal["validation", "test"], metrics:
4141

4242
scores = self.predict(utterances)
4343

44-
self._train_scores = self.predict(context.data_handler.train_utterances(1))
45-
self._validation_scores = self.predict(context.data_handler.validation_utterances(1))
46-
self._test_scores = self.predict(context.data_handler.test_utterances())
47-
4844
metrics_dict = SCORING_METRICS_MULTILABEL if context.is_multilabel() else SCORING_METRICS_MULTICLASS
4945
chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics}
5046
return self.score_metrics((labels, scores), chosen_metrics)
5147

52-
def get_assets(self) -> ScorerArtifact:
48+
def get_artifact(self, context: Context) -> ScorerArtifact:
5349
"""
5450
Retrieve assets generated during scoring.
5551
5652
:return: ScorerArtifact containing test, validation and test scores.
5753
"""
54+
train_scores = self.predict(context.data_handler.train_utterances(1))
55+
validation_scores = self.predict(context.data_handler.validation_utterances(1))
56+
test_scores = self.predict(context.data_handler.test_utterances())
57+
5858
return ScorerArtifact(
59-
train_scores=self._train_scores,
60-
validation_scores=self._validation_scores,
61-
test_scores=self._test_scores,
59+
train_scores=train_scores,
60+
validation_scores=validation_scores,
61+
test_scores=test_scores,
6262
)
6363

6464
@abstractmethod

autointent/modules/embedding/_logreg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def score(self, context: Context, split: Literal["validation", "test"], metrics:
152152
chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics}
153153
return self.score_metrics((labels, probas), chosen_metrics)
154154

155-
def get_assets(self) -> RetrieverArtifact:
155+
def get_artifact(self, context: Context) -> RetrieverArtifact:
156156
"""
157157
Get the classifier artifacts for this module.
158158

autointent/modules/embedding/_retrieval.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ def score(self, context: Context, split: Literal["validation", "test"], metrics:
132132
chosen_metrics = {name: fn for name, fn in metrics_dict.items() if name in metrics}
133133
return self.score_metrics((labels, predictions), chosen_metrics)
134134

135-
def get_assets(self) -> RetrieverArtifact:
135+
def get_artifact(self, context: Context) -> RetrieverArtifact:
136136
"""
137137
Get the retriever artifacts for this module.
138138

autointent/modules/regexp/_regexp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def clear_cache(self) -> None:
133133
"""Clear cache."""
134134
del self.regexp_patterns
135135

136-
def get_assets(self) -> Artifact:
136+
def get_artifact(self) -> Artifact:
137137
"""Get assets."""
138138
return Artifact()
139139

autointent/nodes/_optimization/_node_optimizer.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def fit(self, context: Context) -> None:
5151
:param context: Context
5252
"""
5353
self._logger.info("starting %s node optimization...", self.node_info.node_type)
54-
54+
scored_modules = []
5555
for search_space in deepcopy(self.modules_search_spaces):
5656
module_name = search_space.pop("module_name")
5757

@@ -62,7 +62,10 @@ def fit(self, context: Context) -> None:
6262
context.callback_handler.start_module(
6363
module_name=module_name, num=j_combination, module_kwargs=module_kwargs
6464
)
65-
module = self.node_info.modules_available[module_name].from_context(context, **module_kwargs)
65+
module_type = self.node_info.modules_available[module_name]
66+
module = module_type.from_context(context, **module_kwargs)
67+
68+
scored_modules.append((module_type, module_kwargs))
6669

6770
embedder_name = module.get_embedder_name()
6871
if embedder_name is not None:
@@ -92,7 +95,6 @@ def fit(self, context: Context) -> None:
9295
module_kwargs,
9396
metric_value,
9497
self.target_metric,
95-
module.get_assets(), # retriever name / scores / predictions
9698
module_dump_dir,
9799
module=module if not context.is_ram_to_clear() else None,
98100
)
@@ -102,7 +104,14 @@ def fit(self, context: Context) -> None:
102104
gc.collect()
103105
torch.cuda.empty_cache()
104106

105-
self._logger.info("%s node optimization is finished!", self.node_info.node_type)
107+
self._logger.info("%s node optimization is finished! saving best assets", self.node_info.node_type)
108+
# TODO refactor the following code (via implementing `autointent.load_module(path)` utility)
109+
trial_idx = context.optimization_info.get_best_trial_idx(self.node_type)
110+
trial = context.optimization_info.trials.get_trial(self.node_type, trial_idx)
111+
module_type, module_kwargs = scored_modules[trial_idx]
112+
best_module: Module = module_type(**module_kwargs)
113+
best_module.load(trial.module_dump_dir)
114+
context.optimization_info.artifacts.add_artifact(self.node_type, best_module.get_artifact(context))
106115

107116
def get_module_dump_dir(self, dump_dir: Path, module_name: str, j_combination: int) -> str:
108117
"""

tests/modules/embedding/test_logreg.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
from autointent.modules.embedding import LogregAimedEmbedding
22

33

4-
def test_get_assets_returns_correct_artifact_for_logreg():
4+
def test_get_artifact_returns_correct_artifact_for_logreg():
55
module = LogregAimedEmbedding(embedder_name="sergeyzh/rubert-tiny-turbo")
6-
artifact = module.get_assets()
6+
artifact = module.get_artifact()
77
assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo"
88

99

tests/modules/embedding/test_retrieval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from tests.conftest import setup_environment
55

66

7-
def test_get_assets_returns_correct_artifact():
7+
def test_get_artifact_returns_correct_artifact():
88
module = RetrievalAimedEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")
9-
artifact = module.get_assets()
9+
artifact = module.get_artifact()
1010
assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo"
1111

1212

0 commit comments

Comments
 (0)