Skip to content

Commit 4ec0b32

Browse files
committed
Fix Cohere, SentenceTransformers
1 parent 0e015fa commit 4ec0b32

File tree

6 files changed

+184
-110
lines changed

6 files changed

+184
-110
lines changed

pydantic_ai_slim/pydantic_ai/embeddings/cohere.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
from collections.abc import Sequence
22
from dataclasses import dataclass, field
3-
from typing import Any, Literal, cast, overload
3+
from typing import Any, Literal, cast
44

5-
from pydantic_ai.embeddings.base import EmbeddingModel, EmbedInputType
6-
from pydantic_ai.embeddings.settings import EmbeddingSettings
75
from pydantic_ai.exceptions import UnexpectedModelBehavior
8-
from pydantic_ai.providers import infer_provider
6+
from pydantic_ai.providers import Provider, infer_provider
7+
from pydantic_ai.usage import RequestUsage
8+
9+
from .base import EmbeddingModel, EmbedInputType
10+
from .result import EmbeddingResult
11+
from .settings import EmbeddingSettings
912

1013
try:
14+
from cohere import AsyncClientV2
1115
from cohere.core.request_options import RequestOptions
12-
from cohere.v2.client import EmbedInputType as CohereEmbedInputType
16+
from cohere.types.embed_by_type_response import EmbedByTypeResponse
17+
from cohere.types.embed_input_type import EmbedInputType as CohereEmbedInputType
1318
from cohere.v2.types.v2embed_request_truncate import V2EmbedRequestTruncate
1419

