Skip to content

Commit d26327c

Browse files
committed
bug fix
1 parent 16b1ded commit d26327c

File tree

4 files changed

+8
-6
lines changed

4 files changed

+8
-6
lines changed

autointent/modules/embedding/_retrieval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,5 +150,5 @@ def predict(self, utterances: list[str]) -> list[ListOfLabels]:
150150
Returns:
151151
List of labels for each retrieved utterance
152152
"""
153-
predictions, _ = self._vector_index.query(utterances, self.k)
154-
return predictions
153+
_, documents = self._vector_index.query(utterances, self.k)
154+
return [[n.label for n in neigs] for neigs in documents]

tests/callback/test_callback.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from autointent import Context, Pipeline
77
from autointent._callbacks import CallbackHandler, OptimizerCallback
8-
from autointent.configs import DataConfig, HPOConfig, LoggingConfig
8+
from autointent.configs import DataConfig, FaissConfig, HPOConfig, LoggingConfig
99
from tests.conftest import setup_environment
1010

1111

@@ -96,6 +96,7 @@ def test_pipeline_callbacks(dataset):
9696
context.callback_handler = CallbackHandler([DummyCallback])
9797
context.set_dataset(dataset, DataConfig(scheme="ho"))
9898
context.configure_hpo(HPOConfig(n_trials=10))
99+
context.configure_vector_index(FaissConfig())
99100

100101
pipeline_optimizer._fit(context)
101102

tests/context/test_vector_index.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33
from autointent import VectorIndex
4-
from autointent.configs import EmbedderConfig
4+
from autointent.configs import EmbedderConfig, FaissConfig
55

66

77
@pytest.fixture
@@ -14,7 +14,7 @@ class MockDataHandler:
1414

1515

1616
def test_create_collection(data_handler):
17-
vector_index = VectorIndex(embedder_config=EmbedderConfig(model_name="bert-base-uncased"))
17+
vector_index = VectorIndex(embedder_config=EmbedderConfig(model_name="bert-base-uncased"), config=FaissConfig())
1818
vector_index.add(
1919
data_handler.utterances_train,
2020
data_handler.labels_train,

tests/modules/test_dumper.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ def check_attributes(self):
5252
tokenizer_predictions = self.tokenizer(["hello", "world"]).input_ids
5353
np.testing.assert_array_equal(self._tokenizer_predictions, tokenizer_predictions)
5454
with torch.no_grad():
55-
np.testing.assert_array_equal(
55+
np.testing.assert_almost_equal(
5656
self._transformer_predictions,
5757
self.transformer(input_ids=torch.tensor(tokenizer_predictions)).logits.cpu().numpy(),
58+
decimal=4,
5859
)
5960

6061

0 commit comments

Comments
 (0)