Skip to content

Commit 386d8b8

Browse files
Fix: Migrate Gemini Embeddings (mem0ai#3002)
Co-authored-by: Dev-Khant <[email protected]>
1 parent c173ec3 commit 386d8b8

File tree

5 files changed

+124
-71
lines changed

5 files changed

+124
-71
lines changed

docs/components/embedders/models/gemini.mdx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,5 @@ Here are the parameters available for configuring Gemini embedder:
3939
| Parameter | Description | Default Value |
4040
| --- | --- | --- |
4141
| `model` | The name of the embedding model to use | `models/text-embedding-004` |
42-
| `embedding_dims` | Dimensions of the embedding model | `768` |
42+
| `embedding_dims` | Dimensions of the embedding model (output_dimensionality will be considered as embedding_dims, so please set embedding_dims accordingly) | `768` |
4343
| `api_key` | The Gemini API key | `None` |

mem0/embeddings/gemini.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
from typing import Literal, Optional
33

4-
import google.generativeai as genai
4+
import google.genai as genai
55

66
from mem0.configs.embeddings.base import BaseEmbedderConfig
77
from mem0.embeddings.base import EmbeddingBase
@@ -12,23 +12,28 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
1212
super().__init__(config)
1313

1414
self.config.model = self.config.model or "models/text-embedding-004"
15-
self.config.embedding_dims = self.config.embedding_dims or 768
15+
self.config.embedding_dims = self.config.embedding_dims or self.config.output_dimensionality or 768
1616

1717
api_key = self.config.api_key or os.getenv("GOOGLE_API_KEY")
1818

19-
genai.configure(api_key=api_key)
19+
if api_key:
20+
self.client = genai.Client(api_key="api_key")
21+
else:
22+
self.client = genai.Client()
2023

2124
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
2225
"""
2326
Get the embedding for the given text using Google Generative AI.
2427
Args:
2528
text (str): The text to embed.
26-
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
29+
memory_action (optional): The type of embedding to use. (Currently not used by Gemini for task_type)
2730
Returns:
2831
list: The embedding vector.
2932
"""
3033
text = text.replace("\n", " ")
31-
response = genai.embed_content(
34+
35+
response = self.client.models.embed_content(
3236
model=self.config.model, content=text, output_dimensionality=self.config.embedding_dims
3337
)
38+
3439
return response["embedding"]

mem0/llms/gemini.py

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
try:
55
from google import genai
66
from google.genai import types
7-
7+
88
except ImportError:
99
raise ImportError(
1010
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
@@ -49,16 +49,17 @@ def _parse_response(self, response, tools):
4949
for part in candidate.content.parts:
5050
fn = getattr(part, "function_call", None)
5151
if fn:
52-
processed_response["tool_calls"].append({
53-
"name": fn.name,
54-
"arguments": fn.args,
55-
})
52+
processed_response["tool_calls"].append(
53+
{
54+
"name": fn.name,
55+
"arguments": fn.args,
56+
}
57+
)
5658

5759
return processed_response
5860

5961
return content
6062

61-
6263
def _reformat_messages(self, messages: List[Dict[str, str]]) -> List[types.Content]:
6364
"""
6465
Reformat messages for Gemini using google.genai.types.
@@ -78,15 +79,11 @@ def _reformat_messages(self, messages: List[Dict[str, str]]) -> List[types.Conte
7879
content = message["content"]
7980

8081
new_messages.append(
81-
types.Content(
82-
role="model" if message["role"] == "model" else "user",
83-
parts=[types.Part(text=content)]
84-
)
82+
types.Content(role="model" if message["role"] == "model" else "user", parts=[types.Part(text=content)])
8583
)
8684

8785
return new_messages
8886

89-
9087
def _reformat_tools(self, tools: Optional[List[Dict]]):
9188
"""
9289
Reformat tools for Gemini.
@@ -131,7 +128,6 @@ def generate_response(
131128
tools: Optional[List[Dict]] = None,
132129
tool_choice: str = "auto",
133130
):
134-
135131
"""
136132
Generate a response based on the given messages using Gemini.
137133
@@ -161,31 +157,22 @@ def generate_response(
161157
tool_config = types.ToolConfig(
162158
function_calling_config=types.FunctionCallingConfig(
163159
mode=tool_choice.upper(), # Assuming 'any' should become 'ANY', etc.
164-
allowed_function_names=[
165-
tool["function"]["name"] for tool in tools
166-
] if tool_choice == "any" else None
160+
allowed_function_names=[tool["function"]["name"] for tool in tools]
161+
if tool_choice == "any"
162+
else None,
167163
)
168164
)
169165

170-
print(f"Tool config: {tool_config}")
171-
print(f"Params: {params}" )
172-
print(f"Messages: {messages}")
173-
print(f"Tools: {tools}")
174-
print(f"Reformatted messages: {self._reformat_messages(messages)}")
175-
print(f"Reformatted tools: {self._reformat_tools(tools)}")
176-
177166
response = self.client_gemini.models.generate_content(
178-
model=self.config.model,
179-
contents=self._reformat_messages(messages),
180-
config=types.GenerateContentConfig(
181-
temperature= self.config.temperature,
182-
max_output_tokens= self.config.max_tokens,
183-
top_p= self.config.top_p,
184-
tools=self._reformat_tools(tools),
185-
tool_config=tool_config,
186-
187-
),
188-
)
189-
print(f"Response test: {response}")
167+
model=self.config.model,
168+
contents=self._reformat_messages(messages),
169+
config=types.GenerateContentConfig(
170+
temperature=self.config.temperature,
171+
max_output_tokens=self.config.max_tokens,
172+
top_p=self.config.top_p,
173+
tools=self._reformat_tools(tools),
174+
tool_config=tool_config,
175+
),
176+
)
190177

191178
return self._parse_response(response, tools)

tests/embeddings/test_gemini.py renamed to tests/embeddings/test_gemini_emeddings.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,29 @@ def test_embed_query(mock_genai, config):
2828

2929
assert embedding == [0.1, 0.2, 0.3, 0.4]
3030
mock_genai.assert_called_once_with(model="test_model", content="Hello, world!", output_dimensionality=786)
31+
32+
def test_embed_returns_empty_list_if_none(mock_genai, config):
33+
mock_genai.return_value = None
34+
35+
embedder = GoogleGenAIEmbedding(config)
36+
result = embedder.embed("test")
37+
38+
assert result == []
39+
mock_genai.assert_called_once()
40+
41+
42+
def test_embed_raises_on_error(mock_genai, config):
43+
mock_genai.side_effect = RuntimeError("Embedding failed")
44+
45+
embedder = GoogleGenAIEmbedding(config)
46+
47+
with pytest.raises(RuntimeError, match="Embedding failed"):
48+
embedder.embed("some input")
49+
50+
def test_config_initialization(config):
51+
embedder = GoogleGenAIEmbedding(config)
52+
53+
assert embedder.config.api_key == "dummy_api_key"
54+
assert embedder.config.model == "test_model"
55+
assert embedder.config.embedding_dims == 786
56+

tests/llms/test_gemini_llm.py renamed to tests/llms/test_gemini.py

Lines changed: 66 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,38 @@
11
from unittest.mock import Mock, patch
22

33
import pytest
4-
from google.generativeai import GenerationConfig
5-
from google.generativeai.types import content_types
4+
from google.genai import types
65

76
from mem0.configs.llms.base import BaseLlmConfig
87
from mem0.llms.gemini import GeminiLLM
98

109

1110
@pytest.fixture
1211
def mock_gemini_client():
13-
with patch("mem0.llms.gemini.GenerativeModel") as mock_gemini:
12+
with patch("mem0.llms.gemini.genai") as mock_client_class:
1413
mock_client = Mock()
15-
mock_gemini.return_value = mock_client
14+
mock_client_class.return_value = mock_client
1615
yield mock_client
1716

1817

1918
def test_generate_response_without_tools(mock_gemini_client: Mock):
20-
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
19+
config = BaseLlmConfig(model="gemini-2.0-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
2120
llm = GeminiLLM(config)
2221
messages = [
2322
{"role": "system", "content": "You are a helpful assistant."},
2423
{"role": "user", "content": "Hello, how are you?"},
2524
]
2625

2726
mock_part = Mock(text="I'm doing well, thank you for asking!")
27+
mock_embedding = Mock()
28+
mock_embedding.values = [0.1, 0.2, 0.3]
29+
30+
mock_response = Mock()
31+
mock_response.candidates = [Mock()]
32+
mock_response.candidates[0].content.parts = [Mock()]
33+
mock_response.candidates[0].content.parts[0].text = "I'm doing well, thank you for asking!"
34+
35+
mock_gemini_client.models.generate_content.return_value = mock_response
2836
mock_content = Mock(parts=[mock_part])
2937
mock_message = Mock(content=mock_content)
3038
mock_response = Mock(candidates=[mock_message])
@@ -37,15 +45,24 @@ def test_generate_response_without_tools(mock_gemini_client: Mock):
3745
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
3846
{"parts": "Hello, how are you?", "role": "user"},
3947
],
40-
generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0),
41-
tools=None,
42-
tool_config=content_types.to_tool_config(
43-
{"function_calling_config": {"mode": "auto", "allowed_function_names": None}}
44-
),
45-
)
48+
config=types.GenerateContentConfig(
49+
temperature=0.7,
50+
max_output_tokens=100,
51+
top_p=1.0,
52+
tools=None,
53+
tool_config=types.ToolConfig(
54+
function_calling_config=types.FunctionCallingConfig(
55+
allowed_function_names=None,
56+
mode="auto"
57+
58+
)
59+
)
60+
) )
61+
4662
assert response == "I'm doing well, thank you for asking!"
4763

4864

65+
4966
def test_generate_response_with_tools(mock_gemini_client: Mock):
5067
config = BaseLlmConfig(model="gemini-1.5-flash-latest", temperature=0.7, max_tokens=100, top_p=1.0)
5168
llm = GeminiLLM(config)
@@ -89,28 +106,46 @@ def test_generate_response_with_tools(mock_gemini_client: Mock):
89106

90107
mock_gemini_client.generate_content.assert_called_once_with(
91108
contents=[
92-
{"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.", "role": "user"},
93-
{"parts": "Add a new memory: Today is a sunny day.", "role": "user"},
94-
],
95-
generation_config=GenerationConfig(temperature=0.7, max_output_tokens=100, top_p=1.0),
96-
tools=[
97109
{
98-
"function_declarations": [
99-
{
100-
"name": "add_memory",
101-
"description": "Add a memory",
102-
"parameters": {
103-
"type": "object",
104-
"properties": {"data": {"type": "string", "description": "Data to add to memory"}},
105-
"required": ["data"],
106-
},
107-
}
108-
]
109-
}
110+
"parts": "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: You are a helpful assistant.",
111+
"role": "user"
112+
},
113+
{
114+
"parts": "Add a new memory: Today is a sunny day.",
115+
"role": "user"
116+
},
110117
],
111-
tool_config=content_types.to_tool_config(
112-
{"function_calling_config": {"mode": "auto", "allowed_function_names": None}}
113-
),
118+
config=types.GenerateContentConfig(
119+
temperature=0.7,
120+
max_output_tokens=100,
121+
top_p=1.0,
122+
tools=[
123+
types.Tool(
124+
function_declarations=[
125+
types.FunctionDeclaration(
126+
name="add_memory",
127+
description="Add a memory",
128+
parameters={
129+
"type": "object",
130+
"properties": {
131+
"data": {
132+
"type": "string",
133+
"description": "Data to add to memory"
134+
}
135+
},
136+
"required": ["data"]
137+
}
138+
)
139+
]
140+
)
141+
],
142+
tool_config=types.ToolConfig(
143+
function_calling_config=types.FunctionCallingConfig(
144+
allowed_function_names=None,
145+
mode="auto"
146+
)
147+
)
148+
)
114149
)
115150

116151
assert response["content"] == "I've added the memory for you."

0 commit comments

Comments
 (0)