Skip to content

Commit 69508ec

Browse files
authored
feat(google): Add embedding support. (#68)
1 parent 210531b commit 69508ec

File tree

4 files changed

+56
-6
lines changed

4 files changed

+56
-6
lines changed

src/any_llm/providers/google/google.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,12 +14,18 @@
1414
from openai._streaming import Stream
1515
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
1616
from openai.types.chat.chat_completion import ChatCompletion
17+
from openai.types import CreateEmbeddingResponse
1718
from any_llm.provider import Provider, ApiConfig
1819
from any_llm.exceptions import MissingApiKeyError, UnsupportedParameterError
1920
from any_llm.providers.helpers import (
2021
create_completion_from_response,
2122
)
22-
from any_llm.providers.google.utils import _convert_tool_spec, _convert_messages, _create_openai_chunk_from_google_chunk
23+
from any_llm.providers.google.utils import (
24+
_convert_tool_spec,
25+
_convert_messages,
26+
_create_openai_chunk_from_google_chunk,
27+
_create_openai_embedding_response_from_google,
28+
)
2329

2430

2531
class GoogleProvider(Provider):
@@ -29,7 +35,7 @@ class GoogleProvider(Provider):
2935
PROVIDER_DOCUMENTATION_URL = "https://cloud.google.com/vertex-ai/docs"
3036

3137
SUPPORTS_STREAMING = True
32-
SUPPORTS_EMBEDDING = False
38+
SUPPORTS_EMBEDDING = True
3339

3440
def __init__(self, config: ApiConfig) -> None:
3541
"""Initialize Google GenAI provider."""
@@ -51,6 +57,20 @@ def __init__(self, config: ApiConfig) -> None:
5157

5258
self.client = genai.Client(api_key=api_key)
5359

60+
def embedding(
61+
self,
62+
model: str,
63+
inputs: str | list[str],
64+
**kwargs: Any,
65+
) -> CreateEmbeddingResponse:
66+
result = self.client.models.embed_content(
67+
model=model,
68+
contents=inputs, # type: ignore[arg-type]
69+
**kwargs,
70+
)
71+
72+
return _create_openai_embedding_response_from_google(model, result)
73+
5474
def verify_kwargs(self, kwargs: dict[str, Any]) -> None:
5575
"""Verify the kwargs for the Google provider."""
5676
if kwargs.get("stream", False) and kwargs.get("response_format", None) is not None:

src/any_llm/providers/google/utils.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22
from time import time
33
from typing import Any
44

5+
from openai.types import CreateEmbeddingResponse
6+
from openai.types.embedding import Embedding
7+
from openai.types.create_embedding_response import Usage
58
from openai.types.chat.chat_completion_chunk import (
69
ChatCompletionChunk,
710
Choice,
@@ -92,6 +95,32 @@ def _convert_messages(messages: list[dict[str, Any]]) -> list[types.Content]:
9295
return formatted_messages
9396

9497

98+
def _create_openai_embedding_response_from_google(
99+
model: str, result: types.EmbedContentResponse
100+
) -> CreateEmbeddingResponse:
101+
"""Convert a Google embedding response to an OpenAI-compatible format."""
102+
103+
data = [
104+
Embedding(
105+
embedding=embedding.values,
106+
index=i,
107+
object="embedding",
108+
)
109+
for i, embedding in enumerate(result.embeddings or [])
110+
if embedding.values
111+
]
112+
113+
# Google does not provide usage data in the embedding response
114+
usage = Usage(prompt_tokens=0, total_tokens=0)
115+
116+
return CreateEmbeddingResponse(
117+
data=data,
118+
model=model,
119+
object="list",
120+
usage=usage,
121+
)
122+
123+
95124
def _create_openai_chunk_from_google_chunk(
96125
response: types.GenerateContentResponse,
97126
) -> ChatCompletionChunk:

tests/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def embedding_provider_model_map() -> dict[ProviderName, str]:
4646
ProviderName.AWS: "amazon.titan-embed-text-v2:0",
4747
ProviderName.OLLAMA: "llama3.2:1b",
4848
ProviderName.LMSTUDIO: "text-embedding-nomic-embed-text-v1.5",
49+
ProviderName.GOOGLE: "gemini-embedding-001",
4950
}
5051

5152

tests/integration/test_embedding.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,11 @@ def test_embedding_providers(provider: ProviderName, embedding_provider_model_ma
2727
if "model" in str(e).lower() or "embedding" in str(e).lower():
2828
pytest.skip(f"{provider.value} embedding model not available: {e}")
2929
raise
30-
# Verify result is a list of floats
3130
assert isinstance(result, CreateEmbeddingResponse)
3231
assert len(result.data) > 0
33-
assert all(isinstance(x.embedding, list) for x in result.data)
34-
# LM Studio follows OpenAI Spec but doesn't output token use
35-
if provider not in ProviderName.LMSTUDIO:
32+
for entry in result.data:
33+
assert all(isinstance(v, float) for v in entry.embedding)
34+
# These providers don't output token use
35+
if provider not in (ProviderName.GOOGLE, ProviderName.LMSTUDIO):
3636
assert result.usage.prompt_tokens > 0
3737
assert result.usage.total_tokens > 0

0 commit comments

Comments
 (0)