Skip to content

Commit 3e00424

Browse files
committed
Add long-term memory feature to Friday
1 parent d043294 commit 3e00424

File tree

10 files changed

+553
-43
lines changed

10 files changed

+553
-43
lines changed

packages/app/friday/args.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,52 @@ def get_args() -> Namespace:
6363
type=bool,
6464
required=True,
6565
)
66+
parser.add_argument(
67+
"--longTermMemory",
68+
type=bool,
69+
default=False,
70+
help="Whether to enable long-term memory and embedding support.",
71+
)
72+
parser.add_argument(
73+
"--embeddingProvider",
74+
type=str,
75+
choices=["dashscope", "openai", "gemini", "ollama"],
76+
help="Embedding provider name, e.g. openai/dashscope/gemini/ollama.",
77+
)
78+
parser.add_argument(
79+
"--embeddingModelName",
80+
type=str,
81+
help="Embedding model name, e.g. text-embedding-3-small or text-embedding-v1.",
82+
)
83+
parser.add_argument(
84+
"--embeddingApiKey",
85+
type=str,
86+
help="API key for the embedding provider; falls back to --apiKey if not set.",
87+
)
88+
parser.add_argument(
89+
"--embeddingKwargs",
90+
type=json_type,
91+
default={},
92+
help="A JSON string for extra kwargs passed to the embedding model (e.g. host, dimensions).",
93+
)
94+
parser.add_argument(
95+
"--saveToLocal",
96+
type=bool,
97+
default=False,
98+
help="Whether to save long-term memory to local disk.",
99+
)
100+
parser.add_argument(
101+
"--localStoragePath",
102+
type=str,
103+
default="",
104+
help="Local storage path for long-term memory.",
105+
)
106+
parser.add_argument(
107+
"--vectorStoreProvider",
108+
type=str,
109+
default="qdrant",
110+
help="Vector store provider for long-term memory (e.g. qdrant, chroma, faiss).",
111+
)
66112
parser.add_argument(
67113
"--clientKwargs",
68114
type=json_type,

packages/app/friday/main.py

Lines changed: 68 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
import json5
1010
from agentscope.agent import ReActAgent
11-
from agentscope.memory import InMemoryMemory
11+
from agentscope.memory import InMemoryMemory, Mem0LongTermMemory
1212
from agentscope.message import Msg
1313
from agentscope.session import JSONSession
1414
from agentscope.tool import (
@@ -25,7 +25,7 @@
2525
studio_post_reply_hook,
2626
)
2727
from args import get_args
28-
from model import get_model, get_formatter
28+
from model import get_model, get_formatter, get_embedding_model, get_memory_model
2929
from tool.utils import (
3030
view_agentscope_library,
3131
view_agentscope_readme,
@@ -36,6 +36,8 @@
3636
from utils.constants import FRIDAY_SESSION_ID
3737

3838
from mcp_manager import connect_mcp_servers, close_mcp_connections
39+
from mem0.vector_stores.configs import VectorStoreConfig
40+
3941

4042
async def main():
4143
args = get_args()
@@ -94,13 +96,55 @@ async def main():
9496
local_mcp_clients = await connect_mcp_servers(mcp_servers, toolkit)
9597

9698
# get model from args
97-
model = get_model(args.llmProvider, args.modelName, args.apiKey, args.clientKwargs, args.generateKwargs)
99+
model = get_model(
100+
args.llmProvider,
101+
args.modelName,
102+
args.apiKey,
103+
args.clientKwargs,
104+
args.generateKwargs,
105+
)
98106
formatter = get_formatter(args.llmProvider)
99107

100-
# Create the ReAct agent
101-
agent = ReActAgent(
102-
name="Friday",
103-
sys_prompt="""You're Friday, a helpful assistant specialized in daily task management and AgentScope framework support.
108+
# Initialize long-term memory if enabled
109+
long_term_memory = None
110+
if args.longTermMemory:
111+
memory_model = get_memory_model(
112+
args.llmProvider,
113+
args.modelName,
114+
args.apiKey,
115+
args.clientKwargs,
116+
args.generateKwargs,
117+
)
118+
embedding_model = get_embedding_model(
119+
embeddingProvider=args.embeddingProvider,
120+
embeddingModelName=args.embeddingModelName,
121+
embeddingApiKey=args.embeddingApiKey,
122+
embedding_kwargs=args.embeddingKwargs,
123+
)
124+
125+
# Prepare vector store config with embedding dimensions
126+
vector_store_config_dict = {
127+
"on_disk": args.saveToLocal,
128+
"path": args.localStoragePath,
129+
}
130+
131+
# Add embedding_model_dims for vector stores that need it (e.g., Qdrant)
132+
if hasattr(embedding_model, 'dimensions'):
133+
vector_store_config_dict["embedding_model_dims"] = embedding_model.dimensions
134+
135+
long_term_memory = Mem0LongTermMemory(
136+
agent_name="Friday",
137+
user_name="Studio",
138+
model=memory_model,
139+
embedding_model=embedding_model,
140+
vector_store_config=VectorStoreConfig(
141+
provider=args.vectorStoreProvider,
142+
config=vector_store_config_dict,
143+
),
144+
)
145+
146+
# Build system prompt
147+
sys_prompt = """You're Friday, a helpful assistant specialized in daily task management and AgentScope framework support.
104148
105149
# Core Objectives
106150
- Help users manage and complete daily tasks efficiently
@@ -130,7 +174,21 @@ async def main():
130174
- Never guess or make up implementations
131175
132176
# Available Context
133-
- Current date and time: {current_time}""".format(
177+
- Current date and time: {current_time}"""
178+
179+
# Add long-term memory information if enabled
180+
if args.longTermMemory:
181+
sys_prompt += """
182+
183+
# Long-term Memory
184+
- You have long-term memory enabled, which allows you to remember information across conversations
185+
- Use this capability to provide more personalized and context-aware assistance
186+
- You can reference past interactions and learned preferences to better serve the user"""
187+
188+
# Create the ReAct agent
189+
agent = ReActAgent(
190+
name="Friday",
191+
sys_prompt=sys_prompt.format(
134192
current_time=datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
135193
max_turns=20,
136194
),
@@ -140,6 +198,8 @@ async def main():
140198
memory=InMemoryMemory(),
141199
max_iters=50,
142200
enable_meta_tool=True,
201+
long_term_memory=long_term_memory,
202+
long_term_memory_mode='both'
143203
)
144204

145205
path_dialog_history = get_local_file_path("")

packages/app/friday/model.py

Lines changed: 126 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# -*- coding: utf-8 -*-
22
"""Get the formatter and model based on the model provider."""
33
import re
4+
45
import agentscope
56
from agentscope.formatter import (
67
DashScopeChatFormatter,
@@ -18,24 +19,24 @@
1819
GeminiChatModel,
1920
AnthropicChatModel,
2021
)
22+
from agentscope.embedding import (
23+
EmbeddingModelBase,
24+
EmbeddingUsage,
25+
EmbeddingResponse,
26+
DashScopeTextEmbedding,
27+
DashScopeMultiModalEmbedding,
28+
OpenAITextEmbedding,
29+
GeminiTextEmbedding,
30+
OllamaTextEmbedding,
31+
EmbeddingCacheBase,
32+
FileEmbeddingCache,
33+
)
2134

2235

2336
def is_agentscope_version_ge(target_version: tuple) -> bool:
24-
"""
25-
Check if the current agentscope version is greater than or equal to the target version.
26-
27-
Args:
28-
target_version: A tuple of (major, minor, patch) version numbers.
29-
30-
Returns:
31-
True if current version >= target version, False otherwise.
32-
33-
Example:
34-
>>> is_agentscope_version_ge((1, 0, 9)) # Works with "1.0.9" or "1.0.9dev"
35-
True
36-
"""
37+
"""Check whether the current agentscope version is >= target_version."""
3738
version_str = agentscope.__version__
38-
version_match = re.match(r'^(\d+)\.(\d+)\.(\d+)', version_str)
39+
version_match = re.match(r"^(\d+)\.(\d+)\.(\d+)", version_str)
3940
if version_match:
4041
major, minor, patch = map(int, version_match.groups())
4142
current_version = (major, minor, patch)
@@ -61,14 +62,15 @@ def get_formatter(llmProvider: str) -> FormatterBase:
6162
f"Unsupported model provider: {llmProvider}. "
6263
)
6364

65+
6466
def get_model(
6567
llmProvider: str,
6668
modelName: str,
6769
apiKey: str,
6870
client_kwargs: dict = {},
6971
generate_kwargs: dict = {},
7072
) -> ChatModelBase:
71-
"""Get the model instance based on the input arguments."""
73+
"""Get the chat model instance based on the input arguments."""
7274

7375
match llmProvider.lower():
7476
case "dashscope":
@@ -122,3 +124,112 @@ def get_model(
122124
raise ValueError(
123125
f"Unsupported model provider: {llmProvider}. "
124126
)
127+
128+
def get_memory_model(
129+
llmProvider: str,
130+
modelName: str,
131+
apiKey: str,
132+
client_kwargs: dict = {},
133+
generate_kwargs: dict = {},
134+
) -> ChatModelBase:
135+
"""Get the chat model instance based on the input arguments."""
136+
137+
match llmProvider.lower():
138+
case "dashscope":
139+
return DashScopeChatModel(
140+
model_name=modelName,
141+
api_key=apiKey,
142+
stream=False,
143+
generate_kwargs=generate_kwargs,
144+
)
145+
case "openai":
146+
return OpenAIChatModel(
147+
model_name=modelName,
148+
api_key=apiKey,
149+
stream=False,
150+
client_kwargs=client_kwargs,
151+
generate_kwargs=generate_kwargs,
152+
)
153+
case "ollama":
154+
if is_agentscope_version_ge((1, 0, 9)):
155+
# For agentscope >= 1.0.9
156+
return OllamaChatModel(
157+
model_name=modelName,
158+
stream=False,
159+
client_kwargs=client_kwargs,
160+
generate_kwargs=generate_kwargs,
161+
)
162+
else:
163+
# For agentscope < 1.0.9
164+
return OllamaChatModel(
165+
model_name=modelName,
166+
stream=False,
167+
**client_kwargs,
168+
)
169+
case "gemini":
170+
return GeminiChatModel(
171+
model_name=modelName,
172+
api_key=apiKey,
173+
stream=False,
174+
client_kwargs=client_kwargs,
175+
generate_kwargs=generate_kwargs,
176+
)
177+
case "anthropic":
178+
return AnthropicChatModel(
179+
model_name=modelName,
180+
api_key=apiKey,
181+
stream=False,
182+
client_kwargs=client_kwargs,
183+
generate_kwargs=generate_kwargs,
184+
)
185+
case _:
186+
raise ValueError(
187+
f"Unsupported model provider: {llmProvider}. "
188+
)
189+
190+
191+
def get_embedding_model(
192+
embeddingProvider: str,
193+
embeddingModelName: str,
194+
embeddingApiKey: str,
195+
embedding_kwargs: dict = {},
196+
) -> EmbeddingModelBase:
197+
"""Get the embedding model instance based on the input arguments.
198+
199+
The signature follows the corresponding classes in ``agentscope.embedding``.
200+
``embedding_kwargs`` can be used to pass extra provider-specific keyword
201+
arguments, such as ``host`` or ``dimensions``.
202+
"""
203+
204+
match embeddingProvider.lower():
205+
case "dashscope":
206+
# DashScopeTextEmbedding(model_name: str, api_key: str, **kwargs)
207+
return DashScopeTextEmbedding(
208+
model_name=embeddingModelName,
209+
api_key=embeddingApiKey,
210+
**embedding_kwargs,
211+
)
212+
case "openai":
213+
# OpenAITextEmbedding(model_name: str, api_key: str, **kwargs)
214+
return OpenAITextEmbedding(
215+
model_name=embeddingModelName,
216+
api_key=embeddingApiKey,
217+
**embedding_kwargs,
218+
)
219+
case "gemini":
220+
# GeminiTextEmbedding(model_name: str, api_key: str, **kwargs)
221+
return GeminiTextEmbedding(
222+
model_name=embeddingModelName,
223+
api_key=embeddingApiKey,
224+
**embedding_kwargs,
225+
)
226+
case "ollama":
227+
# OllamaTextEmbedding(model_name: str, **kwargs)
228+
return OllamaTextEmbedding(
229+
model_name=embeddingModelName,
230+
**embedding_kwargs,
231+
)
232+
case _:
233+
raise ValueError(
234+
f"Unsupported embedding provider: {embeddingProvider}. "
235+
)

packages/app/friday/utils/common.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""Utility functions for file path management in AgentScope Studio."""
33
import platform
44
import os
5+
import shutil
56

67
from utils.constants import NAME_STUDIO, NAME_APP
78

@@ -26,3 +27,31 @@ def get_local_file_path(filename: str) -> str:
2627
os.makedirs(os.path.join(local_path, NAME_APP), exist_ok=True)
2728

2829
return os.path.join(local_path, NAME_APP, filename)
30+
31+
32+
def clear_vector_store(storage_path: str, provider: str = "qdrant") -> None:
33+
"""Clear the vector store data to resolve dimension mismatch issues.
34+
35+
Args:
36+
storage_path: The base path where vector store data is saved
37+
provider: The vector store provider name (default: qdrant)
38+
"""
39+
if not storage_path:
40+
# Use default path if not specified
41+
storage_path = get_local_file_path("")
42+
43+
# Construct vector store path based on provider
44+
if provider.lower() == "qdrant":
45+
vector_store_path = os.path.join(storage_path, "qdrant_storage")
46+
elif provider.lower() == "chroma":
47+
vector_store_path = os.path.join(storage_path, "chroma_db")
48+
else:
49+
# For other providers, use a generic folder name
50+
vector_store_path = os.path.join(storage_path, f"{provider}_storage")
51+
52+
# Remove the vector store directory if it exists
53+
if os.path.exists(vector_store_path):
54+
print(f"Clearing vector store at: {vector_store_path}")
55+
shutil.rmtree(vector_store_path)
56+
print("Vector store cleared successfully")
57+

0 commit comments

Comments
 (0)