Skip to content

Commit c60245c

Browse files
Feat/api embeddings (#263)
* add mypy pydantic plugin settings * implement base class and interface class * refactor embedder config * add sentence transformer embedding backend * add openai embedding backend * re-refactor embedder configs * re-refactor dump/load * add proper dump/load to Embedder * handle default embedder config usage * fix some typing errors * fix a couple more * fix some more typing errors * one more error * is it all? * Update optimizer_config.schema.json * bug fix * fix some tests * temporary way to fix tests * refactor embedder tests * fix some tests * Update optimizer_config.schema.json * try to fix dynamic schema issues * Update optimizer_config.schema.json * upd vector index tests * upd inference test * upd tutorials * ignore ds store * set similarity_fn default to None * upd callback test * remove unnecessary import * run code formatter * remove unnecessary import * add openai base url option * remove openai api key everywhere for security reasons * ignore extra envs in mcp server * add typed marker --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 40d2986 commit c60245c

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1887
-661
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,4 @@ vector_db*
184184
/wandb
185185
model_output/
186186
my.py
187+
.DS_store

docs/optimizer_config.schema.json

Lines changed: 2 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -123,54 +123,8 @@
123123
},
124124
"EmbedderConfig": {
125125
"additionalProperties": false,
126+
"description": "Base class for embedder configurations.",
126127
"properties": {
127-
"model_name": {
128-
"default": "sentence-transformers/all-MiniLM-L6-v2",
129-
"description": "Name of the hugging face model.",
130-
"title": "Model Name",
131-
"type": "string"
132-
},
133-
"batch_size": {
134-
"default": 32,
135-
"description": "Batch size for model inference.",
136-
"exclusiveMinimum": 0,
137-
"title": "Batch Size",
138-
"type": "integer"
139-
},
140-
"device": {
141-
"anyOf": [
142-
{
143-
"type": "string"
144-
},
145-
{
146-
"type": "null"
147-
}
148-
],
149-
"default": null,
150-
"description": "Torch notation for CPU or CUDA.",
151-
"title": "Device"
152-
},
153-
"bf16": {
154-
"default": false,
155-
"description": "Whether to use mixed precision training (not all devices support this).",
156-
"title": "Bf16",
157-
"type": "boolean"
158-
},
159-
"fp16": {
160-
"default": false,
161-
"description": "Whether to use mixed precision training (not all devices support this).",
162-
"title": "Fp16",
163-
"type": "boolean"
164-
},
165-
"tokenizer_config": {
166-
"$ref": "#/$defs/TokenizerConfig"
167-
},
168-
"trust_remote_code": {
169-
"default": false,
170-
"description": "Whether to trust the remote code when loading the model.",
171-
"title": "Trust Remote Code",
172-
"type": "boolean"
173-
},
174128
"default_prompt": {
175129
"anyOf": [
176130
{
@@ -249,18 +203,6 @@
249203
"description": "Prompt for passage.",
250204
"title": "Passage Prompt"
251205
},
252-
"similarity_fn_name": {
253-
"default": "cosine",
254-
"description": "Name of the similarity function to use.",
255-
"enum": [
256-
"cosine",
257-
"dot",
258-
"euclidean",
259-
"manhattan"
260-
],
261-
"title": "Similarity Fn Name",
262-
"type": "string"
263-
},
264206
"use_cache": {
265207
"default": true,
266208
"description": "Whether to use embeddings caching.",
@@ -552,28 +494,7 @@
552494
}
553495
},
554496
"embedder_config": {
555-
"$ref": "#/$defs/EmbedderConfig",
556-
"default": {
557-
"model_name": "sentence-transformers/all-MiniLM-L6-v2",
558-
"batch_size": 32,
559-
"device": null,
560-
"bf16": false,
561-
"fp16": false,
562-
"tokenizer_config": {
563-
"max_length": null,
564-
"padding": true,
565-
"truncation": true
566-
},
567-
"trust_remote_code": false,
568-
"default_prompt": null,
569-
"classification_prompt": null,
570-
"cluster_prompt": null,
571-
"sts_prompt": null,
572-
"query_prompt": null,
573-
"passage_prompt": null,
574-
"similarity_fn_name": "cosine",
575-
"use_cache": true
576-
}
497+
"$ref": "#/$defs/EmbedderConfig"
577498
},
578499
"cross_encoder_config": {
579500
"$ref": "#/$defs/CrossEncoderConfig",

pyproject.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,10 @@ plugins = [
240240
mypy_path = "src/autointent"
241241
disable_error_code = ["override"]
242242

243+
[tool.pydantic-mypy]
244+
init_forbid_extra = true
245+
init_typed = true
246+
243247
[[tool.mypy.overrides]]
244248
module = [
245249
"scipy",

src/autointent/_optimization_config.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
11
from typing import Any
22

3-
from pydantic import BaseModel, PositiveInt
3+
from pydantic import BaseModel, Field, PositiveInt, field_validator
44

5-
from .configs import CrossEncoderConfig, DataConfig, EmbedderConfig, HFModelConfig, HPOConfig, LoggingConfig
5+
from .configs import (
6+
CrossEncoderConfig,
7+
DataConfig,
8+
EmbedderConfig,
9+
HFModelConfig,
10+
HPOConfig,
11+
LoggingConfig,
12+
initialize_embedder_config,
13+
)
614

715

816
class OptimizationConfig(BaseModel):
@@ -20,7 +28,13 @@ class OptimizationConfig(BaseModel):
2028
logging_config: LoggingConfig = LoggingConfig()
2129
"""See tutorial on logging configuration."""
2230

23-
embedder_config: EmbedderConfig = EmbedderConfig()
31+
embedder_config: EmbedderConfig = Field(default_factory=lambda: initialize_embedder_config(None))
32+
33+
@field_validator("embedder_config", mode="before")
34+
@classmethod
35+
def validate_embedder_config(cls, v: Any) -> EmbedderConfig: # noqa: ANN401
36+
"""Validate and convert embedder config to proper type."""
37+
return initialize_embedder_config(v)
2438

2539
cross_encoder_config: CrossEncoderConfig = CrossEncoderConfig()
2640

src/autointent/_pipeline/_pipeline.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
InferenceNodeConfig,
2020
LoggingConfig,
2121
VectorIndexConfig,
22+
get_default_embedder_config,
2223
get_default_vector_index_config,
2324
)
2425
from autointent.custom_types import ListOfGenericLabels, NodeType, SearchSpacePreset, SearchSpaceValidationMode
@@ -56,7 +57,7 @@ def __init__(
5657

5758
if isinstance(nodes[0], NodeOptimizer):
5859
self.logging_config = LoggingConfig()
59-
self.embedder_config = EmbedderConfig()
60+
self.embedder_config = get_default_embedder_config()
6061
self.cross_encoder_config = CrossEncoderConfig()
6162
self.data_config = DataConfig()
6263
self.transformer_config = HFModelConfig()
@@ -111,7 +112,7 @@ def from_search_space(cls, search_space: list[dict[str, Any]] | Path | str, seed
111112
return cls(nodes=nodes, seed=seed)
112113

113114
@classmethod
114-
def from_preset(cls, name: SearchSpacePreset, seed: int | None = 42) -> "Pipeline":
115+
def from_preset(cls, name: SearchSpacePreset, seed: int = 42) -> "Pipeline":
115116
"""Instantiate pipeline optimizer from a preset."""
116117
optimization_config = load_preset(name)
117118
config = OptimizationConfig(seed=seed, **optimization_config)
@@ -395,6 +396,19 @@ def _refit(self, context: Context) -> None:
395396
decision_module.clear_cache()
396397
decision_module.fit(scores, context.data_handler.train_labels(1), context.data_handler.tags)
397398

399+
def _convert_score_to_float_list(self, score: Any) -> list[float]: # noqa: ANN401
400+
"""Convert score to list of floats for InferencePipelineUtteranceOutput."""
401+
if hasattr(score, "tolist"):
402+
result = score.tolist()
403+
return result if isinstance(result, list) else [float(result)]
404+
if score is None:
405+
return []
406+
if isinstance(score, int | float):
407+
return [float(score)]
408+
if hasattr(score, "__iter__") and not isinstance(score, str):
409+
return [float(x) for x in score]
410+
return [float(score)]
411+
398412
def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutput:
399413
"""Predict the labels for the utterances with metadata.
400414
@@ -422,13 +436,13 @@ def predict_with_metadata(self, utterances: list[str]) -> InferencePipelineOutpu
422436
regex_prediction_metadata=regex_predictions_metadata[idx]
423437
if regex_predictions_metadata is not None
424438
else None,
425-
score=scores[idx],
439+
score=self._convert_score_to_float_list(scores[idx]),
426440
score_metadata=scores_metadata[idx] if scores_metadata is not None else None,
427441
)
428442
outputs.append(output)
429443

430444
return InferencePipelineOutput(
431-
predictions=predictions,
445+
predictions=predictions, # type: ignore[arg-type]
432446
regex_predictions=regex_predictions,
433447
utterances=outputs,
434448
)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""Embedder module with multiple backend support."""
2+
3+
from .base import BaseEmbeddingBackend
4+
from .embedder import Embedder
5+
from .openai import OpenaiEmbeddingBackend
6+
from .sentence_transformers import SentenceTransformerEmbeddingBackend
7+
8+
__all__ = [
9+
"BaseEmbeddingBackend",
10+
"Embedder",
11+
"OpenaiEmbeddingBackend",
12+
"SentenceTransformerEmbeddingBackend",
13+
]
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
from abc import ABC, abstractmethod
2+
from pathlib import Path
3+
from typing import Literal, overload
4+
5+
import numpy as np
6+
import numpy.typing as npt
7+
import torch
8+
9+
from autointent.configs import EmbedderConfig, TaskTypeEnum
10+
11+
12+
class BaseEmbeddingBackend(ABC):
13+
"""Abstract base class for embedding backends."""
14+
15+
supports_training: bool = False
16+
17+
@abstractmethod
18+
def __init__(self, config: EmbedderConfig) -> None:
19+
"""Initialize the embedding backend with configuration."""
20+
...
21+
22+
@abstractmethod
23+
def clear_ram(self) -> None:
24+
"""Clear the backend from RAM."""
25+
...
26+
27+
@overload
28+
@abstractmethod
29+
def embed(
30+
self, utterances: list[str], task_type: TaskTypeEnum | None = None, *, return_tensors: Literal[True]
31+
) -> torch.Tensor: ...
32+
33+
@overload
34+
@abstractmethod
35+
def embed(
36+
self, utterances: list[str], task_type: TaskTypeEnum | None = None, *, return_tensors: Literal[False] = False
37+
) -> npt.NDArray[np.float32]: ...
38+
39+
@abstractmethod
40+
def embed(
41+
self,
42+
utterances: list[str],
43+
task_type: TaskTypeEnum | None = None,
44+
return_tensors: bool = False,
45+
) -> npt.NDArray[np.float32] | torch.Tensor:
46+
"""Calculate embeddings for a list of utterances.
47+
48+
Args:
49+
utterances: List of input texts to calculate embeddings for.
50+
task_type: Type of task for which embeddings are calculated.
51+
return_tensors: If True, return a PyTorch tensor; otherwise, return a numpy array.
52+
53+
Returns:
54+
A numpy array or PyTorch tensor of embeddings.
55+
"""
56+
...
57+
58+
@abstractmethod
59+
def similarity(
60+
self, embeddings1: npt.NDArray[np.float32], embeddings2: npt.NDArray[np.float32]
61+
) -> npt.NDArray[np.float32]:
62+
"""Calculate similarity between two sets of embeddings.
63+
64+
Args:
65+
embeddings1: First set of embeddings (size n).
66+
embeddings2: Second set of embeddings (size m).
67+
68+
Returns:
69+
A numpy array of similarities (size n x m).
70+
"""
71+
...
72+
73+
@abstractmethod
74+
def get_hash(self) -> int:
75+
"""Compute a hash value for the backend configuration and model state.
76+
77+
Returns:
78+
The hash value of the backend.
79+
"""
80+
...
81+
82+
@abstractmethod
83+
def dump(self, path: Path) -> None:
84+
"""Save the backend state to disk.
85+
86+
Args:
87+
path: Path to the directory where the backend will be saved.
88+
"""
89+
...
90+
91+
@classmethod
92+
@abstractmethod
93+
def load(cls, path: Path) -> "BaseEmbeddingBackend":
94+
"""Load the backend state from disk.
95+
96+
Args:
97+
path: Path to the directory where the backend is stored.
98+
99+
Returns:
100+
Loaded backend instance.
101+
"""
102+
...

0 commit comments

Comments
 (0)