Skip to content

Commit 64414ea

Browse files
authored
Merge branch 'dev' into feat/deep-search
2 parents 94dba83 + 8374f96 commit 64414ea

File tree

25 files changed

+3038
-195
lines changed

25 files changed

+3038
-195
lines changed

examples/mem_reader/multimodel_struct_reader.py

Lines changed: 831 additions & 0 deletions
Large diffs are not rendered by default.

examples/mem_reader/reader.py renamed to examples/mem_reader/simple_struct_reader.py

Lines changed: 119 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
import argparse
22
import json
3+
import os
34
import time
45

6+
from typing import Any
7+
8+
from dotenv import load_dotenv
9+
510
from memos.configs.mem_reader import SimpleStructMemReaderConfig
611
from memos.mem_reader.simple_struct import SimpleStructMemReader
712
from memos.memories.textual.item import (
@@ -11,6 +16,10 @@
1116
)
1217

1318

19+
# Load environment variables from .env file
20+
load_dotenv()
21+
22+
1423
def print_textual_memory_item(
1524
item: TextualMemoryItem, max_memory_length: int = 200, indent: int = 0
1625
):
@@ -98,6 +107,104 @@ def print_textual_memory_item_json(item: TextualMemoryItem, indent: int = 2):
98107
print(json.dumps(data, indent=indent, ensure_ascii=False))
99108

100109

110+
def get_reader_config() -> dict[str, Any]:
111+
"""
112+
Get reader configuration from environment variables.
113+
114+
Returns a dictionary that can be used to create SimpleStructMemReaderConfig.
115+
Similar to APIConfig.get_reader_config() in server_router_api.py.
116+
117+
Returns:
118+
Configuration dictionary for SimpleStructMemReaderConfig
119+
"""
120+
openai_api_key = os.getenv("OPENAI_API_KEY")
121+
openai_base_url = os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1")
122+
ollama_api_base = os.getenv("OLLAMA_API_BASE", "http://localhost:11434")
123+
124+
# Get LLM backend and config
125+
llm_backend = os.getenv("MEM_READER_LLM_BACKEND", "openai")
126+
if llm_backend == "ollama":
127+
llm_config = {
128+
"backend": "ollama",
129+
"config": {
130+
"model_name_or_path": os.getenv("MEM_READER_LLM_MODEL", "qwen3:0.6b"),
131+
"api_base": ollama_api_base,
132+
"temperature": float(os.getenv("MEM_READER_LLM_TEMPERATURE", "0.0")),
133+
"remove_think_prefix": os.getenv(
134+
"MEM_READER_LLM_REMOVE_THINK_PREFIX", "true"
135+
).lower()
136+
== "true",
137+
"max_tokens": int(os.getenv("MEM_READER_LLM_MAX_TOKENS", "8192")),
138+
},
139+
}
140+
else: # openai
141+
llm_config = {
142+
"backend": "openai",
143+
"config": {
144+
"model_name_or_path": os.getenv("MEM_READER_LLM_MODEL", "gpt-4o-mini"),
145+
"api_key": openai_api_key or os.getenv("MEMRADER_API_KEY", "EMPTY"),
146+
"api_base": openai_base_url,
147+
"temperature": float(os.getenv("MEM_READER_LLM_TEMPERATURE", "0.5")),
148+
"remove_think_prefix": os.getenv(
149+
"MEM_READER_LLM_REMOVE_THINK_PREFIX", "true"
150+
).lower()
151+
== "true",
152+
"max_tokens": int(os.getenv("MEM_READER_LLM_MAX_TOKENS", "8192")),
153+
},
154+
}
155+
156+
# Get embedder backend and config
157+
embedder_backend = os.getenv(
158+
"MEM_READER_EMBEDDER_BACKEND", os.getenv("MOS_EMBEDDER_BACKEND", "ollama")
159+
)
160+
if embedder_backend == "universal_api":
161+
embedder_config = {
162+
"backend": "universal_api",
163+
"config": {
164+
"provider": os.getenv(
165+
"MEM_READER_EMBEDDER_PROVIDER", os.getenv("MOS_EMBEDDER_PROVIDER", "openai")
166+
),
167+
"api_key": os.getenv(
168+
"MEM_READER_EMBEDDER_API_KEY",
169+
os.getenv("MOS_EMBEDDER_API_KEY", openai_api_key or "sk-xxxx"),
170+
),
171+
"model_name_or_path": os.getenv(
172+
"MEM_READER_EMBEDDER_MODEL",
173+
os.getenv("MOS_EMBEDDER_MODEL", "text-embedding-3-large"),
174+
),
175+
"base_url": os.getenv(
176+
"MEM_READER_EMBEDDER_API_BASE",
177+
os.getenv("MOS_EMBEDDER_API_BASE", openai_base_url),
178+
),
179+
},
180+
}
181+
else: # ollama
182+
embedder_config = {
183+
"backend": "ollama",
184+
"config": {
185+
"model_name_or_path": os.getenv(
186+
"MEM_READER_EMBEDDER_MODEL",
187+
os.getenv("MOS_EMBEDDER_MODEL", "nomic-embed-text:latest"),
188+
),
189+
"api_base": ollama_api_base,
190+
},
191+
}
192+
193+
return {
194+
"llm": llm_config,
195+
"embedder": embedder_config,
196+
"chunker": {
197+
"backend": "sentence",
198+
"config": {
199+
"tokenizer_or_token_counter": "gpt2",
200+
"chunk_size": 512,
201+
"chunk_overlap": 128,
202+
"min_sentences_per_chunk": 1,
203+
},
204+
},
205+
}
206+
207+
101208
def main():
102209
# Parse command line arguments
103210
parser = argparse.ArgumentParser(description="Test Mem-Reader with structured output")
@@ -115,10 +222,18 @@ def main():
115222
)
116223
args = parser.parse_args()
117224

