Skip to content
Draft
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
3dbad0d
Draft implementation of support for embeddings APIs
dmontagu Oct 24, 2025
467bb8e
Merge branch 'main' into embeddings-api
DouweM Nov 14, 2025
00d8e26
Progress is made
DouweM Nov 14, 2025
a133796
Merge branch 'main' into embeddings-api
DouweM Nov 14, 2025
6d9e2a5
fix typing
DouweM Nov 14, 2025
9ffddf8
fix tests
DouweM Nov 14, 2025
d777138
Extension of embeddings draft implementation to support local models …
tomaarsen Nov 18, 2025
6973b28
Merge branch 'main' into embeddings-api
DouweM Nov 24, 2025
5aa6d87
Split query and documents methods; add tests for SentenceTransformers
DouweM Nov 24, 2025
1e66742
Instrumentation
DouweM Nov 24, 2025
bd65c4d
Add sentence-transformers
DouweM Nov 24, 2025
35a533f
tweaks
DouweM Nov 24, 2025
7392c38
Add max_input_tokens and count_tokens
DouweM Nov 25, 2025
45b2a6d
Implement OpenAI token counting using tiktoken
DouweM Nov 27, 2025
c336093
Test known embedding model names
DouweM Nov 27, 2025
0e015fa
Extract OpenAI usage and calculate cost
DouweM Nov 27, 2025
4ec0b32
Fix Cohere, SentenceTransformers
DouweM Nov 28, 2025
1717115
Merge branch 'main' into embeddings-api
DouweM Dec 10, 2025
500fc38
Various fixes
DouweM Dec 10, 2025
bb4eb3e
Follow otel gen_ai convention
DouweM Dec 10, 2025
cc8aaf1
Add metrics
DouweM Dec 10, 2025
d6fdbcf
Error handling
DouweM Dec 10, 2025
79c6157
simplification
DouweM Dec 10, 2025
880e0fa
Fix tests
DouweM Dec 10, 2025
add6444
Merge branch 'main' into embeddings-api
DouweM Dec 12, 2025
849fb19
Address feedback; download small embedding model from HF
DouweM Dec 12, 2025
bcc090b
Fixes
DouweM Dec 12, 2025
6878484
fix tests
DouweM Dec 12, 2025
581e2b6
fix tests
DouweM Dec 12, 2025
1d15644
fix tests
DouweM Dec 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@
WebSearchTool,
WebSearchUserLocation,
)
from .embeddings import (
Embedder,
)
from .exceptions import (
AgentRunError,
ApprovalRequired,
Expand Down Expand Up @@ -123,6 +126,8 @@
'UserPromptNode',
'capture_run_messages',
'InstrumentationSettings',
# embeddings
'Embedder',
# exceptions
'AgentRunError',
'CallDeferred',
Expand Down
237 changes: 237 additions & 0 deletions pydantic_ai_slim/pydantic_ai/embeddings/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
from collections.abc import Callable, Iterator, Sequence
from contextlib import contextmanager
from contextvars import ContextVar
from dataclasses import dataclass
from typing import Any, ClassVar, Literal, get_args

from typing_extensions import TypeAliasType

from pydantic_ai import _utils
from pydantic_ai.exceptions import UserError
from pydantic_ai.models import OpenAIChatCompatibleProvider, OpenAIResponsesCompatibleProvider
from pydantic_ai.models.instrumented import InstrumentationSettings
from pydantic_ai.providers import Provider, infer_provider

from .base import EmbeddingModel
from .instrumented import InstrumentedEmbeddingModel, instrument_embedding_model
from .result import EmbeddingResult, EmbedInputType
from .settings import EmbeddingSettings, merge_embedding_settings
from .wrapper import WrapperEmbeddingModel

__all__ = [
'Embedder',
'EmbeddingModel',
'EmbeddingSettings',
'EmbeddingResult',
'merge_embedding_settings',
'KnownEmbeddingModelName',
'infer_model',
'WrapperEmbeddingModel',
'InstrumentedEmbeddingModel',
'instrument_embedding_model',
]

KnownEmbeddingModelName = TypeAliasType(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test like this one to verify this is up to date:

def test_known_model_names(): # pragma: lax no cover

'KnownEmbeddingModelName',
Literal[
'openai:text-embedding-ada-002',
'openai:text-embedding-3-small',
'openai:text-embedding-3-large',
'cohere:embed-v4.0',
'cohere:embed-english-v3.0',
'cohere:embed-english-light-v3.0',
'cohere:embed-multilingual-v3.0',
'cohere:embed-multilingual-light-v3.0',
],
)
"""Known model names that can be used with the `model` parameter of [`Embedder`][pydantic_ai.embeddings.Embedder].

`KnownEmbeddingModelName` is provided as a concise way to specify an embedding model.
"""

# For now, we assume that every chat and completions-compatible provider also
# supports the embeddings endpoint, as at worst the user would get an `ModelHTTPError`.
OpenAIEmbeddingsCompatibleProvider = OpenAIChatCompatibleProvider | OpenAIResponsesCompatibleProvider


def infer_model(
model: EmbeddingModel | KnownEmbeddingModelName | str,
*,
provider_factory: Callable[[str], Provider[Any]] = infer_provider,
) -> EmbeddingModel:
"""Infer the model from the name."""
if isinstance(model, EmbeddingModel):
return model

try:
provider_name, model_name = model.split(':', maxsplit=1)
except ValueError as e:
raise ValueError('You must provide a provider prefix when specifying an embedding model name') from e

provider = provider_factory(provider_name)

model_kind = provider_name
if model_kind.startswith('gateway/'):
from ..providers.gateway import normalize_gateway_provider

model_kind = normalize_gateway_provider(model_kind)

if model_kind in (
'openai',
# For now, we assume that every chat and completions-compatible provider also
# supports the embeddings endpoint, as at worst the user would get an `ModelHTTPError`.
*get_args(OpenAIChatCompatibleProvider.__value__),
*get_args(OpenAIResponsesCompatibleProvider.__value__),
):
from .openai import OpenAIEmbeddingModel

return OpenAIEmbeddingModel(model_name, provider=provider)
elif model_kind == 'cohere':
from .cohere import CohereEmbeddingModel

return CohereEmbeddingModel(model_name, provider=provider)
elif model_kind == 'sentence-transformers':
from .sentence_transformers import SentenceTransformerEmbeddingModel

return SentenceTransformerEmbeddingModel(model_name)
else:
raise UserError(f'Unknown embeddings model: {model}') # pragma: no cover


@dataclass(init=False)
class Embedder:
"""TODO: Docstring."""

instrument: InstrumentationSettings | bool | None
"""Options to automatically instrument with OpenTelemetry.

Set to `True` to use default instrumentation settings, which will use Logfire if it's configured.
Set to an instance of [`InstrumentationSettings`][pydantic_ai.models.instrumented.InstrumentationSettings] to customize.
If this isn't set, then the last value set by
[`Embedder.instrument_all()`][pydantic_ai.embeddings.Embedder.instrument_all]
will be used, which defaults to False.
See the [Debugging and Monitoring guide](https://ai.pydantic.dev/logfire/) for more info.
"""

_instrument_default: ClassVar[InstrumentationSettings | bool] = False

def __init__(
self,
model: EmbeddingModel | KnownEmbeddingModelName | str,
*,
settings: EmbeddingSettings | None = None,
defer_model_check: bool = True,
instrument: InstrumentationSettings | bool | None = None,
) -> None:
"""Initialize an Embedder.

Args:
model: The embedding model to use - can be a model instance, model name, or string.
settings: Optional embedding settings to use as defaults.
defer_model_check: Whether to defer model validation until first use.
instrument: OpenTelemetry instrumentation settings. Set to `True` to enable with defaults,
or pass an `InstrumentationSettings` instance to customize. If `None`, uses the value
from `Embedder.instrument_all()`.
"""
self._model = model if defer_model_check else infer_model(model)
self._settings = settings
self.instrument = instrument

self._override_model: ContextVar[EmbeddingModel | None] = ContextVar('_override_model', default=None)

@staticmethod
def instrument_all(instrument: InstrumentationSettings | bool = True) -> None:
"""Set the instrumentation options for all embedders where `instrument` is not set.

Args:
instrument: Instrumentation settings to use as the default. Set to `True` for default settings,
`False` to disable, or pass an `InstrumentationSettings` instance to customize.
"""
Embedder._instrument_default = instrument

@property
def model(self) -> EmbeddingModel | KnownEmbeddingModelName | str:
return self._model

@contextmanager
def override(
self,
*,
model: EmbeddingModel | KnownEmbeddingModelName | str | _utils.Unset = _utils.UNSET,
) -> Iterator[None]:
if _utils.is_set(model):
model_token = self._override_model.set(infer_model(model))
else:
model_token = None

try:
yield
finally:
if model_token is not None:
self._override_model.reset(model_token)

async def embed_query(
self, query: str | Sequence[str], *, settings: EmbeddingSettings | None = None
) -> EmbeddingResult:
return await self.embed(query, input_type='query', settings=settings)

async def embed_documents(
self, documents: str | Sequence[str], *, settings: EmbeddingSettings | None = None
) -> EmbeddingResult:
return await self.embed(documents, input_type='document', settings=settings)

async def embed(
self, documents: str | Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
) -> EmbeddingResult:
model = self._get_model()
settings = merge_embedding_settings(self._settings, settings)
return await model.embed(documents, input_type=input_type, settings=settings)

def embed_query_sync(
self, query: str | Sequence[str], *, settings: EmbeddingSettings | None = None
) -> EmbeddingResult:
return _utils.get_event_loop().run_until_complete(self.embed_query(query, settings=settings))

def embed_documents_sync(
self, documents: str | Sequence[str], *, settings: EmbeddingSettings | None = None
) -> EmbeddingResult:
return _utils.get_event_loop().run_until_complete(self.embed_documents(documents, settings=settings))

def embed_sync(
self, documents: str | Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
) -> EmbeddingResult:
return _utils.get_event_loop().run_until_complete(
self.embed(documents, input_type=input_type, settings=settings)
)

async def max_input_tokens(self) -> int | None:
model = self._get_model()
return await model.max_input_tokens()

def max_input_tokens_sync(self) -> int | None:
return _utils.get_event_loop().run_until_complete(self.max_input_tokens())

async def count_tokens(self, text: str) -> int:
model = self._get_model()
return await model.count_tokens(text)

def count_tokens_sync(self, text: str) -> int:
return _utils.get_event_loop().run_until_complete(self.count_tokens(text))

def _get_model(self) -> EmbeddingModel:
"""Create a model configured for this embedder.

Returns:
The embedding model to use, with instrumentation applied if configured.
"""
model_: EmbeddingModel
if some_model := self._override_model.get():
model_ = some_model
else:
model_ = self._model = infer_model(self.model)

instrument = self.instrument
if instrument is None:
instrument = self._instrument_default

return instrument_embedding_model(model_, instrument)
73 changes: 73 additions & 0 deletions pydantic_ai_slim/pydantic_ai/embeddings/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence

from .result import EmbeddingResult, EmbedInputType
from .settings import EmbeddingSettings, merge_embedding_settings


class EmbeddingModel(ABC):
"""Abstract class for a model."""

_settings: EmbeddingSettings | None = None

def __init__(
self,
*,
settings: EmbeddingSettings | None = None,
) -> None:
"""Initialize the model with optional settings and profile.

Args:
settings: Model-specific settings that will be used as defaults for this model.
profile: The model profile to use.
"""
self._settings = settings

@property
def settings(self) -> EmbeddingSettings | None:
"""Get the model settings."""
return self._settings

@property
def base_url(self) -> str | None:
"""The base URL for the provider API, if available."""
return None

@property
@abstractmethod
def model_name(self) -> str:
"""The model name."""
raise NotImplementedError()

@property
@abstractmethod
def system(self) -> str:
"""The embedding model provider."""
raise NotImplementedError()

@abstractmethod
async def embed(
self, documents: str | Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
) -> EmbeddingResult:
raise NotImplementedError

def prepare_embed(
self, documents: str | Sequence[str], settings: EmbeddingSettings | None = None
) -> tuple[list[str], EmbeddingSettings]:
"""Prepare the documents and settings for the embedding."""
documents = [documents] if isinstance(documents, str) else list(documents)

settings = merge_embedding_settings(self._settings, settings) or {}

return documents, settings

async def max_input_tokens(self) -> int | None:
"""Get the maximum number of tokens that can be input to the model.

`None` means unknown.
"""
return None

async def count_tokens(self, text: str) -> int:
"""Count the number of tokens in the text."""
raise NotImplementedError
Loading
Loading