Skip to content

Commit 9555678

Browse files
Enhance Google Gemini integration
- Added support for Google Gemini embeddings and LLMs in the README. - Updated example scripts to assert the presence of the GOOGLE_API_KEY environment variable. - Modified the default embedding model and dimension to use environment variables for better configurability. - Improved error handling in the GeminiEmbedder class to ensure valid responses from the API.
1 parent 5dfc9d5 commit 9555678

File tree

4 files changed

+15
-9
lines changed

4 files changed

+15
-9
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ the extra dependencies described below:
5050
- **ollama**: LLMs from Ollama
5151
- **openai**: LLMs from OpenAI (including AzureOpenAI)
5252
- **google**: LLMs from Vertex AI
53+
- **google-genai**: LLMs and embeddings from Google Gemini
5354
- **cohere**: LLMs from Cohere
5455
- **anthropic**: LLMs from Anthropic
5556
- **mistralai**: LLMs from MistralAI

examples/customize/embeddings/google_genai_embeddings.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import os
2+
13
from neo4j_graphrag.embeddings import GeminiEmbedder
24

3-
# set api key here on in the GOOGLE_API_KEY env var
4-
api_key = None
5+
api_key = os.getenv("GOOGLE_API_KEY")
6+
assert api_key is not None, "you must set GOOGLE_API_KEY to run this experiment"
57

68
embedder = GeminiEmbedder(
79
model="gemini-embedding-001",

examples/customize/llms/google_genai_llm.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import os
2+
13
from neo4j_graphrag.llm import GeminiLLM
24

3-
# set api key here on in the GOOGLE_API_KEY env var
4-
api_key = None
5+
api_key = os.getenv("GOOGLE_API_KEY")
6+
assert api_key is not None, "you must set GOOGLE_API_KEY to run this experiment"
57

68
llm = GeminiLLM(
79
model_name="gemini-2.5-flash",

src/neo4j_graphrag/embeddings/google_genai.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
# built-in dependencies
18+
import os
1819
from typing import Any, Optional
1920

2021
# project dependencies
@@ -33,15 +34,15 @@
3334
genai = None
3435
types = None
3536

36-
DEFAULT_EMBEDDING_MODEL = "text-embedding-004"
37-
DEFAULT_EMBEDDING_DIM = 768
37+
DEFAULT_EMBEDDING_MODEL = os.getenv("GOOGLE_GENAI_EMB_MODEL", "gemini-embedding-001")
38+
DEFAULT_EMBEDDING_DIM = int(os.getenv("GOOGLE_GENAI_EMB_DIM", "768"))
3839

3940

4041
class GeminiEmbedder(Embedder):
4142
"""Embedder that uses Google's Gemini API via the google.genai SDK.
4243
4344
Args:
44-
model: Embedding model name. Defaults to "text-embedding-004".
45+
model: Embedding model name. Defaults to "gemini-embedding-001".
4546
embedding_dim: Output dimensionality. Defaults to 768.
4647
rate_limit_handler: Optional rate limit handler.
4748
**kwargs: Arguments passed to the genai.Client.
@@ -75,7 +76,7 @@ def embed_query(self, text: str, **kwargs: Any) -> list[float]:
7576
),
7677
**kwargs,
7778
)
78-
if not result.embeddings or not result.embeddings[0].values:
79+
if not result or not result.embeddings or not result.embeddings[0].values:
7980
raise ValueError("No embeddings returned from Gemini API")
8081
return list(result.embeddings[0].values)
8182
except Exception as e:
@@ -94,7 +95,7 @@ async def async_embed_query(self, text: str, **kwargs: Any) -> list[float]:
9495
),
9596
**kwargs,
9697
)
97-
if not result.embeddings or not result.embeddings[0].values:
98+
if not result or not result.embeddings or not result.embeddings[0].values:
9899
raise ValueError("No embeddings returned from Gemini API")
99100
return list(result.embeddings[0].values)
100101
except Exception as e:

0 commit comments

Comments
 (0)