11from collections .abc import Sequence
22from dataclasses import dataclass , field
3- from typing import Any , Literal , cast , overload
3+ from typing import Any , Literal , cast
44
5- from pydantic_ai .embeddings .base import EmbeddingModel , EmbedInputType
6- from pydantic_ai .embeddings .settings import EmbeddingSettings
75from pydantic_ai .exceptions import UnexpectedModelBehavior
8- from pydantic_ai .providers import infer_provider
6+ from pydantic_ai .providers import Provider , infer_provider
7+ from pydantic_ai .usage import RequestUsage
8+
9+ from .base import EmbeddingModel , EmbedInputType
10+ from .result import EmbeddingResult
11+ from .settings import EmbeddingSettings
912
1013try :
14+ from cohere import AsyncClientV2
1115 from cohere .core .request_options import RequestOptions
12- from cohere .v2 .client import EmbedInputType as CohereEmbedInputType
16+ from cohere .types .embed_by_type_response import EmbedByTypeResponse
17+ from cohere .types .embed_input_type import EmbedInputType as CohereEmbedInputType
1318 from cohere .v2 .types .v2embed_request_truncate import V2EmbedRequestTruncate
1419
1520 from pydantic_ai .providers .cohere import CohereProvider
@@ -73,7 +78,7 @@ def __init__(
7378 self ,
7479 model_name : CohereEmbeddingModelName ,
7580 * ,
76- provider : Literal ['cohere' ] | CohereProvider = 'cohere' ,
81+ provider : Literal ['cohere' ] | Provider [ AsyncClientV2 ] | CohereProvider = 'cohere' ,
7782 settings : EmbeddingSettings | None = None ,
7883 ):
7984 """Initialize an Cohere model.
@@ -92,7 +97,7 @@ def __init__(
9297 provider = infer_provider (provider )
9398 self ._provider = provider
9499 self ._client = provider .client
95- self ._v1_client = provider .v1_client
100+ self ._v1_client = provider .v1_client if isinstance ( provider , CohereProvider ) else None
96101
97102 super ().__init__ (settings = settings )
98103
@@ -111,28 +116,15 @@ def system(self) -> str:
111116 """The embedding model provider."""
112117 return self ._provider .name
113118
114- @overload
115- async def embed (
116- self , documents : str , * , input_type : EmbedInputType , settings : EmbeddingSettings | None = None
117- ) -> list [float ]:
118- pass
119-
120- @overload
121- async def embed (
122- self , documents : Sequence [str ], * , input_type : EmbedInputType , settings : EmbeddingSettings | None = None
123- ) -> list [list [float ]]:
124- pass
125-
126119 async def embed (
127- self , documents : Sequence [str ], * , input_type : EmbedInputType , settings : EmbeddingSettings | None = None
128- ) -> list [float ] | list [list [float ]]:
129- documents , is_single_document , settings = self .prepare_embed (documents , settings )
130- embeddings = await self ._embed (documents , input_type , cast (CohereEmbeddingSettings , settings ))
131- return embeddings [0 ] if is_single_document else embeddings
120+ self , documents : str | Sequence [str ], * , input_type : EmbedInputType , settings : EmbeddingSettings | None = None
121+ ) -> EmbeddingResult :
122+ documents , settings = self .prepare_embed (documents , settings )
123+ return await self ._embed (documents , input_type , cast (CohereEmbeddingSettings , settings ))
132124
133125 async def _embed (
134- self , documents : Sequence [str ], input_type : EmbedInputType , settings : CohereEmbeddingSettings
135- ) -> list [ list [ float ]] :
126+ self , documents : str | Sequence [str ], input_type : EmbedInputType , settings : CohereEmbeddingSettings
127+ ) -> EmbeddingResult :
136128 request_options = RequestOptions ()
137129 if extra_headers := settings .get ('extra_headers' ):
138130 request_options ['additional_headers' ] = extra_headers
@@ -156,10 +148,18 @@ async def _embed(
156148 if embeddings is None :
157149 raise UnexpectedModelBehavior (
158150 'The Cohere embeddings response did not have an `embeddings` field holding a list of floats' ,
159- str ( response . data ) ,
151+ response ,
160152 )
161153
162- return embeddings
154+ return EmbeddingResult (
155+ embeddings = embeddings ,
156+ inputs = documents ,
157+ input_type = input_type ,
158+ usage = _map_usage (response ),
159+ model_name = self .model_name ,
160+ provider_name = self .system ,
161+ provider_response_id = response .id ,
162+ )
163163
164164 async def max_input_tokens (self ) -> int | None :
165165 return _MAX_INPUT_TOKENS .get (self .model_name )
@@ -173,3 +173,17 @@ async def count_tokens(self, text: str) -> int:
173173 offline = False ,
174174 )
175175 return len (result .tokens )
176+
177+
178+ def _map_usage (response : EmbedByTypeResponse ) -> RequestUsage :
179+ u = response .meta
180+ if u is None or u .billed_units is None :
181+ return RequestUsage ()
182+ usage_data = u .billed_units .model_dump (exclude_none = True )
183+ details = {k : int (v ) for k , v in usage_data .items () if k != 'input_tokens' and isinstance (v , int | float ) and v > 0 }
184+
185+ # TODO (DouweM): Use RequestUsage.extract() once https://github.com/pydantic/genai-prices/blob/main/prices/providers/cohere.yml has been updated
186+ return RequestUsage (
187+ input_tokens = int (u .billed_units .input_tokens or 0 ),
188+ details = details ,
189+ )
0 commit comments