Skip to content

Commit fc78818

Browse files
author
yuan.wang
committed
merge api refactor to dev
2 parents 1043377 + b1053c4 commit fc78818

38 files changed

+1519
-360
lines changed

src/memos/api/handlers/chat_handler.py

Lines changed: 384 additions & 60 deletions
Large diffs are not rendered by default.

src/memos/api/handlers/component_init.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from memos.api.config import APIConfig
1313
from memos.api.handlers.config_builders import (
14+
build_chat_llm_config,
1415
build_embedder_config,
1516
build_graph_db_config,
1617
build_internet_retriever_config,
@@ -77,6 +78,38 @@ def _get_default_memory_size(cube_config: Any) -> dict[str, int]:
7778
}
7879

7980

81+
def _init_chat_llms(chat_llm_configs: list[dict]) -> dict[str, Any]:
82+
"""
83+
Initialize chat language models from configuration.
84+
85+
Args:
86+
chat_llm_configs: List of chat LLM configuration dictionaries
87+
88+
Returns:
89+
Dictionary mapping model names to initialized LLM instances
90+
"""
91+
92+
def _list_models(client):
93+
try:
94+
models = (
95+
[model.id for model in client.models.list().data]
96+
if client.models.list().data
97+
else client.models.list().models
98+
)
99+
except Exception as e:
100+
logger.error(f"Error listing models: {e}")
101+
models = []
102+
return models
103+
104+
model_name_instrance_maping = {}
105+
for cfg in chat_llm_configs:
106+
llm = LLMFactory.from_config(cfg["config_class"])
107+
if cfg["support_models"]:
108+
for model_name in cfg["support_models"]:
109+
model_name_instrance_maping[model_name] = llm
110+
return model_name_instrance_maping
111+
112+
80113
def init_server() -> dict[str, Any]:
81114
"""
82115
Initialize all server components and configurations.
@@ -104,6 +137,7 @@ def init_server() -> dict[str, Any]:
104137
# Build component configurations
105138
graph_db_config = build_graph_db_config()
106139
llm_config = build_llm_config()
140+
chat_llm_config = build_chat_llm_config()
107141
embedder_config = build_embedder_config()
108142
mem_reader_config = build_mem_reader_config()
109143
reranker_config = build_reranker_config()
@@ -123,13 +157,16 @@ def init_server() -> dict[str, Any]:
123157
else None
124158
)
125159
llm = LLMFactory.from_config(llm_config)
160+
chat_llms = _init_chat_llms(chat_llm_config)
126161
embedder = EmbedderFactory.from_config(embedder_config)
127162
mem_reader = MemReaderFactory.from_config(mem_reader_config)
128163
reranker = RerankerFactory.from_config(reranker_config)
129164
internet_retriever = InternetRetrieverFactory.from_config(
130165
internet_retriever_config, embedder=embedder
131166
)
132167

168+
# Initialize chat llms
169+
133170
logger.debug("Core components instantiated")
134171

135172
# Initialize memory manager
@@ -276,6 +313,7 @@ def init_server() -> dict[str, Any]:
276313
"graph_db": graph_db,
277314
"mem_reader": mem_reader,
278315
"llm": llm,
316+
"chat_llms": chat_llms,
279317
"embedder": embedder,
280318
"reranker": reranker,
281319
"internet_retriever": internet_retriever,

src/memos/api/handlers/config_builders.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
a configuration dictionary using the appropriate ConfigFactory.
77
"""
88

9+
import json
910
import os
1011

1112
from typing import Any
@@ -81,6 +82,32 @@ def build_llm_config() -> dict[str, Any]:
8182
)
8283

8384