1520
from pydantic_ai.providers.cohere import CohereProvider
@@ -73,7 +78,7 @@ def __init__(
7378
self,
7479
model_name: CohereEmbeddingModelName,
7580
*,
76-
provider: Literal['cohere'] | CohereProvider = 'cohere',
81+
provider: Literal['cohere'] | Provider[AsyncClientV2] | CohereProvider = 'cohere',
7782
settings: EmbeddingSettings | None = None,
7883
):
7984
"""Initialize an Cohere model.
@@ -92,7 +97,7 @@ def __init__(
9297
provider = infer_provider(provider)
9398
self._provider = provider
9499
self._client = provider.client
95-
self._v1_client = provider.v1_client
100+
self._v1_client = provider.v1_client if isinstance(provider, CohereProvider) else None
96101

97102
super().__init__(settings=settings)
98103

@@ -111,28 +116,15 @@ def system(self) -> str:
111116
"""The embedding model provider."""
112117
return self._provider.name
113118

114-
@overload
115-
async def embed(
116-
self, documents: str, *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
117-
) -> list[float]:
118-
pass
119-
120-
@overload
121-
async def embed(
122-
self, documents: Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
123-
) -> list[list[float]]:
124-
pass
125-
126119
async def embed(
127-
self, documents: Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
128-
) -> list[float] | list[list[float]]:
129-
documents, is_single_document, settings = self.prepare_embed(documents, settings)
130-
embeddings = await self._embed(documents, input_type, cast(CohereEmbeddingSettings, settings))
131-
return embeddings[0] if is_single_document else embeddings
120+
self, documents: str | Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
121+
) -> EmbeddingResult:
122+
documents, settings = self.prepare_embed(documents, settings)
123+
return await self._embed(documents, input_type, cast(CohereEmbeddingSettings, settings))
132124

133125
async def _embed(
134-
self, documents: Sequence[str], input_type: EmbedInputType, settings: CohereEmbeddingSettings
135-
) -> list[list[float]]:
126+
self, documents: str | Sequence[str], input_type: EmbedInputType, settings: CohereEmbeddingSettings
127+
) -> EmbeddingResult:
136128
request_options = RequestOptions()
137129
if extra_headers := settings.get('extra_headers'):
138130
request_options['additional_headers'] = extra_headers
@@ -156,10 +148,18 @@ async def _embed(
156148
if embeddings is None:
157149
raise UnexpectedModelBehavior(
158150
'The Cohere embeddings response did not have an `embeddings` field holding a list of floats',
159-
str(response.data),
151+
response,
160152
)
161153

162-
return embeddings
154+
return EmbeddingResult(
155+
embeddings=embeddings,
156+
inputs=documents,
157+
input_type=input_type,
158+
usage=_map_usage(response),
159+
model_name=self.model_name,
160+
provider_name=self.system,
161+
provider_response_id=response.id,
162+
)
163163

164164
async def max_input_tokens(self) -> int | None:
165165
return _MAX_INPUT_TOKENS.get(self.model_name)
@@ -173,3 +173,17 @@ async def count_tokens(self, text: str) -> int:
173173
offline=False,
174174
)
175175
return len(result.tokens)
176+
177+
178+
def _map_usage(response: EmbedByTypeResponse) -> RequestUsage:
179+
u = response.meta
180+
if u is None or u.billed_units is None:
181+
return RequestUsage()
182+
usage_data = u.billed_units.model_dump(exclude_none=True)
183+
details = {k: int(v) for k, v in usage_data.items() if k != 'input_tokens' and isinstance(v, int | float) and v > 0}
184+
185+
# TODO (DouweM): Use RequestUsage.extract() once https://github.com/pydantic/genai-prices/blob/main/prices/providers/cohere.yml has been updated
186+
return RequestUsage(
187+
input_tokens=int(u.billed_units.input_tokens or 0),
188+
details=details,
189+
)

pydantic_ai_slim/pydantic_ai/embeddings/instrumented.py

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,19 @@
11
from __future__ import annotations
22

33
import json
4+
import warnings
45
from collections.abc import Callable, Iterator, Sequence
56
from contextlib import contextmanager
67
from dataclasses import dataclass
7-
from typing import TYPE_CHECKING, Any, overload
8+
from typing import TYPE_CHECKING, Any
89
from urllib.parse import urlparse
910

1011
from opentelemetry.util.types import AttributeValue
1112

12-
from pydantic_ai.models.instrumented import ANY_ADAPTER, InstrumentationSettings
13+
from pydantic_ai.models.instrumented import ANY_ADAPTER, CostCalculationFailedWarning, InstrumentationSettings
1314

1415
from .base import EmbeddingModel, EmbedInputType
16+
from .result import EmbeddingResult
1517
from .settings import EmbeddingSettings
1618
from .wrapper import WrapperEmbeddingModel
1719

@@ -50,19 +52,9 @@ def __init__(
5052
super().__init__(wrapped)
5153
self.instrumentation_settings = options or InstrumentationSettings()
5254

53-
@overload
54-
async def embed(
55-
self, documents: str, *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
56-
) -> list[float]: ...
57-
58-
@overload
59-
async def embed(
60-
self, documents: Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
61-
) -> list[list[float]]: ...
62-
6355
async def embed(
6456
self, documents: str | Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
65-
) -> list[float] | list[list[float]]:
57+
) -> EmbeddingResult:
6658
with self._instrument(documents, input_type, settings) as finish:
6759
result = await self.wrapped.embed(documents, input_type=input_type, settings=settings)
6860
finish(result)
@@ -74,7 +66,7 @@ def _instrument(
7466
documents: str | Sequence[str],
7567
input_type: EmbedInputType,
7668
settings: EmbeddingSettings | None,
77-
) -> Iterator[Callable[[list[float] | list[list[float]]], None]]:
69+
) -> Iterator[Callable[[EmbeddingResult], None]]:
7870
operation = 'embed'
7971
span_name = f'{operation} {self.model_name}'
8072

@@ -111,31 +103,46 @@ def _instrument(
111103
try:
112104
with self.instrumentation_settings.tracer.start_as_current_span(span_name, attributes=attributes) as span:
113105

114-
def finish(result: list[float] | list[list[float]]):
106+
def finish(result: EmbeddingResult):
115107
if not span.is_recording():
116108
return
117109

118-
# Calculate output dimension
119-
if isinstance(result, list) and result:
120-
if isinstance(result[0], list):
121-
# Multiple embeddings
122-
output_dim = len(result[0]) if result[0] else 0
123-
num_outputs = len(result)
124-
else:
125-
# Single embedding
126-
output_dim = len(result)
127-
num_outputs = 1
110+
attributes_to_set: dict[str, AttributeValue] = {
111+
**result.usage.opentelemetry_attributes(),
112+
'gen_ai.response.model': result.model_name or self.model_name,
113+
}
114+
115+
try:
116+
price_calculation = result.cost()
117+
except LookupError:
118+
# The cost of this provider/model is unknown, which is common.
119+
pass
120+
except Exception as e:
121+
warnings.warn(
122+
f'Failed to get cost from response: {type(e).__name__}: {e}', CostCalculationFailedWarning
123+
)
128124
else:
129-
output_dim = 0
130-
num_outputs = 0
125+
attributes_to_set['operation.cost'] = float(price_calculation.total_price)
126+
127+
# Calculate output dimension
128+
embeddings = result.embeddings
129+
if embeddings:
130+
output_dim = len(embeddings[0]) if embeddings[0] else 0
131+
num_outputs = len(embeddings)
132+
133+
attributes_to_set.update(
134+
{
135+
'gen_ai.embedding.dimension': output_dim,
136+
'gen_ai.embedding.num_outputs': num_outputs,
137+
}
138+
)
139+
140+
if result.provider_response_id is not None:
141+
attributes_to_set['gen_ai.response.id'] = result.provider_response_id
131142

132-
attributes_to_set = {
133-
'gen_ai.embedding.dimension': output_dim,
134-
'gen_ai.embedding.num_outputs': num_outputs,
135-
}
136143
span.set_attributes(attributes_to_set)
137144

138-
# TODO (DouweM): Include cost as metric etc, just like on InstrumentedModel
145+
# TODO (DouweM): Record cost metric
139146

140147
yield finish
141148
finally:

pydantic_ai_slim/pydantic_ai/embeddings/result.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ class EmbeddingResult:
3636

3737
provider_response_id: str | None = None
3838

39+
# TODO (DouweM): Support `result[idx: int]` and `result[document: str]`
40+
3941
def cost(self) -> genai_types.PriceCalculation:
4042
"""Calculate the cost of the usage.
4143

