Skip to content

Commit 00d8e26

Browse files
committed
Progress is made
1 parent 467bb8e commit 00d8e26

File tree

13 files changed

+5141
-145
lines changed

13 files changed

+5141
-145
lines changed

pydantic_ai_slim/pydantic_ai/embeddings/__init__.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,28 @@
1-
from collections.abc import Iterator, Sequence
1+
from collections.abc import Callable, Iterator, Sequence
22
from contextlib import contextmanager
33
from contextvars import ContextVar
44
from dataclasses import dataclass
5-
from typing import Literal, overload
5+
from typing import Any, Literal, get_args, overload
66

77
from typing_extensions import TypeAliasType
88

99
from pydantic_ai import _utils
10-
from pydantic_ai.embeddings.embedding_model import EmbeddingModel
10+
from pydantic_ai.embeddings.base import EmbeddingModel
1111
from pydantic_ai.embeddings.settings import EmbeddingSettings, merge_embedding_settings
1212
from pydantic_ai.exceptions import UserError
13+
from pydantic_ai.models import OpenAIChatCompatibleProvider
1314
from pydantic_ai.models.instrumented import InstrumentationSettings
14-
from pydantic_ai.providers import infer_provider
15+
from pydantic_ai.providers import Provider, infer_provider
16+
17+
__all__ = [
18+
'Embedder',
19+
'EmbeddingModel',
20+
'EmbeddingSettings',
21+
'merge_embedding_settings',
22+
'KnownEmbeddingModelName',
23+
'OpenAIEmbeddingsCompatibleProvider',
24+
'infer_model',
25+
]
1526