118-
# 1. Create Configuration
119-
reader_config = SimpleStructMemReaderConfig.from_json_file(
120-
"examples/data/config/simple_struct_reader_config.json"
121-
)
225+
# 1. Create Configuration from environment variables or JSON file
226+
# Try to get config from environment variables first
227+
openai_api_key = os.getenv("OPENAI_API_KEY")
228+
if openai_api_key:
229+
# Use environment variables (similar to server_router_api.py)
230+
config_dict = get_reader_config()
231+
reader_config = SimpleStructMemReaderConfig.model_validate(config_dict)
232+
else:
233+
# Fall back to JSON file
234+
reader_config = SimpleStructMemReaderConfig.from_json_file(
235+
"examples/data/config/simple_struct_reader_config.json"
236+
)
122237
reader = SimpleStructMemReader(reader_config)
123238

124239
# 2. Define scene data

src/memos/api/handlers/chat_handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -894,7 +894,7 @@ def _send_message_to_scheduler(
894894
content=query,
895895
timestamp=datetime.utcnow(),
896896
)
897-
self.mem_scheduler.memos_message_queue.submit_messages(messages=[message_item])
897+
self.mem_scheduler.submit_messages(messages=[message_item])
898898
self.logger.info(f"Sent message to scheduler with label: {label}")
899899
except Exception as e:
900900
self.logger.error(f"Failed to send message to scheduler: {e}", exc_info=True)

