Skip to content

Commit 7176660

Browse files
committed
handle default embedder config usage
1 parent 1091044 commit 7176660

File tree

4 files changed

+14
-3
lines changed

4 files changed

+14
-3
lines changed

src/autointent/_pipeline/_pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
InferenceNodeConfig,
2020
LoggingConfig,
2121
VectorIndexConfig,
22+
get_default_embedder_config,
2223
get_default_vector_index_config,
2324
)
2425
from autointent.custom_types import ListOfGenericLabels, NodeType, SearchSpacePreset, SearchSpaceValidationMode
@@ -56,7 +57,7 @@ def __init__(
5657

5758
if isinstance(nodes[0], NodeOptimizer):
5859
self.logging_config = LoggingConfig()
59-
self.embedder_config = EmbedderConfig()
60+
self.embedder_config = get_default_embedder_config()
6061
self.cross_encoder_config = CrossEncoderConfig()
6162
self.data_config = DataConfig()
6263
self.transformer_config = HFModelConfig()

src/autointent/_wrappers/embedder/embedder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def _load_model(self) -> BaseEmbeddingBackend:
5555
return OpenaiEmbeddingBackend(self.config)
5656
# Check if it's exactly the abstract base config (not a subclass)
5757
if type(self.config) is EmbedderConfig:
58-
# Handle abstract base config case
5958
msg = f"Cannot instantiate abstract EmbedderConfig: {self.config.__repr__()}"
6059
raise TypeError(msg)
6160
assert_never(self.config)

src/autointent/configs/__init__.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
"""Dataclasses for the configuration of the :class:`autointent.Embedder` and other objects."""
22

3-
from ._embedder import EmbedderConfig, OpenaiEmbeddingConfig, SentenceTransformerEmbeddingConfig, TaskTypeEnum
3+
from ._embedder import (
4+
EmbedderConfig,
5+
OpenaiEmbeddingConfig,
6+
SentenceTransformerEmbeddingConfig,
7+
TaskTypeEnum,
8+
get_default_embedder_config,
9+
)
410
from ._inference_node import InferenceNodeConfig
511
from ._optimization import DataConfig, HPOConfig, LoggingConfig
612
from ._torch import TorchTrainingConfig, VocabConfig
@@ -32,5 +38,6 @@
3238
"TorchTrainingConfig",
3339
"VectorIndexConfig",
3440
"VocabConfig",
41+
"get_default_embedder_config",
3542
"get_default_vector_index_config",
3643
]

src/autointent/configs/_embedder.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,7 @@ class OpenaiEmbeddingConfig(EmbedderConfig):
9696
max_per_second: float | None = Field(
9797
None, description="Maximum number of API requests per second. Only used with async processing."
9898
)
99+
100+
101+
def get_default_embedder_config() -> EmbedderConfig:
102+
return SentenceTransformerEmbeddingConfig()

0 commit comments

Comments
 (0)