1627
KnownEmbeddingModelName = TypeAliasType(
1728
'KnownEmbeddingModelName',
@@ -21,13 +32,20 @@
2132
'openai:text-embedding-3-largecohere:embed-v4.0',
2233
],
2334
)
24-
"""Known model names that can be used with the `model` parameter of [`Agent`][pydantic_ai.Agent].
35+
"""Known model names that can be used with the `model` parameter of [`Embedder`][pydantic_ai.embeddings.Embedder].
2536
26-
`KnownModelName` is provided as a concise way to specify a model.
37+
`KnownEmbeddingModelName` is provided as a concise way to specify an embedding model.
2738
"""
2839

40+
# For now, we assume that every chat completions-compatible provider also supports the embeddings endpoint.
41+
OpenAIEmbeddingsCompatibleProvider = OpenAIChatCompatibleProvider
42+
2943

30-
def infer_model(model: EmbeddingModel | KnownEmbeddingModelName | str) -> EmbeddingModel:
44+
def infer_model(
45+
model: EmbeddingModel | KnownEmbeddingModelName | str,
46+
*,
47+
provider_factory: Callable[[str], Provider[Any]] = infer_provider,
48+
) -> EmbeddingModel:
3149
"""Infer the model from the name."""
3250
if isinstance(model, EmbeddingModel):
3351
return model
@@ -37,14 +55,15 @@ def infer_model(model: EmbeddingModel | KnownEmbeddingModelName | str) -> Embedd
3755
except ValueError as e:
3856
raise ValueError('You must provide a provider prefix when specifying an embedding model name') from e
3957

40-
provider = infer_provider(provider_name)
58+
provider = provider_factory(provider_name)
4159

4260
model_kind = provider_name
4361
if model_kind.startswith('gateway/'):
44-
model_kind = provider_name.removeprefix('gateway/')
62+
from ..providers.gateway import normalize_gateway_provider
4563

46-
# TODO: extend the following list for other providers as appropriate
47-
if model_kind in ('openai',):
64+
model_kind = normalize_gateway_provider(model_kind)
65+
66+
if model_kind in get_args(OpenAIEmbeddingsCompatibleProvider.__value__):
4867
model_kind = 'openai'
4968

5069
if model_kind == 'openai':
@@ -59,8 +78,10 @@ def infer_model(model: EmbeddingModel | KnownEmbeddingModelName | str) -> Embedd
5978
raise UserError(f'Unknown embeddings model: {model}') # pragma: no cover
6079

6180

62-
@dataclass
81+
@dataclass(init=False)
6382
class Embedder:
83+
"""TODO: Docstring."""
84+
6485
instrument: InstrumentationSettings | bool | None
6586
"""Options to automatically instrument with OpenTelemetry."""
6687

pydantic_ai_slim/pydantic_ai/embeddings/embedding_model.py renamed to pydantic_ai_slim/pydantic_ai/embeddings/base.py

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections.abc import Sequence
33
from typing import overload
44

5-
from pydantic_ai.embeddings.settings import EmbeddingSettings
5+
from pydantic_ai.embeddings.settings import EmbeddingSettings, merge_embedding_settings
66

77

88
class EmbeddingModel(ABC):
@@ -28,18 +28,22 @@ def settings(self) -> EmbeddingSettings | None:
2828
"""Get the model settings."""
2929
return self._settings
3030

31+
@property
32+
def base_url(self) -> str | None:
33+
"""The base URL for the provider API, if available."""
34+
return None
35+
3136
@property
3237
@abstractmethod
3338
def model_name(self) -> str:
3439
"""The model name."""
3540
raise NotImplementedError()
3641

37-
# TODO: Add system?
38-
3942
@property
40-
def base_url(self) -> str | None:
41-
"""The base URL for the provider API, if available."""
42-
return None
43+
@abstractmethod
44+
def system(self) -> str:
45+
"""The embedding model provider."""
46+
raise NotImplementedError()
4347

4448
@overload
4549
async def embed(self, documents: str, *, settings: EmbeddingSettings | None = None) -> list[float]:
@@ -49,7 +53,20 @@ async def embed(self, documents: str, *, settings: EmbeddingSettings | None = No
4953
async def embed(self, documents: Sequence[str], *, settings: EmbeddingSettings | None = None) -> list[list[float]]:
5054
pass
5155

56+
@abstractmethod
5257
async def embed(
5358
self, documents: str | Sequence[str], *, settings: EmbeddingSettings | None = None
5459
) -> list[float] | list[list[float]]:
5560
raise NotImplementedError
61+
62+
def prepare_embed(
63+
self, documents: str | Sequence[str], settings: EmbeddingSettings | None = None
64+
) -> tuple[Sequence[str], bool, EmbeddingSettings]:
65+
"""Prepare the documents and settings for the embedding."""
66+
is_single_document = isinstance(documents, str)
67+
if is_single_document:
68+
documents = [documents]
69+
70+
settings = merge_embedding_settings(self._settings, settings) or {}
71+
72+
return documents, is_single_document, settings

pydantic_ai_slim/pydantic_ai/embeddings/cohere.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
from collections.abc import Sequence
22
from dataclasses import dataclass, field
3-
from typing import Literal, cast, overload
3+
from typing import Any, Literal, cast, overload
44

5-
from pydantic_ai.embeddings.embedding_model import EmbeddingModel
5+
from pydantic_ai.embeddings.base import EmbeddingModel
66
from pydantic_ai.embeddings.settings import EmbeddingSettings
77
from pydantic_ai.providers import Provider, infer_provider
88

9-
from .settings import merge_embedding_settings
10-
119
try:
1210
from cohere import AsyncClientV2
11+
from cohere.core.request_options import RequestOptions
12+
from cohere.types.embed_input_type import EmbedInputType
1313
except ImportError as _import_error:
1414
raise ImportError(
1515
'Please install `cohere` to use the Cohere embeddings model, '
@@ -18,16 +18,36 @@
1818

1919
LatestCohereEmbeddingModelNames = Literal[
2020
'cohere:embed-v4.0',
21-
# TODO: Add the others
21+
'embed-english-v3.0embed-english-light-v3.0',
22+
'embed-multilingual-v3.0',
23+
'embed-multilingual-light-v3.0',
2224
]
2325
"""Latest Cohere embeddings models."""
2426

2527
CohereEmbeddingModelName = str | LatestCohereEmbeddingModelNames
2628
"""Possible Cohere embeddings model names."""
2729

2830

31+
class CohereEmbeddingSettings(EmbeddingSettings, total=False):
32+
"""Settings used for a Cohere embedding model request."""
33+
34+
# ALL FIELDS MUST BE `cohere_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
35+
36+
# TODO: Possibly move to base EmbeddingSettings if supported by more providers
37+
cohere_max_tokens: int
38+
"""The maximum number of tokens to generate before stopping."""
39+
40+
# We don't support `embedding_types for now because it doesn't affect the user-facing API today..
41+
# cohere_embedding_types: Literal["float", "int8", "uint8", "binary", "ubinary", "base64"]
42+
43+
cohere_input_type: EmbedInputType
44+
"""The input type of the embedding."""
45+
46+
2947
@dataclass(init=False)
3048
class CohereEmbeddingModel(EmbeddingModel):
49+
"""Cohere embedding model."""
50+
3151
_model_name: CohereEmbeddingModelName = field(repr=False)
3252
_provider: Provider[AsyncClientV2] = field(repr=False)
3353

@@ -42,11 +62,10 @@ def __init__(
4262
4363
Args:
4464
model_name: The name of the Cohere model to use. List of model names
45-
available [here](https://docs.cohere.com/docs/models#command).
65+
available [here](https://docs.cohere.com/docs/cohere-embed).
4666
provider: The provider to use for authentication and API access. Can be either the string
4767
'cohere' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
4868
created using the other parameters.
49-
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
5069
settings: Model-specific settings that will be used as defaults for this model.
5170
"""
5271
self._model_name = model_name
@@ -84,21 +103,26 @@ async def embed(self, documents: Sequence[str], *, settings: EmbeddingSettings |
84103
async def embed(
85104
self, documents: Sequence[str], *, settings: EmbeddingSettings | None = None
86105
) -> list[float] | list[list[float]]:
87-
input_is_string = isinstance(documents, str)
88-
if input_is_string:
89-
documents = [documents]
106+
documents, is_single_document, settings = self.prepare_embed(documents, settings)
107+
embeddings = await self._embed(documents, cast(CohereEmbeddingSettings, settings))
108+
return embeddings[0] if is_single_document else embeddings
109+
110+
async def _embed(self, documents: Sequence[str], settings: CohereEmbeddingSettings) -> list[list[float]]:
111+
request_options = RequestOptions()
112+
if extra_headers := settings.get('extra_headers'):
113+
request_options['additional_headers'] = extra_headers
114+
if extra_body := settings.get('extra_body'):
115+
request_options['additional_body_parameters'] = cast(dict[str, Any], extra_body)
90116

91-
settings = merge_embedding_settings(self._settings, settings) or {}
92117
response = await self._client.embed(
93118
model=self.model_name,
94-
input_type=settings.get('input_type', 'search_document'),
95-
texts=cast(Sequence[str], documents),
96-
output_dimension=settings.get('output_dimension'),
119+
texts=documents,
120+
output_dimension=settings.get('dimensions'),
121+
input_type=settings.get('cohere_input_type', 'search_document'),
122+
max_tokens=settings.get('cohere_max_tokens'),
123+
request_options=request_options,
97124
)
98125
embeddings = response.embeddings.float_
99126
assert embeddings is not None, 'This is a bug in cohere?'
100127

101-
if input_is_string:
102-
return embeddings[0]
103-
104128
return embeddings

pydantic_ai_slim/pydantic_ai/embeddings/openai.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,11 @@
22
from dataclasses import dataclass, field
33
from typing import Literal, overload
44

5-
from pydantic_ai.embeddings.embedding_model import EmbeddingModel
5+
from pydantic_ai.embeddings.base import EmbeddingModel
66
from pydantic_ai.embeddings.settings import EmbeddingSettings
77
from pydantic_ai.providers import Provider, infer_provider
88

9-
from .settings import merge_embedding_settings
9+
from . import OpenAIEmbeddingsCompatibleProvider
1010

1111
try:
1212
from openai import NOT_GIVEN, AsyncOpenAI
@@ -21,27 +21,34 @@
2121
"""Possible OpenAI embeddings model names."""
2222

2323

24+
class OpenAIEmbeddingSettings(EmbeddingSettings, total=False):
25+
"""Settings used for an OpenAI embedding model request."""
26+
27+
# ALL FIELDS MUST BE `openai_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS.
28+
29+
2430
@dataclass(init=False)
2531
class OpenAIEmbeddingModel(EmbeddingModel):
32+
"""OpenAI embedding model."""
33+
2634
_model_name: OpenAIEmbeddingModelName = field(repr=False)
2735
_provider: Provider[AsyncOpenAI] = field(repr=False)
2836

2937
def __init__(
3038
self,
3139
model_name: OpenAIEmbeddingModelName,
3240
*,
33-
provider: Literal['openai'] | Provider[AsyncOpenAI] = 'openai',
41+
provider: OpenAIEmbeddingsCompatibleProvider | Literal['openai'] | Provider[AsyncOpenAI] = 'openai',
3442
settings: EmbeddingSettings | None = None,
3543
):
3644
"""Initialize an OpenAI model.
3745
3846
Args:
3947
model_name: The name of the OpenAI model to use. List of model names
40-
available [here](https://docs.OpenAI.com/docs/models#command).
48+
available [here](https://platform.openai.com/docs/guides/embeddings#embedding-models).
4149
provider: The provider to use for authentication and API access. Can be either the string
4250
'OpenAI' or an instance of `Provider[AsyncClientV2]`. If not provided, a new provider will be
4351
created using the other parameters.
44-
profile: The model profile to use. Defaults to a profile picked by the provider based on the model name.
4552
settings: Model-specific settings that will be used as defaults for this model.
4653
"""
4754
self._model_name = model_name
@@ -78,18 +85,16 @@ async def embed(self, documents: Sequence[str], *, settings: EmbeddingSettings |
7885
async def embed(
7986
self, documents: str | Sequence[str], *, settings: EmbeddingSettings | None = None
8087
) -> list[float] | list[list[float]]:
81-
input_is_string = isinstance(documents, str)
82-
if input_is_string:
83-
documents = [documents]
88+
documents, is_single_document, settings = self.prepare_embed(documents, settings)
89+
embeddings = await self._embed(documents, settings)
90+
return embeddings[0] if is_single_document else embeddings
8491

85-
settings = merge_embedding_settings(self._settings, settings) or {}
92+
async def _embed(self, documents: Sequence[str], settings: OpenAIEmbeddingSettings) -> list[list[float]]:
8693
response = await self._client.embeddings.create(
8794
input=documents, # pyright: ignore[reportArgumentType] # Sequence[str] not compatible with SequenceNotStr[str] :/
8895
model=self.model_name,
89-
dimensions=settings.get('output_dimension') or NOT_GIVEN,
96+
dimensions=settings.get('dimensions') or NOT_GIVEN,
97+
extra_headers=settings.get('extra_headers'),
98+
extra_body=settings.get('extra_body'),
9099
)
91-
result = [item.embedding for item in response.data]
92-
93-
if input_is_string:
94-
return result[0]
95-
return result
100+
return [item.embedding for item in response.data]

pydantic_ai_slim/pydantic_ai/embeddings/settings.py

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,47 @@
1-
from typing import Literal, TypedDict
1+
from typing_extensions import TypedDict
22

33

44
class EmbeddingSettings(TypedDict, total=False):
5-
# TODO: May want to add extra_headres, extra_query, extra_body, timeout, etc.
5+
"""Settings to configure an embedding model.
66
7-
output_dimension: int
8-
"""The maximum number of tokens to generate before stopping.
7+
Here we include only settings which apply to multiple models / model providers,
8+
though not all of these settings are supported by all models.
9+
"""
10+
11+
dimensions: int
12+
"""The number of dimensions the resulting output embeddings should have.
913
1014
Supported by:
1115
12-
* Cohere
1316
* OpenAI
17+
* Cohere
1418
"""
1519

16-
max_tokens: int
17-
"""The maximum number of tokens to generate before stopping.
20+
extra_headers: dict[str, str]
21+
"""Extra headers to send to the model.
1822
1923
Supported by:
2024
25+
* OpenAI
2126
* Cohere
2227
"""
2328

24-
# We don't support embedding_types for now because it doesn't affect the user-facing API today..
25-
# embedding_types: Literal["float", "int8", "uint8", "binary", "ubinary", "base64"]
26-
27-
input_type: Literal['search_document', 'search_query', 'classification', 'clustering', 'image']
28-
"""The input type of the embedding.
29+
extra_body: object
30+
"""Extra body to send to the model.
2931
3032
Supported by:
3133
32-
* Cohere (See `cohere.EmbedInputType`)
34+
* OpenAI
35+
* Cohere
3336
"""
3437

35-
# TODO: Add more?
36-
3738

3839
def merge_embedding_settings(
3940
base: EmbeddingSettings | None, overrides: EmbeddingSettings | None
4041
) -> EmbeddingSettings | None:
4142
"""Merge two sets of embedding settings, preferring the overrides.
4243
43-
A common use case is: merge_embedding_settings(<agent settings>, <run settings>)
44+
A common use case is: merge_embedding_settings(<embedder settings>, <run settings>)
4445
"""
4546
# Note: we may want merge recursively if/when we add non-primitive values
4647
if base and overrides:

0 commit comments

Comments
 (0)