85+
def build_chat_llm_config() -> list[dict[str, Any]]:
86+
"""
87+
Build chat LLM configuration.
88+
89+
Returns:
90+
Validated chat LLM configuration dictionary
91+
"""
92+
configs = json.loads(os.getenv("CHAT_MODEL_LIST"))
93+
return [
94+
{
95+
"config_class": LLMConfigFactory.model_validate(
96+
{
97+
"backend": cfg.get("backend", "openai"),
98+
"config": (
99+
{k: v for k, v in cfg.items() if k not in ["backend", "support_models"]}
100+
)
101+
if cfg
102+
else APIConfig.get_openai_config(),
103+
}
104+
),
105+
"support_models": cfg.get("support_models", None),
106+
}
107+
for cfg in configs
108+
]
109+
110+
84111
def build_embedder_config() -> dict[str, Any]:
85112
"""
86113
Build embedder configuration.

src/memos/api/handlers/memory_handler.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66

77
from typing import Any, Literal
88

9-
from memos.api.product_models import MemoryResponse
9+
from memos.api.handlers.formatters_handler import format_memory_item
10+
from memos.api.product_models import (
11+
DeleteMemoryRequest,
12+
DeleteMemoryResponse,
13+
GetMemoryRequest,
14+
GetMemoryResponse,
15+
MemoryResponse,
16+
)
1017
from memos.log import get_logger
1118
from memos.mem_os.utils.format_utils import (
1219
convert_graph_to_tree_forworkmem,
@@ -149,3 +156,37 @@ def handle_get_subgraph(
149156
except Exception as e:
150157
logger.error(f"Failed to get subgraph: {e}", exc_info=True)
151158
raise
159+
160+
161+
def handle_get_memories(get_mem_req: GetMemoryRequest, naive_mem_cube: Any) -> GetMemoryResponse:
162+
# TODO: Implement get memory with filter
163+
memories = naive_mem_cube.text_mem.get_all(user_name=get_mem_req.mem_cube_id)["nodes"]
164+
filter_params: dict[str, Any] = {}
165+
if get_mem_req.user_id is not None:
166+
filter_params["user_id"] = get_mem_req.user_id
167+
if get_mem_req.mem_cube_id is not None:
168+
filter_params["mem_cube_id"] = get_mem_req.mem_cube_id
169+
preferences = naive_mem_cube.pref_mem.get_memory_by_filter(filter_params)
170+
return GetMemoryResponse(
171+
message="Memories retrieved successfully",
172+
data={
173+
"text_mem": memories,
174+
"pref_mem": [format_memory_item(mem) for mem in preferences],
175+
},
176+
)
177+
178+
179+
def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: Any):
180+
try:
181+
naive_mem_cube.text_mem.delete(delete_mem_req.memory_ids)
182+
naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids)
183+
except Exception as e:
184+
logger.error(f"Failed to delete memories: {e}", exc_info=True)
185+
return DeleteMemoryResponse(
186+
message="Failed to delete memories",
187+
data="failure",
188+
)
189+
return DeleteMemoryResponse(
190+
message="Memories deleted successfully",
191+
data={"status": "success"},
192+
)

src/memos/api/handlers/scheduler_handler.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323

2424
def handle_scheduler_status(
25-
user_name: str | None = None,
25+
mem_cube_id: str | None = None,
2626
mem_scheduler: Any | None = None,
2727
instance_id: str = "",
2828
) -> dict[str, Any]:
@@ -43,17 +43,17 @@ def handle_scheduler_status(
4343
HTTPException: If status retrieval fails
4444
"""
4545
try:
46-
if user_name:
46+
if mem_cube_id:
4747
running = mem_scheduler.dispatcher.get_running_tasks(
48-
lambda task: getattr(task, "mem_cube_id", None) == user_name
48+
lambda task: getattr(task, "mem_cube_id", None) == mem_cube_id
4949
)
5050
tasks_iter = to_iter(running)
5151
running_count = len(tasks_iter)
5252
return {
5353
"message": "ok",
5454
"data": {
5555
"scope": "user",
56-
"user_name": user_name,
56+
"mem_cube_id": mem_cube_id,
5757
"running_tasks": running_count,
5858
"timestamp": time.time(),
5959
"instance_id": instance_id,

src/memos/api/product_models.py

Lines changed: 76 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
1-
import os
21
import uuid
32

4-
from typing import Generic, Literal, TypeVar
3+
from typing import Any, Generic, Literal, TypeVar
54

65
from pydantic import BaseModel, Field
76

@@ -37,7 +36,7 @@ class UserRegisterRequest(BaseRequest):
3736
interests: str | None = Field(None, description="User interests")
3837

3938

40-
class GetMemoryRequest(BaseRequest):
39+
class GetMemoryPlaygroundRequest(BaseRequest):
4140
"""Request model for getting memories."""
4241

4342
user_id: str = Field(..., description="User ID")
@@ -80,9 +79,20 @@ class ChatRequest(BaseRequest):
8079
None, description="List of cube IDs user can write for multi-cube chat"
8180
)
8281
history: list[MessageDict] | None = Field(None, description="Chat history")
82+
mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
8383
internet_search: bool = Field(True, description="Whether to use internet search")
84-
moscube: bool = Field(False, description="Whether to use MemOSCube")
84+
system_prompt: str | None = Field(None, description="Base system prompt to use for chat")
85+
top_k: int = Field(10, description="Number of results to return")
86+
threshold: float = Field(0.5, description="Threshold for filtering references")
8587
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
88+
include_preference: bool = Field(True, description="Whether to handle preference memory")
89+
pref_top_k: int = Field(6, description="Number of preference results to return")
90+
filter: dict[str, Any] | None = Field(None, description="Filter for the memory")
91+
model_name_or_path: str | None = Field(None, description="Model name to use for chat")
92+
max_tokens: int | None = Field(None, description="Max tokens to generate")
93+
temperature: float | None = Field(None, description="Temperature for sampling")
94+
top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter")
95+
add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat")
8696

8797

8898
class ChatCompleteRequest(BaseRequest):
@@ -93,11 +103,18 @@ class ChatCompleteRequest(BaseRequest):
93103
mem_cube_id: str | None = Field(None, description="Cube ID to use for chat")
94104
history: list[MessageDict] | None = Field(None, description="Chat history")
95105
internet_search: bool = Field(False, description="Whether to use internet search")
96-
moscube: bool = Field(False, description="Whether to use MemOSCube")
97-
base_prompt: str | None = Field(None, description="Base prompt to use for chat")
106+
system_prompt: str | None = Field(None, description="Base prompt to use for chat")
98107
top_k: int = Field(10, description="Number of results to return")
99108
threshold: float = Field(0.5, description="Threshold for filtering references")
100109
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
110+
include_preference: bool = Field(True, description="Whether to handle preference memory")
111+
pref_top_k: int = Field(6, description="Number of preference results to return")
112+
filter: dict[str, Any] | None = Field(None, description="Filter for the memory")
113+
model_name_or_path: str | None = Field(None, description="Model name to use for chat")
114+
max_tokens: int | None = Field(None, description="Max tokens to generate")
115+
temperature: float | None = Field(None, description="Temperature for sampling")
116+
top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter")
117+
add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat")
101118

102119

103120
class UserCreate(BaseRequest):
@@ -129,6 +146,10 @@ class SuggestionResponse(BaseResponse[list]):
129146
data: dict[str, list[str]] | None = Field(None, description="Response data")
130147

131148

149+
class AddStatusResponse(BaseResponse[dict]):
150+
"""Response model for add status operations."""
151+
152+
132153
class ConfigResponse(BaseResponse[None]):
133154
"""Response model for configuration endpoint."""
134155

@@ -141,6 +162,14 @@ class ChatResponse(BaseResponse[str]):
141162
"""Response model for chat operations."""
142163

143164

165+
class GetMemoryResponse(BaseResponse[dict]):
166+
"""Response model for getting memories."""
167+
168+
169+
class DeleteMemoryResponse(BaseResponse[dict]):
170+
"""Response model for deleting memories."""
171+
172+
144173
class UserResponse(BaseResponse[dict]):
145174
"""Response model for user operations."""
146175

@@ -181,11 +210,8 @@ class APISearchRequest(BaseRequest):
181210
readable_cube_ids: list[str] | None = Field(
182211
None, description="List of cube IDs user can read for multi-cube search"
183212
)
184-
mode: SearchMode = Field(
185-
os.getenv("SEARCH_MODE", SearchMode.FAST), description="search mode: fast, fine, or mixture"
186-
)
213+
mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
187214
internet_search: bool = Field(False, description="Whether to use internet search")
188-
moscube: bool = Field(False, description="Whether to use MemOSCube")
189215
top_k: int = Field(10, description="Number of results to return")
190216
chat_history: list[MessageDict] | None = Field(None, description="Chat history")
191217
session_id: str | None = Field(None, description="Session ID for soft-filtering memories")
@@ -194,6 +220,7 @@ class APISearchRequest(BaseRequest):
194220
)
195221
include_preference: bool = Field(True, description="Whether to handle preference memory")
196222
pref_top_k: int = Field(6, description="Number of preference results to return")
223+
filter: dict[str, Any] | None = Field(None, description="Filter for the memory")
197224

