Skip to content

Commit a602cd2

Browse files
committed
stage progress
1 parent 2924d09 commit a602cd2

File tree

4 files changed

+10
-18
lines changed

4 files changed

+10
-18
lines changed

autointent/nodes/_optimization/_node_optimizer.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,15 @@ def fit(self, context: Context) -> None:
107107
self._logger.info("%s node optimization is finished! saving best assets", self.node_info.node_type)
108108
# TODO refactor the following code (via implementing `autointent.load_module(path)` utility)
109109
trial_idx = context.optimization_info.get_best_trial_idx(self.node_type)
110-
trial = context.optimization_info.trials.get_trial(self.node_type, trial_idx) # type: ignore[arg-type]
111-
module_type, module_kwargs = scored_modules[trial_idx] # type: ignore[index]
112-
best_module: Module = module_type(**module_kwargs)
113-
best_module.load(trial.module_dump_dir) # type: ignore[arg-type]
114-
context.optimization_info.artifacts.add_artifact(self.node_type, best_module.get_artifact(context))
110+
if context.is_ram_to_clear():
111+
trial = context.optimization_info.trials.get_trial(self.node_type, trial_idx) # type: ignore[arg-type]
112+
module_type, module_kwargs = scored_modules[trial_idx] # type: ignore[index]
113+
best_module: Module = module_type(**module_kwargs)
114+
best_module.load(trial.module_dump_dir) # type: ignore[arg-type]
115+
else:
116+
best_module = context.optimization_info.modules.get(self.node_type)[trial_idx]
117+
artifact = best_module.get_artifact(context)
118+
context.optimization_info.artifacts.add_artifact(self.node_type, artifact)
115119

116120
def get_module_dump_dir(self, dump_dir: Path, module_name: str, j_combination: int) -> str:
117121
"""

tests/modules/embedding/test_logreg.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,6 @@
11
from autointent.modules.embedding import LogregAimedEmbedding
22

33

4-
def test_get_artifact_returns_correct_artifact_for_logreg():
5-
module = LogregAimedEmbedding(embedder_name="sergeyzh/rubert-tiny-turbo")
6-
artifact = module.get_artifact()
7-
assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo"
8-
9-
104
def test_fit_trains_model():
115
module = LogregAimedEmbedding(embedder_name="sergeyzh/rubert-tiny-turbo")
126

tests/modules/embedding/test_retrieval.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,6 @@
44
from tests.conftest import setup_environment
55

66

7-
def test_get_artifact_returns_correct_artifact():
8-
module = RetrievalAimedEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")
9-
artifact = module.get_artifact()
10-
assert artifact.embedder_name == "sergeyzh/rubert-tiny-turbo"
11-
12-
137
def test_dump_and_load_preserves_model_state():
148
project_dir = setup_environment()
159
module = RetrievalAimedEmbedding(k=5, embedder_name="sergeyzh/rubert-tiny-turbo")

tests/nodes/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def get_context(multilabel):
8282
if multilabel:
8383
dataset = dataset.to_multilabel()
8484
res.set_dataset(dataset)
85-
res.configure_logging(LoggingConfig(project_dir=project_dir, keep_in_ram=False))
85+
res.configure_logging(LoggingConfig(project_dir=project_dir, keep_in_ram=True))
8686
res.configure_vector_index(VectorIndexConfig(), EmbedderConfig(device="cpu"))
8787
res.configure_cross_encoder(CrossEncoderConfig())
8888
return res

0 commit comments

Comments
 (0)