Skip to content

Commit bc7162e

Browse files
committed
feat: add ockam support
1 parent dc877fd commit bc7162e

File tree

14 files changed

+429
-87
lines changed

14 files changed

+429
-87
lines changed

mem0/configs/embeddings/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ def __init__(
3737
aws_access_key_id: Optional[str] = None,
3838
aws_secret_access_key: Optional[str] = None,
3939
aws_region: Optional[str] = "us-west-2",
40+
# Ockam Model specific
41+
ockam_model: Optional = None,
4042
):
4143
"""
4244
Initializes a configuration class instance for the Embeddings.
@@ -101,3 +103,6 @@ def __init__(
101103
self.aws_access_key_id = aws_access_key_id
102104
self.aws_secret_access_key = aws_secret_access_key
103105
self.aws_region = aws_region
106+
107+
# Ockam Model specific
108+
self.ockam_model = ockam_model

mem0/configs/llms/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ def __init__(
5050
aws_access_key_id: Optional[str] = None,
5151
aws_secret_access_key: Optional[str] = None,
5252
aws_region: Optional[str] = "us-west-2",
53+
# Ockam Model specific
54+
ockam_model: Optional = None,
5355
):
5456
"""
5557
Initializes a configuration class instance for the LLM.
@@ -150,3 +152,6 @@ def __init__(
150152
self.aws_access_key_id = aws_access_key_id
151153
self.aws_secret_access_key = aws_secret_access_key
152154
self.aws_region = aws_region
155+
156+
# Ockam Model specific
157+
self.ockam_model = ockam_model
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
from typing import Any, Dict, Optional
2+
3+
from pydantic import BaseModel, Field, model_validator
4+
5+
6+
class InMemoryConfig(BaseModel):
7+
collection_name: str = Field("mem0", description="Default name for the collection")
8+
9+
@model_validator(mode="before")
10+
@classmethod
11+
def validate_extra_fields(cls, values: Dict[str, Any]) -> Dict[str, Any]:
12+
allowed_fields = set(cls.model_fields.keys())
13+
input_fields = set(values.keys())
14+
extra_fields = input_fields - allowed_fields
15+
if extra_fields:
16+
raise ValueError(
17+
f"Extra fields not allowed: {', '.join(extra_fields)}. Please input only the following fields: {', '.join(allowed_fields)}"
18+
)
19+
return values

mem0/embeddings/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ def __init__(self, config: Optional[BaseEmbedderConfig] = None):
1818
self.config = config
1919

2020
@abstractmethod
21-
def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]]):
21+
async def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]]):
2222
"""
2323
Get the embedding for the given text.
2424

mem0/embeddings/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def validate_config(cls, v, values):
2424
"lmstudio",
2525
"langchain",
2626
"aws_bedrock",
27+
"ockam",
2728
]:
2829
return v
2930
else:

mem0/embeddings/ockam.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from typing import Literal, Optional
2+
3+
from mem0.configs.embeddings.base import BaseEmbedderConfig
4+
from mem0.embeddings.base import EmbeddingBase
5+
6+
try:
7+
import litellm
8+
except ImportError:
9+
raise ImportError("The 'litellm' library is required. Please install it using 'pip install litellm'.")
10+
11+
12+
class OckamEmbedding(EmbeddingBase):
13+
def __init__(self, config: Optional[BaseEmbedderConfig] = None):
14+
super().__init__(config)
15+
16+
if not self.config.ockam_model:
17+
raise ValueError("'ockam_model' is required for 'OckamEmbedding'.")
18+
19+
async def embed(self, text, memory_action: Optional[Literal["add", "search", "update"]] = None):
20+
"""
21+
Get the embedding for the given text using Ollama.
22+
23+
Args:
24+
text (str): The text to embed.
25+
memory_action (optional): The type of embedding to use. Must be one of "add", "search", or "update". Defaults to None.
26+
Returns:
27+
list: The embedding vector.
28+
"""
29+
ockam_model = self.config.ockam_model
30+
router = ockam_model.router()
31+
kwargs = ockam_model.kwargs
32+
response = await router.aembedding(model=ockam_model.name, input=text, **kwargs)
33+
return response.data[0]["embedding"]

mem0/llms/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):
1717
self.config = config
1818

1919
@abstractmethod
20-
def generate_response(self, messages, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
20+
async def generate_response(self, messages, tools: Optional[List[Dict]] = None, tool_choice: str = "auto"):
2121
"""
2222
Generate a response based on the given messages.
2323