198225

199226
class APIADDRequest(BaseRequest):
@@ -213,8 +240,13 @@ class APIADDRequest(BaseRequest):
213240
operation: list[PermissionDict] | None = Field(
214241
None, description="operation ids for multi cubes"
215242
)
216-
async_mode: Literal["async", "sync"] | None = Field(
217-
None, description="Whether to add memory in async mode"
243+
async_mode: Literal["async", "sync"] = Field(
244+
"async", description="Whether to add memory in async mode"
245+
)
246+
custom_tags: list[str] | None = Field(None, description="Custom tags for the memory")
247+
info: dict[str, str] | None = Field(None, description="Additional information for the memory")
248+
is_feedback: bool = Field(
249+
False, description="Whether the user feedback in knowladge base service"
218250
)
219251

220252

@@ -232,13 +264,43 @@ class APIChatCompleteRequest(BaseRequest):
232264
)
233265
history: list[MessageDict] | None = Field(None, description="Chat history")
234266
internet_search: bool = Field(False, description="Whether to use internet search")
235-
moscube: bool = Field(True, description="Whether to use MemOSCube")
236-
base_prompt: str | None = Field(None, description="Base prompt to use for chat")
267+
system_prompt: str | None = Field(None, description="Base system prompt to use for chat")
268+
mode: SearchMode = Field(SearchMode.FAST, description="search mode: fast, fine, or mixture")
237269
top_k: int = Field(10, description="Number of results to return")
238270
threshold: float = Field(0.5, description="Threshold for filtering references")
239271
session_id: str | None = Field(
240272
"default_session", description="Session ID for soft-filtering memories"
241273
)
274+
include_preference: bool = Field(True, description="Whether to handle preference memory")
275+
pref_top_k: int = Field(6, description="Number of preference results to return")
276+
filter: dict[str, Any] | None = Field(None, description="Filter for the memory")
277+
model_name_or_path: str | None = Field(None, description="Model name to use for chat")
278+
max_tokens: int | None = Field(None, description="Max tokens to generate")
279+
temperature: float | None = Field(None, description="Temperature for sampling")
280+
top_p: float | None = Field(None, description="Top-p (nucleus) sampling parameter")
281+
add_message_on_answer: bool = Field(True, description="Add dialogs to memory after chat")
282+
283+
284+
class AddStatusRequest(BaseRequest):
285+
"""Request model for checking add status."""
286+
287+
mem_cube_id: str = Field(..., description="Cube ID")
288+
user_id: str | None = Field(None, description="User ID")
289+
session_id: str | None = Field(None, description="Session ID")
290+
291+
292+
class GetMemoryRequest(BaseRequest):
293+
"""Request model for getting memories."""
294+
295+
mem_cube_id: str = Field(..., description="Cube ID")
296+
user_id: str | None = Field(None, description="User ID")
297+
include_preference: bool = Field(True, description="Whether to handle preference memory")
298+
299+
300+
class DeleteMemoryRequest(BaseRequest):
301+
"""Request model for deleting memories."""
302+
303+
memory_ids: list[str] = Field(..., description="Memory IDs")
242304

243305

244306
class SuggestionRequest(BaseRequest):

src/memos/api/routers/product_router.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
BaseResponse,
1111
ChatCompleteRequest,
1212
ChatRequest,
13-
GetMemoryRequest,
13+
GetMemoryPlaygroundRequest,
1414
MemoryCreateRequest,
1515
MemoryResponse,
1616
SearchRequest,
@@ -159,7 +159,7 @@ def get_suggestion_queries_post(suggestion_req: SuggestionRequest):
159159

160160

161161
@router.post("/get_all", summary="Get all memories for user", response_model=MemoryResponse)
162-
def get_all_memories(memory_req: GetMemoryRequest):
162+
def get_all_memories(memory_req: GetMemoryPlaygroundRequest):
163163
"""Get all memories for a specific user."""
164164
try:
165165
mos_product = get_mos_product_instance()

0 commit comments

Comments
 (0)