Skip to content

Commit 833a982

Browse files
authored
[ENH] Add google genai embedding function (#5836)
## Description of changes _Summarize the changes made by this PR._ - Improvements & Bug fixes - this pr adds google genai embedding function since the other packages are deprecated now fixes #5177 - New functionality - ... ## Test plan _How are these changes tested?_ - [ ] Tests pass locally with `pytest` for python, `yarn test` for js, `cargo test` for rust ## Migration plan _Are there any migrations, or any forwards/backwards compatibility changes needed in order to make sure this change deploys reliably?_ ## Observability plan _What is the plan to instrument and monitor this change?_ ## Documentation Changes _Are all docstrings for user-facing APIs updated if required? Do we need to make documentation changes in the [docs section](https://github.com/chroma-core/chroma/tree/main/docs/docs.trychroma.com)?_
1 parent dfbb6c2 commit 833a982

File tree

4 files changed

+187
-0
lines changed

4 files changed

+187
-0
lines changed

chromadb/test/ef/test_ef.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def test_get_builtins_holds() -> None:
3333
"GoogleGenerativeAiEmbeddingFunction",
3434
"GooglePalmEmbeddingFunction",
3535
"GoogleVertexEmbeddingFunction",
36+
"GoogleGenaiEmbeddingFunction",
3637
"HuggingFaceEmbeddingFunction",
3738
"HuggingFaceEmbeddingServer",
3839
"InstructorEmbeddingFunction",

chromadb/utils/embedding_functions/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
GooglePalmEmbeddingFunction,
2424
GoogleGenerativeAiEmbeddingFunction,
2525
GoogleVertexEmbeddingFunction,
26+
GoogleGenaiEmbeddingFunction,
2627
)
2728
from chromadb.utils.embedding_functions.ollama_embedding_function import (
2829
OllamaEmbeddingFunction,
@@ -102,6 +103,7 @@
102103
"GooglePalmEmbeddingFunction",
103104
"GoogleGenerativeAiEmbeddingFunction",
104105
"GoogleVertexEmbeddingFunction",
106+
"GoogleGenaiEmbeddingFunction",
105107
"OllamaEmbeddingFunction",
106108
"InstructorEmbeddingFunction",
107109
"JinaEmbeddingFunction",
@@ -142,6 +144,7 @@ def get_builtins() -> Set[str]:
142144
"google_palm": GooglePalmEmbeddingFunction,
143145
"google_generative_ai": GoogleGenerativeAiEmbeddingFunction,
144146
"google_vertex": GoogleVertexEmbeddingFunction,
147+
"google_genai": GoogleGenaiEmbeddingFunction,
145148
"ollama": OllamaEmbeddingFunction,
146149
"instructor": InstructorEmbeddingFunction,
147150
"jina": JinaEmbeddingFunction,
@@ -265,6 +268,7 @@ def config_to_embedding_function(config: Dict[str, Any]) -> EmbeddingFunction:
265268
"GooglePalmEmbeddingFunction",
266269
"GoogleGenerativeAiEmbeddingFunction",
267270
"GoogleVertexEmbeddingFunction",
271+
"GoogleGenaiEmbeddingFunction",
268272
"OllamaEmbeddingFunction",
269273
"InstructorEmbeddingFunction",
270274
"JinaEmbeddingFunction",

chromadb/utils/embedding_functions/google_embedding_function.py

Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,150 @@
77
import warnings
88

99

10+
class GoogleGenaiEmbeddingFunction(EmbeddingFunction[Documents]):
11+
def __init__(
12+
self,
13+
model_name: str,
14+
vertexai: Optional[bool] = None,
15+
project: Optional[str] = None,
16+
location: Optional[str] = None,
17+
api_key_env_var: str = "GOOGLE_API_KEY",
18+
):
19+
"""
20+
Initialize the GoogleGenaiEmbeddingFunction.
21+
22+
Args:
23+
model_name (str): The name of the model to use for text embeddings.
24+
api_key_env_var (str, optional): Environment variable name that contains your API key for the Google GenAI API.
25+
Defaults to "GOOGLE_API_KEY".
26+
"""
27+
try:
28+
import google.genai as genai
29+
except ImportError:
30+
raise ValueError(
31+
"The google-genai python package is not installed. Please install it with `pip install google-genai`"
32+
)
33+
34+
self.model_name = model_name
35+
self.api_key_env_var = api_key_env_var
36+
self.vertexai = vertexai
37+
self.project = project
38+
self.location = location
39+
self.api_key = os.getenv(self.api_key_env_var)
40+
if not self.api_key:
41+
raise ValueError(
42+
f"The {self.api_key_env_var} environment variable is not set."
43+
)
44+
45+
self.client = genai.Client(
46+
api_key=self.api_key, vertexai=vertexai, project=project, location=location
47+
)
48+
49+
def __call__(self, input: Documents) -> Embeddings:
50+
"""
51+
Generate embeddings for the given documents.
52+
53+
Args:
54+
input: Documents or images to generate embeddings for.
55+
56+
Returns:
57+
Embeddings for the documents.
58+
"""
59+
if not input:
60+
raise ValueError("Input documents cannot be empty")
61+
if not isinstance(input, (list, tuple)):
62+
raise ValueError("Input must be a list or tuple of documents")
63+
if not all(isinstance(doc, str) for doc in input):
64+
raise ValueError("All input documents must be strings")
65+
66+
try:
67+
response = self.client.models.embed_content(
68+
model=self.model_name, contents=input
69+
)
70+
except Exception as e:
71+
raise ValueError(f"Failed to generate embeddings: {str(e)}") from e
72+
73+
# Validate response structure
74+
if not hasattr(response, "embeddings") or not response.embeddings:
75+
raise ValueError("No embeddings returned from the API")
76+
77+
embeddings_list = []
78+
for ce in response.embeddings:
79+
if not hasattr(ce, "values"):
80+
raise ValueError("Malformed embedding response: missing 'values'")
81+
embeddings_list.append(np.array(ce.values, dtype=np.float32))
82+
83+
return cast(Embeddings, embeddings_list)
84+
85+
@staticmethod
86+
def name() -> str:
87+
return "google_genai"
88+
89+
def default_space(self) -> Space:
90+
return "cosine"
91+
92+
def supported_spaces(self) -> List[Space]:
93+
return ["cosine", "l2", "ip"]
94+
95+
@staticmethod
96+
def build_from_config(config: Dict[str, Any]) -> "EmbeddingFunction[Documents]":
97+
model_name = config.get("model_name")
98+
vertexai = config.get("vertexai")
99+
project = config.get("project")
100+
location = config.get("location")
101+
102+
if model_name is None:
103+
raise ValueError("The model name is required.")
104+
105+
return GoogleGenaiEmbeddingFunction(
106+
model_name=model_name,
107+
vertexai=vertexai,
108+
project=project,
109+
location=location,
110+
)
111+
112+
def get_config(self) -> Dict[str, Any]:
113+
return {
114+
"model_name": self.model_name,
115+
"vertexai": self.vertexai,
116+
"project": self.project,
117+
"location": self.location,
118+
}
119+
120+
def validate_config_update(
121+
self, old_config: Dict[str, Any], new_config: Dict[str, Any]
122+
) -> None:
123+
if "model_name" in new_config:
124+
raise ValueError(
125+
"The model name cannot be changed after the embedding function has been initialized."
126+
)
127+
if "vertexai" in new_config:
128+
raise ValueError(
129+
"The vertexai cannot be changed after the embedding function has been initialized."
130+
)
131+
if "project" in new_config:
132+
raise ValueError(
133+
"The project cannot be changed after the embedding function has been initialized."
134+
)
135+
if "location" in new_config:
136+
raise ValueError(
137+
"The location cannot be changed after the embedding function has been initialized."
138+
)
139+
140+
@staticmethod
141+
def validate_config(config: Dict[str, Any]) -> None:
142+
"""
143+
Validate the configuration using the JSON schema.
144+
145+
Args:
146+
config: Configuration to validate
147+
148+
Raises:
149+
ValidationError: If the configuration does not match the schema
150+
"""
151+
validate_config_schema(config, "google_genai")
152+
153+
10154
class GooglePalmEmbeddingFunction(EmbeddingFunction[Documents]):
11155
"""To use this EmbeddingFunction, you must have the google.generativeai Python package installed and have a PaLM API key."""
12156

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
{
2+
"$schema": "http://json-schema.org/draft-07/schema#",
3+
"title": "Google GenAI Embedding Function Schema",
4+
"description": "Schema for the Google GenAI embedding function configuration",
5+
"version": "1.0.0",
6+
"type": "object",
7+
"properties": {
8+
"model_name": {
9+
"type": "string",
10+
"description": "The name of the model to use for text embeddings"
11+
},
12+
"vertexai": {
13+
"type": [
14+
"boolean",
15+
"null"
16+
],
17+
"description": "Whether to use Vertex AI"
18+
},
19+
"project": {
20+
"type": [
21+
"string",
22+
"null"
23+
],
24+
"description": "The Google Cloud project ID (required for Vertex AI)"
25+
},
26+
"location": {
27+
"type": [
28+
"string",
29+
"null"
30+
],
31+
"description": "The Google Cloud location/region (required for Vertex AI)"
32+
}
33+
},
34+
"required": [
35+
"model_name"
36+
],
37+
"additionalProperties": false
38+
}

0 commit comments

Comments
 (0)