src/memos/api/handlers/memory_handler.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
MemoryResponse,
1616
)
1717
from memos.log import get_logger
18+
from memos.mem_cube.navie import NaiveMemCube
1819
from memos.mem_os.utils.format_utils import (
1920
convert_graph_to_tree_forworkmem,
2021
ensure_unique_tree_ids,
@@ -162,11 +163,13 @@ def handle_get_subgraph(
162163
raise
163164

164165

165-
def handle_get_memories(get_mem_req: GetMemoryRequest, naive_mem_cube: Any) -> GetMemoryResponse:
166+
def handle_get_memories(
167+
get_mem_req: GetMemoryRequest, naive_mem_cube: NaiveMemCube
168+
) -> GetMemoryResponse:
166169
# TODO: Implement get memory with filter
167170
memories = naive_mem_cube.text_mem.get_all(user_name=get_mem_req.mem_cube_id)["nodes"]
168171
preferences: list[TextualMemoryItem] = []
169-
if get_mem_req.include_preference:
172+
if get_mem_req.include_preference and naive_mem_cube.pref_mem is not None:
170173
filter_params: dict[str, Any] = {}
171174
if get_mem_req.user_id is not None:
172175
filter_params["user_id"] = get_mem_req.user_id
@@ -183,10 +186,11 @@ def handle_get_memories(get_mem_req: GetMemoryRequest, naive_mem_cube: Any) -> G
183186
)
184187

185188

186-
def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: Any):
189+
def handle_delete_memories(delete_mem_req: DeleteMemoryRequest, naive_mem_cube: NaiveMemCube):
187190
try:
188191
naive_mem_cube.text_mem.delete(delete_mem_req.memory_ids)
189-
naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids)
192+
if naive_mem_cube.pref_mem is not None:
193+
naive_mem_cube.pref_mem.delete(delete_mem_req.memory_ids)
190194
except Exception as e:
191195
logger.error(f"Failed to delete memories: {e}", exc_info=True)
192196
return DeleteMemoryResponse(

src/memos/api/handlers/scheduler_handler.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,9 @@ def handle_scheduler_status(
3434
Args:
3535
user_id: User ID to query for.
3636
status_tracker: The TaskStatusTracker instance.
37-
task_id: Optional Task ID to query a specific task.
37+
task_id: Optional Task ID to query. Can be either:
38+
- business_task_id (will aggregate all related item statuses)
39+
- item_id (will return single item status)
3840
3941
Returns:
4042
StatusResponse with a list of task statuses.
@@ -46,12 +48,22 @@ def handle_scheduler_status(
4648

4749
try:
4850
if task_id:
49-
task_data = status_tracker.get_task_status(task_id, user_id)
50-
if not task_data:
51-
raise HTTPException(
52-
status_code=404, detail=f"Task {task_id} not found for user {user_id}"
51+
# First try as business_task_id (aggregated query)
52+
business_task_data = status_tracker.get_task_status_by_business_id(task_id, user_id)
53+
if business_task_data:
54+
response_data.append(
55+
StatusResponseItem(task_id=task_id, status=business_task_data["status"])
56+
)
57+
else:
58+
# Fallback: try as item_id (single item query)
59+
item_task_data = status_tracker.get_task_status(task_id, user_id)
60+
if not item_task_data:
61+
raise HTTPException(
62+
status_code=404, detail=f"Task {task_id} not found for user {user_id}"
63+
)
64+
response_data.append(
65+
StatusResponseItem(task_id=task_id, status=item_task_data["status"])
5366
)
54-
response_data.append(StatusResponseItem(task_id=task_id, status=task_data["status"]))
5567
else:
5668
all_tasks = status_tracker.get_all_tasks_for_user(user_id)
5769
# The plan returns an empty list, which is good.

src/memos/api/product_models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,7 @@ class MemoryCreateRequest(BaseRequest):
258258
source: str | None = Field(None, description="Source of the memory")
259259
user_profile: bool = Field(False, description="User profile memory")
260260
session_id: str | None = Field(None, description="Session id")
261+
task_id: str | None = Field(None, description="Task ID for monitoring async tasks")
261262

262263

263264
class SearchRequest(BaseRequest):

src/memos/api/routers/product_router.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,9 +188,43 @@ def get_all_memories(memory_req: GetMemoryPlaygroundRequest):
188188
@router.post("/add", summary="add a new memory", response_model=SimpleResponse)
189189
def create_memory(memory_req: MemoryCreateRequest):
190190
"""Create a new memory for a specific user."""
191+
# Initialize status_tracker outside try block to avoid NameError in except blocks
192+
status_tracker = None
193+
191194
try:
192195
time_start_add = time.time()
193196
mos_product = get_mos_product_instance()
197+
198+
# Track task if task_id is provided
199+
item_id: str | None = None
200+
if (
201+
memory_req.task_id
202+
and hasattr(mos_product, "mem_scheduler")
203+
and mos_product.mem_scheduler
204+
):
205+
from uuid import uuid4
206+
207+
from memos.mem_scheduler.utils.status_tracker import TaskStatusTracker
208+
209+
item_id = str(uuid4()) # Generate a unique item_id for this submission
210+
211+
# Get Redis client from scheduler
212+
if (
213+
hasattr(mos_product.mem_scheduler, "redis_client")
214+
and mos_product.mem_scheduler.redis_client
215+
):
216+
status_tracker = TaskStatusTracker(mos_product.mem_scheduler.redis_client)
217+
# Submit task with "product_add" type
218+
status_tracker.task_submitted(
219+
task_id=item_id, # Use generated item_id for internal tracking
220+
user_id=memory_req.user_id,
221+
task_type="product_add",
222+
mem_cube_id=memory_req.mem_cube_id or memory_req.user_id,
223+
business_task_id=memory_req.task_id, # Use memory_req.task_id as business_task_id
224+
)
225+
status_tracker.task_started(item_id, memory_req.user_id) # Use item_id here
226+
227+
# Execute the add operation
194228
mos_product.add(
195229
user_id=memory_req.user_id,
196230
memory_content=memory_req.memory_content,
@@ -200,15 +234,27 @@ def create_memory(memory_req: MemoryCreateRequest):
200234
source=memory_req.source,
201235
user_profile=memory_req.user_profile,
202236
session_id=memory_req.session_id,
237+
task_id=memory_req.task_id,
203238
)
239+
240+
# Mark task as completed
241+
if status_tracker and item_id:
242+
status_tracker.task_completed(item_id, memory_req.user_id)
243+
204244
logger.info(
205245
f"time add api : add time user_id: {memory_req.user_id} time is: {time.time() - time_start_add}"
206246
)
207247
return SimpleResponse(message="Memory created successfully")
208248

209249
except ValueError as err:
250+
# Mark task as failed if tracking
251+
if status_tracker and item_id:
252+
status_tracker.task_failed(item_id, memory_req.user_id, str(err))
210253
raise HTTPException(status_code=404, detail=str(traceback.format_exc())) from err
211254
except Exception as err:
255+
# Mark task as failed if tracking
256+
if status_tracker and item_id:
257+
status_tracker.task_failed(item_id, memory_req.user_id, str(err))
212258
logger.error(f"Failed to create memory: {traceback.format_exc()}")
213259
raise HTTPException(status_code=500, detail=str(traceback.format_exc())) from err
214260

0 commit comments

Comments
 (0)