pydantic_ai_slim/pydantic_ai/embeddings/sentence_transformers.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,15 @@
22

33
from collections.abc import Sequence
44
from dataclasses import dataclass, field
5-
from typing import Any, cast, overload
5+
from typing import Any, cast
66

77
import pydantic_ai._utils as _utils
8-
from pydantic_ai.embeddings.base import EmbeddingModel, EmbedInputType
9-
from pydantic_ai.embeddings.settings import EmbeddingSettings
108
from pydantic_ai.exceptions import UnexpectedModelBehavior
119

10+
from .base import EmbeddingModel, EmbedInputType
11+
from .result import EmbeddingResult
12+
from .settings import EmbeddingSettings
13+
1214
try:
1315
import numpy as np
1416
import torch
@@ -73,26 +75,18 @@ def system(self) -> str:
7375
"""The embedding model provider/system identifier."""
7476
return 'sentence-transformers'
7577

76-
@overload
77-
async def embed(
78-
self, documents: str, *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
79-
) -> list[float]: ...
80-
81-
@overload
82-
async def embed(
83-
self, documents: Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
84-
) -> list[list[float]]: ...
85-
8678
async def embed(
8779
self, documents: str | Sequence[str], *, input_type: EmbedInputType, settings: EmbeddingSettings | None = None
8880
) -> list[float] | list[list[float]]:
89-
docs, is_single_document, settings = self.prepare_embed(documents, settings)
90-
embeddings = await self._embed(docs, input_type, cast(SentenceTransformersEmbeddingSettings, settings))
91-
return embeddings[0] if is_single_document else embeddings
81+
docs, settings = self.prepare_embed(documents, settings)
82+
return await self._embed(docs, input_type, cast(SentenceTransformersEmbeddingSettings, settings))
9283

9384
async def _embed(
94-
self, documents: Sequence[str], input_type: EmbedInputType, settings: SentenceTransformersEmbeddingSettings
95-
) -> list[list[float]]:
85+
self,
86+
documents: str | Sequence[str],
87+
input_type: EmbedInputType,
88+
settings: SentenceTransformersEmbeddingSettings,
89+
) -> EmbeddingResult:
9690
device = settings.get('sentence_transformers_device', None)
9791
normalize = settings.get('sentence_transformers_normalize_embeddings', False)
9892
batch_size = settings.get('sentence_transformers_batch_size', None)
@@ -111,7 +105,15 @@ async def _embed(
111105
normalize_embeddings=normalize,
112106
**{'batch_size': batch_size} if batch_size is not None else {}, # type: ignore[reportArgumentType]
113107
)
114-
return np_embeddings.tolist() # type: ignore[reportUnknownReturnType]
108+
embeddings = np_embeddings.tolist() # type: ignore[reportAttributeAccessIssue]
109+
110+
return EmbeddingResult(
111+
embeddings=embeddings, # type: ignore[reportUnknownArgumentType]
112+
inputs=documents,
113+
input_type=input_type,
114+
model_name=self.model_name,
115+
provider_name=self.system,
116+
)
115117

116118
async def max_input_tokens(self) -> int | None:
117119
model = await self._get_model()

tests/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
'IsBytes',
3535
'IsInt',
3636
'IsInstance',
37+
'IsList',
3738
'TestEnv',
3839
'ClientWithHandler',
3940
'try_import',
@@ -62,8 +63,9 @@ def IsNow(*args: Any, **kwargs: Any) -> datetime: ...
6263
def IsStr(*args: Any, **kwargs: Any) -> str: ...
6364
def IsSameStr(*args: Any, **kwargs: Any) -> str: ...
6465
def IsBytes(*args: Any, **kwargs: Any) -> bytes: ...
66+
def IsList(*args: T, **kwargs: Any) -> list[T]: ...
6567
else:
66-
from dirty_equals import IsBytes, IsDatetime, IsFloat, IsInstance, IsInt, IsNow as _IsNow, IsStr
68+
from dirty_equals import IsBytes, IsDatetime, IsFloat, IsInstance, IsInt, IsList, IsNow as _IsNow, IsStr
6769

6870
def IsNow(*args: Any, **kwargs: Any):
6971
# Increase the default value of `delta` to 10 to reduce test flakiness on overburdened machines

0 commit comments

Comments
 (0)