Skip to content

Commit 6cb6fe5

Browse files
authored
Refactor/testing logic (#140)
* disable testing if test data is not provided * change message
1 parent 97b07aa commit 6cb6fe5

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

autointent/_pipeline/_pipeline.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,13 @@ def fit(
141141
context.configure_vector_index(self.vector_index_config)
142142

143143
self.validate_modules(dataset)
144+
145+
test_utterances = context.data_handler.test_utterances()
146+
if test_utterances is None:
147+
self._logger.warning(
148+
"Test data is not provided. Final test metrics won't be calculated after pipeline optimization."
149+
)
150+
144151
self._fit(context, sampler)
145152

146153
if context.is_ram_to_clear():
@@ -153,15 +160,17 @@ def fit(
153160
self.nodes = {node.node_type: node for node in nodes_list}
154161

155162
if refit_after:
163+
# TODO reflect this refitting in dumped version of pipeline
156164
self._refit(context)
157165

158-
predictions = self.predict(context.data_handler.test_utterances())
159-
for metric_name, metric in DECISION_METRICS.items():
160-
context.optimization_info.pipeline_metrics[metric_name] = metric(
161-
context.data_handler.test_labels(),
162-
predictions,
163-
)
164-
context.callback_handler.log_final_metrics(context.optimization_info.pipeline_metrics)
166+
if test_utterances is not None:
167+
predictions = self.predict(test_utterances)
168+
for metric_name, metric in DECISION_METRICS.items():
169+
context.optimization_info.pipeline_metrics[metric_name] = metric(
170+
context.data_handler.test_labels(),
171+
predictions,
172+
)
173+
context.callback_handler.log_final_metrics(context.optimization_info.pipeline_metrics)
165174

166175
return context
167176

autointent/context/data_handler/_data_handler.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def validation_labels(self, idx: int | None = None) -> ListOfGenericLabels:
150150
split = self._choose_split(Split.VALIDATION, idx)
151151
return cast(ListOfGenericLabels, self.dataset[split][self.dataset.label_feature])
152152

153-
def test_utterances(self) -> list[str]:
153+
def test_utterances(self) -> list[str] | None:
154154
"""
155155
Retrieve test utterances from the dataset.
156156
@@ -161,6 +161,8 @@ def test_utterances(self) -> list[str]:
161161
:param idx: Optional index for a specific test split.
162162
:return: List of test utterances.
163163
"""
164+
if Split.TEST not in self.dataset:
165+
return None
164166
return cast(list[str], self.dataset[Split.TEST][self.dataset.utterance_feature])
165167

166168
def test_labels(self) -> ListOfGenericLabels:

0 commit comments

Comments
 (0)