mem0/llms/configs.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ def validate_config(cls, v, values):
2828
"lmstudio",
2929
"vllm",
3030
"langchain",
31+
"ockam",
3132
):
3233
return v
3334
else:

mem0/llms/ockam.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
import json
2+
from typing import Dict, List, Optional
3+
4+
try:
5+
import litellm
6+
except ImportError:
7+
raise ImportError("The 'litellm' library is required. Please install it using 'pip install litellm'.")
8+
9+
from mem0.configs.llms.base import BaseLlmConfig
10+
from mem0.llms.base import LLMBase
11+
12+
13+
class OckamLLM(LLMBase):
14+
def __init__(self, config: Optional[BaseLlmConfig] = None):
15+
super().__init__(config)
16+
17+
if not self.config.ockam_model:
18+
raise ValueError("'ockam_model' is required for 'OckamLLM'.")
19+
20+
def _parse_response(self, response, tools):
21+
"""
22+
Process the response based on whether tools are used or not.
23+
24+
Args:
25+
response: The raw response from API.
26+
tools: The list of tools provided in the request.
27+
28+
Returns:
29+
str or dict: The processed response.
30+
"""
31+
if tools:
32+
processed_response = {
33+
"content": response.choices[0].message.content,
34+
"tool_calls": [],
35+
}
36+
37+
if response.choices[0].message.tool_calls:
38+
for tool_call in response.choices[0].message.tool_calls:
39+
processed_response["tool_calls"].append(
40+
{
41+
"name": tool_call.function.name,
42+
"arguments": json.loads(tool_call.function.arguments),
43+
}
44+
)
45+
46+
return processed_response
47+
else:
48+
return response.choices[0].message.content
49+
50+
async def generate_response(
51+
self,
52+
messages: List[Dict[str, str]],
53+
response_format=None,
54+
tools: Optional[List[Dict]] = None,
55+
tool_choice: str = "auto",
56+
):
57+
"""
58+
Generate a response based on the given messages using Litellm.
59+
60+
Args:
61+
messages (list): List of message dicts containing 'role' and 'content'.
62+
response_format (str or object, optional): Format of the response. Defaults to "text".
63+
tools (list, optional): List of tools that the model can call. Defaults to None.
64+
tool_choice (str, optional): Tool choice method. Defaults to "auto".
65+
66+
Returns:
67+
str: The generated response.
68+
"""
69+
70+
# FIXME
71+
# if not litellm.supports_function_calling(self.config.model):
72+
# raise ValueError(f"Model '{self.config.model}' in litellm does not support function calling.")
73+
74+
ockam_model = self.config.ockam_model
75+
76+
params = {
77+
"model": ockam_model.name,
78+
"messages": messages,
79+
"temperature": self.config.temperature,
80+
"max_tokens": self.config.max_tokens,
81+
"top_p": self.config.top_p,
82+
}
83+
if response_format:
84+
params["response_format"] = response_format
85+
if tools: # TODO: Remove tools if no issues found with new memory addition logic
86+
params["tools"] = tools
87+
params["tool_choice"] = tool_choice
88+
89+
router = ockam_model.router()
90+
kwargs = {**ockam_model.kwargs, **params}
91+
92+
response = await router.acompletion(**kwargs)
93+
return self._parse_response(response, tools)

0 commit comments

Comments
 (0)