Skip to content

Commit 1534056

Browse files
committed
upd inference test
1 parent 6b80fb2 commit 1534056

File tree

2 files changed

+6
-4
lines changed

2 files changed

+6
-4
lines changed

src/autointent/configs/_embedder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class OpenaiEmbeddingConfig(EmbedderConfig):
101101

102102

103103
def get_default_embedder_config(**kwargs: Any) -> EmbedderConfig: # noqa: ANN401
104-
return SentenceTransformerEmbeddingConfig(**kwargs)
104+
return SentenceTransformerEmbeddingConfig.model_validate(kwargs)
105105

106106

107107
def initialize_embedder_config(values: dict[str, Any] | str | EmbedderConfig | None) -> EmbedderConfig:

tests/pipeline/test_inference.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pytest
44

55
from autointent import Pipeline
6-
from autointent.configs import EmbedderConfig, LoggingConfig, TokenizerConfig
6+
from autointent.configs import LoggingConfig, TokenizerConfig, get_default_embedder_config
77
from autointent.custom_types import NodeType
88
from tests.conftest import get_search_space, setup_environment
99

@@ -129,7 +129,8 @@ def test_load_with_overrided_params(dataset):
129129

130130
# case 1: simple inference from file system
131131
inference_pipeline = Pipeline.load(
132-
logging_config.dirpath, embedder_config=EmbedderConfig(tokenizer_config=TokenizerConfig(max_length=8))
132+
logging_config.dirpath,
133+
embedder_config=get_default_embedder_config(tokenizer_config=TokenizerConfig(max_length=8)),
133134
)
134135
utterances = ["123", "hello world"]
135136
prediction = inference_pipeline.predict(utterances)
@@ -146,7 +147,8 @@ def test_load_with_overrided_params(dataset):
146147
del pipeline_optimizer
147148

148149
loaded_pipe = Pipeline.load(
149-
logging_config.dirpath, embedder_config=EmbedderConfig(tokenizer_config=TokenizerConfig(max_length=8))
150+
logging_config.dirpath,
151+
embedder_config=get_default_embedder_config(tokenizer_config=TokenizerConfig(max_length=8)),
150152
)
151153
prediction_v2 = loaded_pipe.predict(utterances)
152154
assert prediction == prediction_v2

0 commit comments

Comments
 (0)