Skip to content

Commit c70dc76

Browse files
authored
Fix: Add Google Genai library support (mem0ai#2941)
1 parent e000324 commit c70dc76

File tree

7 files changed

+588
-275
lines changed

7 files changed

+588
-275
lines changed

docs/components/llms/models/gemini.mdx

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,36 +4,44 @@ title: Gemini
44

55
<Snippet file="paper-release.mdx" />
66

7-
To use Gemini model, you have to set the `GEMINI_API_KEY` environment variable. You can obtain the Gemini API key from the [Google AI Studio](https://aistudio.google.com/app/apikey)
7+
To use the Gemini model, set the `GEMINI_API_KEY` environment variable. You can obtain the Gemini API key from [Google AI Studio](https://aistudio.google.com/app/apikey).
8+
9+
> **Note:** As of the latest release, Mem0 uses the new `google.genai` SDK instead of the deprecated `google.generativeai`. All message formatting and model interaction now use the updated `types` module from `google.genai`.
10+
11+
> **Note:** Some Gemini models are being deprecated and will retire soon. It is recommended to migrate to the latest stable models like `"gemini-2.0-flash-001"` or `"gemini-2.0-flash-lite-001"` to ensure ongoing support and improvements.
812
913
## Usage
1014

1115
```python
1216
import os
1317
from mem0 import Memory
1418

15-
os.environ["OPENAI_API_KEY"] = "your-api-key" # used for embedding model
16-
os.environ["GEMINI_API_KEY"] = "your-api-key"
19+
os.environ["OPENAI_API_KEY"] = "your-openai-api-key" # Used for embedding model
20+
os.environ["GEMINI_API_KEY"] = "your-gemini-api-key"
1721

1822
config = {
1923
"llm": {
2024
"provider": "gemini",
2125
"config": {
22-
"model": "gemini-1.5-flash-latest",
26+
"model": "gemini-2.0-flash-001",
2327
"temperature": 0.2,
2428
"max_tokens": 2000,
29+
"top_p": 1.0
2530
}
2631
}
2732
}
2833

2934
m = Memory.from_config(config)
35+
3036
messages = [
3137
{"role": "user", "content": "I'm planning to watch a movie tonight. Any recommendations?"},
32-
{"role": "assistant", "content": "How about a thriller movies? They can be quite engaging."},
33-
{"role": "user", "content": "I’m not a big fan of thriller movies but I love sci-fi movies."},
34-
{"role": "assistant", "content": "Got it! I'll avoid thriller recommendations and suggest sci-fi movies in the future."}
38+
{"role": "assistant", "content": "How about thriller movies? They can be quite engaging."},
39+
{"role": "user", "content": "I’m not a big fan of thrillers, but I love sci-fi movies."},
40+
{"role": "assistant", "content": "Got it! I'll avoid thrillers and suggest sci-fi movies instead."}
3541
]
42+
3643
m.add(messages, user_id="alice", metadata={"category": "movies"})
44+
3745
```
3846

3947
## Config

docs/open-source/graph_memory/overview.mdx

Lines changed: 74 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,16 +238,24 @@ The Mem0's graph supports the following operations:
238238
### Add Memories
239239

240240
<Note>
241-
If you are using Mem0 with Graph Memory, it is recommended to pass `user_id`. Use `userId` in NodeSDK.
241+
Mem0 with Graph Memory supports both `user_id` and `agent_id` parameters. You can use either or both to organize your memories. Use `userId` and `agentId` in NodeSDK.
242242
</Note>
243243

244244
<CodeGroup>
245245
```python Python
246+
# Using only user_id
246247
m.add("I like pizza", user_id="alice")
248+
249+
# Using both user_id and agent_id
250+
m.add("I like pizza", user_id="alice", agent_id="food-assistant")
247251
```
248252

249253
```typescript TypeScript
254+
// Using only userId
250255
memory.add("I like pizza", { userId: "alice" });
256+
257+
// Using both userId and agentId
258+
memory.add("I like pizza", { userId: "alice", agentId: "food-assistant" });
251259
```
252260

253261
```json Output
@@ -260,11 +268,19 @@ memory.add("I like pizza", { userId: "alice" });
260268

261269
<CodeGroup>
262270
```python Python
271+
# Get all memories for a user
263272
m.get_all(user_id="alice")
273+
274+
# Get all memories for a specific agent belonging to a user
275+
m.get_all(user_id="alice", agent_id="food-assistant")
264276
```
265277

266278
```typescript TypeScript
279+
// Get all memories for a user
267280
memory.getAll({ userId: "alice" });
281+
282+
// Get all memories for a specific agent belonging to a user
283+
memory.getAll({ userId: "alice", agentId: "food-assistant" });
268284
```
269285

270286
```json Output
@@ -277,7 +293,8 @@ memory.getAll({ userId: "alice" });
277293
'metadata': None,
278294
'created_at': '2024-08-20T14:09:27.588719-07:00',
279295
'updated_at': None,
280-
'user_id': 'alice'
296+
'user_id': 'alice',
297+
'agent_id': 'food-assistant'
281298
}
282299
],
283300
'entities': [
@@ -295,11 +312,19 @@ memory.getAll({ userId: "alice" });
295312

296313
<CodeGroup>
297314
```python Python
315+
# Search memories for a user
298316
m.search("tell me my name.", user_id="alice")
317+
318+
# Search memories for a specific agent belonging to a user
319+
m.search("tell me my name.", user_id="alice", agent_id="food-assistant")
299320
```
300321

301322
```typescript TypeScript
323+
// Search memories for a user
302324
memory.search("tell me my name.", { userId: "alice" });
325+
326+
// Search memories for a specific agent belonging to a user
327+
memory.search("tell me my name.", { userId: "alice", agentId: "food-assistant" });
303328
```
304329

305330
```json Output
@@ -312,7 +337,8 @@ memory.search("tell me my name.", { userId: "alice" });
312337
'metadata': None,
313338
'created_at': '2024-08-20T14:09:27.588719-07:00',
314339
'updated_at': None,
315-
'user_id': 'alice'
340+
'user_id': 'alice',
341+
'agent_id': 'food-assistant'
316342
}
317343
],
318344
'entities': [
@@ -331,11 +357,19 @@ memory.search("tell me my name.", { userId: "alice" });
331357

332358
<CodeGroup>
333359
```python Python
360+
# Delete all memories for a user
334361
m.delete_all(user_id="alice")
362+
363+
# Delete all memories for a specific agent belonging to a user
364+
m.delete_all(user_id="alice", agent_id="food-assistant")
335365
```
336366

337367
```typescript TypeScript
368+
// Delete all memories for a user
338369
memory.deleteAll({ userId: "alice" });
370+
371+
// Delete all memories for a specific agent belonging to a user
372+
memory.deleteAll({ userId: "alice", agentId: "food-assistant" });
339373
```
340374
</CodeGroup>
341375

@@ -516,6 +550,42 @@ memory.search("Who is spiderman?", { userId: "alice123" });
516550

517551
> **Note:** The Graph Memory implementation is not standalone. You will be adding/retrieving memories to the vector store and the graph store simultaneously.
518552
553+
## Using Multiple Agents with Graph Memory
554+
555+
When working with multiple agents, you can use the `agent_id` parameter to organize memories by both user and agent. This allows you to:
556+
557+
1. Create agent-specific knowledge graphs
558+
2. Share common knowledge between agents
559+
3. Isolate sensitive or specialized information to specific agents
560+
561+
### Example: Multi-Agent Setup
562+
563+
<CodeGroup>
564+
```python Python
565+
# Add memories for different agents
566+
m.add("I prefer Italian cuisine", user_id="bob", agent_id="food-assistant")
567+
m.add("I'm allergic to peanuts", user_id="bob", agent_id="health-assistant")
568+
m.add("I live in Seattle", user_id="bob") # Shared across all agents
569+
570+
# Search within specific agent context
571+
food_preferences = m.search("What food do I like?", user_id="bob", agent_id="food-assistant")
572+
health_info = m.search("What are my allergies?", user_id="bob", agent_id="health-assistant")
573+
location = m.search("Where do I live?", user_id="bob") # Searches across all agents
574+
```
575+
576+
```typescript TypeScript
577+
// Add memories for different agents
578+
memory.add("I prefer Italian cuisine", { userId: "bob", agentId: "food-assistant" });
579+
memory.add("I'm allergic to peanuts", { userId: "bob", agentId: "health-assistant" });
580+
memory.add("I live in Seattle", { userId: "bob" }); // Shared across all agents
581+
582+
// Search within specific agent context
583+
const foodPreferences = memory.search("What food do I like?", { userId: "bob", agentId: "food-assistant" });
584+
const healthInfo = memory.search("What are my allergies?", { userId: "bob", agentId: "health-assistant" });
585+
const location = memory.search("Where do I live?", { userId: "bob" }); // Searches across all agents
586+
```
587+
</CodeGroup>
588+
519589
If you want to use a managed version of Mem0, please check out [Mem0](https://mem0.dev/pd). If you have any questions, please feel free to reach out to us using one of the following methods:
520590

521-
<Snippet file="get-help.mdx" />
591+
<Snippet file="get-help.mdx" />

embedchain/poetry.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

mem0/llms/gemini.py

Lines changed: 59 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
from typing import Dict, List, Optional
33

44
try:
5-
import google.generativeai as genai
6-
from google.generativeai import GenerativeModel, protos
7-
from google.generativeai.types import content_types
5+
from google import genai
6+
from google.genai import types
7+
88
except ImportError:
99
raise ImportError(
1010
"The 'google-generativeai' library is required. Please install it using 'pip install google-generativeai'."
@@ -22,66 +22,71 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):
2222
self.config.model = "gemini-1.5-flash-latest"
2323

2424
api_key = self.config.api_key or os.getenv("GEMINI_API_KEY")
25-
genai.configure(api_key=api_key)
26-
self.client = GenerativeModel(model_name=self.config.model)
25+
self.client_gemini = genai.Client(
26+
api_key=api_key,
27+
)
2728

2829
def _parse_response(self, response, tools):
2930
"""
3031
Process the response based on whether tools are used or not.
3132
3233
Args:
33-
response: The raw response from API.
34+
response: The raw response from the API.
3435
tools: The list of tools provided in the request.
3536
3637
Returns:
3738
str or dict: The processed response.
3839
"""
40+
candidate = response.candidates[0]
41+
content = candidate.content.parts[0].text if candidate.content.parts else None
42+
3943
if tools:
4044
processed_response = {
41-
"content": (content if (content := response.candidates[0].content.parts[0].text) else None),
45+
"content": content,
4246
"tool_calls": [],
4347
}
4448

45-
for part in response.candidates[0].content.parts:
46-
if fn := part.function_call:
47-
if isinstance(fn, protos.FunctionCall):
48-
fn_call = type(fn).to_dict(fn)
49-
processed_response["tool_calls"].append({"name": fn_call["name"], "arguments": fn_call["args"]})
50-
continue
51-
processed_response["tool_calls"].append({"name": fn.name, "arguments": fn.args})
49+
for part in candidate.content.parts:
50+
fn = getattr(part, "function_call", None)
51+
if fn:
52+
processed_response["tool_calls"].append({
53+
"name": fn.name,
54+
"arguments": fn.args,
55+
})
5256

5357
return processed_response
54-
else:
55-
return response.candidates[0].content.parts[0].text
5658

57-
def _reformat_messages(self, messages: List[Dict[str, str]]):
59+
return content
60+
61+
62+
def _reformat_messages(self, messages: List[Dict[str, str]]) -> List[types.Content]:
5863
"""
59-
Reformat messages for Gemini.
64+
Reformat messages for Gemini using google.genai.types.
6065
6166
Args:
6267
messages: The list of messages provided in the request.
6368
6469
Returns:
65-
list: The list of messages in the required format.
70+
list: A list of types.Content objects with proper role and parts.
6671
"""
6772
new_messages = []
6873

6974
for message in messages:
7075
if message["role"] == "system":
7176
content = "THIS IS A SYSTEM PROMPT. YOU MUST OBEY THIS: " + message["content"]
72-
7377
else:
7478
content = message["content"]
7579

7680
new_messages.append(
77-
{
78-
"parts": content,
79-
"role": "model" if message["role"] == "model" else "user",
80-
}
81+
types.Content(
82+
role="model" if message["role"] == "model" else "user",
83+
parts=[types.Part(text=content)]
84+
)
8185
)
8286

8387
return new_messages
8488

89+
8590
def _reformat_tools(self, tools: Optional[List[Dict]]):
8691
"""
8792
Reformat tools for Gemini.
@@ -126,6 +131,7 @@ def generate_response(
126131
tools: Optional[List[Dict]] = None,
127132
tool_choice: str = "auto",
128133
):
134+
129135
"""
130136
Generate a response based on the given messages using Gemini.
131137
@@ -149,23 +155,37 @@ def generate_response(
149155
params["response_mime_type"] = "application/json"
150156
if "schema" in response_format:
151157
params["response_schema"] = response_format["schema"]
158+
159+
tool_config = None
152160
if tool_choice:
153-
tool_config = content_types.to_tool_config(
154-
{
155-
"function_calling_config": {
156-
"mode": tool_choice,
157-
"allowed_function_names": (
158-
[tool["function"]["name"] for tool in tools] if tool_choice == "any" else None
159-
),
160-
}
161-
}
161+
tool_config = types.ToolConfig(
162+
function_calling_config=types.FunctionCallingConfig(
163+
mode=tool_choice.upper(), # Assuming 'any' should become 'ANY', etc.
164+
allowed_function_names=[
165+
tool["function"]["name"] for tool in tools
166+
] if tool_choice == "any" else None
167+
)
162168
)
163169

164-
response = self.client.generate_content(
165-
contents=self._reformat_messages(messages),
166-
tools=self._reformat_tools(tools),
167-
generation_config=genai.GenerationConfig(**params),
168-
tool_config=tool_config,
169-
)
170+
print(f"Tool config: {tool_config}")
171+
print(f"Params: {params}" )
172+
print(f"Messages: {messages}")
173+
print(f"Tools: {tools}")
174+
print(f"Reformatted messages: {self._reformat_messages(messages)}")
175+
print(f"Reformatted tools: {self._reformat_tools(tools)}")
176+
177+
response = self.client_gemini.models.generate_content(
178+
model=self.config.model,
179+
contents=self._reformat_messages(messages),
180+
config=types.GenerateContentConfig(
181+
temperature= self.config.temperature,
182+
max_output_tokens= self.config.max_tokens,
183+
top_p= self.config.top_p,
184+
tools=self._reformat_tools(tools),
185+
tool_config=tool_config,
186+
187+
),
188+
)
189+
print(f"Response test: {response}")
170190

171191
return self._parse_response(response, tools)

0 commit comments

Comments
 (0)