Skip to content

Commit ef02140

Browse files
committed
feat: add image parser in file
2 parents ac38046 + 1923128 commit ef02140

File tree

14 files changed

+237
-30
lines changed

14 files changed

+237
-30
lines changed

examples/mem_scheduler/try_schedule_modules.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,26 @@ def init_task():
4141
"role": "user",
4242
"content": "Max needs joint supplements. Actually, we're moving to Chicago next month.",
4343
},
44+
{
45+
"role": "assistant",
46+
"content": "Got it — Max is on joint supplements, and you’re relocating to Chicago soon. That’s a big move! Have you looked into how the change in climate or vet access might affect his needs?",
47+
},
4448
{
4549
"role": "user",
4650
"content": "Correction: Bella is 6, not 5. And she's allergic to chicken.",
4751
},
52+
{
53+
"role": "assistant",
54+
"content": "Thanks for the update! So Bella is 6 years old and has a chicken allergy — good to know. You’ll want to double-check her food and treats, especially during the move. Has she had any reactions recently?",
55+
},
4856
{
4957
"role": "user",
5058
"content": "My partner's cat Whiskers visits weekends. Bella chases her sometimes.",
5159
},
60+
{
61+
"role": "assistant",
62+
"content": "Ah, the classic dog-and-cat dynamic! Since Bella chases Whiskers, it might help to give them gradual supervised interactions or create safe zones for the cat—especially important as you settle into a new home in Chicago. Keeping Bella’s routine stable during the move could also reduce her urge to chase. How do they usually get along when Whiskers visits?",
63+
},
5264
]
5365

5466
questions = [
@@ -145,18 +157,25 @@ def start_conversation(self, user_id="test_user", mem_cube_id="test_cube", sessi
145157
print(f" User ID: {self.current_user_id}")
146158
print(f" Mem Cube ID: {self.current_mem_cube_id}")
147159

148-
def add_msgs(self, messages: list[dict]):
160+
def add_msgs(
161+
self,
162+
messages: list[dict],
163+
extract_mode: str = "fine",
164+
async_mode: str = "sync",
165+
):
149166
# Create add request
150167
add_req = self.create_test_add_request(
151168
user_id=self.current_user_id,
152169
mem_cube_id=self.current_mem_cube_id,
153170
messages=messages,
154171
session_id=self.current_session_id,
172+
extract_mode=extract_mode,
173+
async_mode=async_mode,
155174
)
156175

157176
# Add to memory
158177
result = self.add_memories(add_req)
159-
print(f" ✅ Added to memory successfully: \n{messages}")
178+
print(f" ✅ Added to memory successfully: \n{result}")
160179

161180
return result
162181

src/memos/mem_reader/read_multi_modal/file_content_parser.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ def create_source(
259259
chunk_index: int | None = None,
260260
chunk_total: int | None = None,
261261
chunk_content: str | None = None,
262+
file_url_flag: bool = False,
262263
) -> SourceMessage:
263264
"""Create SourceMessage from file content part."""
264265
if isinstance(message, dict):
@@ -267,6 +268,7 @@ def create_source(
267268
"type": "file",
268269
"doc_path": file_info.get("filename") or file_info.get("file_id", ""),
269270
"content": chunk_content if chunk_content else file_info.get("file_data", ""),
271+
"file_info": file_info if file_url_flag else {},
270272
}
271273
# Add chunk ordering information if provided
272274
if chunk_index is not None:
@@ -291,10 +293,7 @@ def rebuild_from_source(
291293
# Rebuild from source fields
292294
return {
293295
"type": "file",
294-
"file": {
295-
"filename": source.doc_path or "",
296-
"file_data": source.content or "",
297-
},
296+
"file": source.file_info,
298297
}
299298

300299
def _parse_file(self, file_info: dict[str, Any]) -> str:
@@ -367,7 +366,7 @@ def parse_fast(
367366
file_data = file_info.get("file_data", "")
368367
file_id = file_info.get("file_id", "")
369368
filename = file_info.get("filename", "")
370-
369+
file_url_flag = False
371370
# Build content string based on available information
372371
content_parts = []
373372

@@ -386,6 +385,7 @@ def parse_fast(
386385
content_parts.append(f"[File Data (base64/encoded): {len(file_data)} chars]")
387386
# Check if it looks like a URL
388387
elif file_data.startswith(("http://", "https://", "file://")):
388+
file_url_flag = True
389389
content_parts.append(f"[File URL: {file_data}]")
390390
else:
391391
# TODO: split into multiple memory items
@@ -437,6 +437,7 @@ def parse_fast(
437437
chunk_index=chunk_idx,
438438
chunk_total=total_chunks,
439439
chunk_content=chunk_text,
440+
file_url_flag=file_url_flag,
440441
)
441442

442443
memory_item = TextualMemoryItem(
@@ -473,6 +474,7 @@ def parse_fast(
473474
chunk_index=None,
474475
chunk_total=0,
475476
chunk_content=content,
477+
file_url_flag=file_url_flag,
476478
)
477479
memory_item = TextualMemoryItem(
478480
memory=content,

src/memos/mem_reader/read_multi_modal/user_parser.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def create_source(
8080
message_id=message_id,
8181
doc_path=file_info.get("filename") or file_info.get("file_id", ""),
8282
content=file_info.get("file_data", ""),
83+
file_info=file_info,
8384
)
8485
)
8586
elif part_type == "image_url":

src/memos/mem_reader/simple_struct.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import concurrent.futures
22
import copy
33
import json
4+
import os
45
import re
56
import traceback
67

@@ -25,6 +26,7 @@
2526
from memos.templates.mem_reader_prompts import (
2627
CUSTOM_TAGS_INSTRUCTION,
2728
CUSTOM_TAGS_INSTRUCTION_ZH,
29+
PROMPT_MAPPING,
2830
SIMPLE_STRUCT_DOC_READER_PROMPT,
2931
SIMPLE_STRUCT_DOC_READER_PROMPT_ZH,
3032
SIMPLE_STRUCT_MEM_READER_EXAMPLE,
@@ -80,6 +82,7 @@ def from_config(_config):
8082
"custom_tags": {"en": CUSTOM_TAGS_INSTRUCTION, "zh": CUSTOM_TAGS_INSTRUCTION_ZH},
8183
}
8284

85+
8386
try:
8487
import tiktoken
8588

@@ -448,6 +451,81 @@ def get_memory(
448451
standard_scene_data = coerce_scene_data(scene_data, type)
449452
return self._read_memory(standard_scene_data, type, info, mode)
450453

454+
@staticmethod
455+
def _parse_hallucination_filter_response(text: str) -> tuple[bool, dict[int, dict]]:
456+
"""Parse index-keyed JSON from hallucination filter response.
457+
Expected shape: { "0": {"if_delete": bool, "rewritten memory content": str}, ... }
458+
Returns (success, parsed_dict) with int keys.
459+
"""
460+
try:
461+
data = json.loads(text)
462+
except Exception:
463+
return False, {}
464+
465+
if not isinstance(data, dict):
466+
return False, {}
467+
468+
result: dict[int, dict] = {}
469+
for k, v in data.items():
470+
try:
471+
idx = int(k)
472+
except Exception:
473+
# allow integer keys as-is
474+
if isinstance(k, int):
475+
idx = k
476+
else:
477+
continue
478+
if not isinstance(v, dict):
479+
continue
480+
delete_flag = v.get("delete_flag")
481+
rewritten = v.get("rewritten memory content", "")
482+
if isinstance(delete_flag, bool) and isinstance(rewritten, str):
483+
result[idx] = {"delete_flag": delete_flag, "rewritten memory content": rewritten}
484+
485+
return (len(result) > 0), result
486+
487+
def filter_hallucination_in_memories(
488+
self, user_messages: list[str], memory_list: list[list[TextualMemoryItem]]
489+
):
490+
filtered_memory_list = []
491+
for group in memory_list:
492+
try:
493+
flat_memories = [one.memory for one in group]
494+
template = PROMPT_MAPPING["hallucination_filter"]
495+
prompt_args = {
496+
"user_messages_inline": "\n".join(user_messages),
497+
"memories_inline": json.dumps(flat_memories, ensure_ascii=False, indent=2),
498+
}
499+
prompt = template.format(**prompt_args)
500+
501+
# Optionally run filter and parse the output
502+
try:
503+
raw = self.llm.generate(prompt)
504+
success, parsed = self._parse_hallucination_filter_response(raw)
505+
logger.info(f"Hallucination filter parsed successfully: {success}")
506+
new_mem_list = []
507+
if success:
508+
logger.info(f"Hallucination filter result: {parsed}")
509+
for mem_idx, (delete_flag, rewritten_mem_content) in parsed.items():
510+
if not delete_flag:
511+
group[mem_idx].memory = rewritten_mem_content
512+
new_mem_list.append(group[mem_idx])
513+
filtered_memory_list.append(new_mem_list)
514+
logger.info(
515+
f"Successfully transform origianl memories from {group} to {new_mem_list}."
516+
)
517+
else:
518+
logger.warning(
519+
"Hallucination filter parsing failed or returned empty result."
520+
)
521+
except Exception as e:
522+
logger.error(f"Hallucination filter execution error: {e}", stack_info=True)
523+
filtered_memory_list.append(group)
524+
except Exception:
525+
logger.error("Fail to filter memories", stack_info=True)
526+
filtered_memory_list.append(group)
527+
return filtered_memory_list
528+
451529
def _read_memory(
452530
self, messages: list[MessagesType], type: str, info: dict[str, Any], mode: str = "fine"
453531
) -> list[list[TextualMemoryItem]]:
@@ -492,6 +570,14 @@ def _read_memory(
492570
except Exception as e:
493571
logger.error(f"Task failed with exception: {e}")
494572
logger.error(traceback.format_exc())
573+
574+
if os.getenv("SIMPLE_STRUCT_ADD_FILTER", "false") == "true":
575+
# Build inputs
576+
user_messages = [msg.content for msg in messages if msg.role == "user"]
577+
memory_list = self.filter_hallucination_in_memories(
578+
user_messages=user_messages, memory_list=memory_list
579+
)
580+
495581
return memory_list
496582

497583
def fine_transfer_simple_mem(

src/memos/mem_scheduler/analyzer/api_analyzer.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -599,6 +599,8 @@ def create_test_add_request(
599599
messages=None,
600600
memory_content=None,
601601
session_id=None,
602+
extract_mode=None,
603+
async_mode="sync",
602604
):
603605
"""
604606
Create a test APIADDRequest object with the given parameters.
@@ -637,6 +639,8 @@ def create_test_add_request(
637639
source="api_analyzer_test",
638640
chat_history=None,
639641
operation=None,
642+
mode=extract_mode,
643+
async_mode=async_mode,
640644
)
641645

642646
def run_all_tests(self, mode=SearchMode.MIXTURE):

src/memos/mem_scheduler/base_scheduler.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -140,12 +140,7 @@ def __init__(self, config: BaseSchedulerConfig):
140140
"max_internal_message_queue_size", DEFAULT_MAX_INTERNAL_MESSAGE_QUEUE_SIZE
141141
)
142142
self.orchestrator = SchedulerOrchestrator()
143-
self.memos_message_queue = ScheduleTaskQueue(
144-
use_redis_queue=self.use_redis_queue,
145-
maxsize=self.max_internal_message_queue_size,
146-
disabled_handlers=self.disabled_handlers,
147-
orchestrator=self.orchestrator,
148-
)
143+
149144
self.searcher: Searcher | None = None
150145
self.retriever: SchedulerRetriever | None = None
151146
self.db_engine: Engine | None = None
@@ -155,6 +150,13 @@ def __init__(self, config: BaseSchedulerConfig):
155150
self.status_tracker: TaskStatusTracker | None = None
156151
self.metrics = metrics
157152
self._monitor_thread = None
153+
self.memos_message_queue = ScheduleTaskQueue(
154+
use_redis_queue=self.use_redis_queue,
155+
maxsize=self.max_internal_message_queue_size,
156+
disabled_handlers=self.disabled_handlers,
157+
orchestrator=self.orchestrator,
158+
status_tracker=self.status_tracker,
159+
)
158160
self.dispatcher = SchedulerDispatcher(
159161
config=self.config,
160162
memos_message_queue=self.memos_message_queue,
@@ -228,6 +230,8 @@ def initialize_modules(
228230
self.status_tracker = TaskStatusTracker(redis_client)
229231
if self.dispatcher:
230232
self.dispatcher.status_tracker = self.status_tracker
233+
if self.memos_message_queue:
234+
self.memos_message_queue.status_tracker = self.status_tracker
231235
# initialize submodules
232236
self.chat_llm = chat_llm
233237
self.process_llm = process_llm
@@ -712,7 +716,13 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt
712716
# emit enqueue events for consistency
713717
for m in immediate_msgs:
714718
emit_monitor_event(
715-
"enqueue", m, {"enqueue_ts": to_iso(getattr(m, "timestamp", None))}
719+
"enqueue",
720+
m,
721+
{
722+
"enqueue_ts": to_iso(getattr(m, "timestamp", None)),
723+
"event_duration_ms": 0,
724+
"total_duration_ms": 0,
725+
},
716726
)
717727

718728
# simulate dequeue for immediately dispatched messages so monitor logs stay complete
@@ -741,6 +751,8 @@ def submit_messages(self, messages: ScheduleMessageItem | list[ScheduleMessageIt
741751
"enqueue_ts": to_iso(enqueue_ts_obj),
742752
"dequeue_ts": datetime.fromtimestamp(now, tz=timezone.utc).isoformat(),
743753
"queue_wait_ms": queue_wait_ms,
754+
"event_duration_ms": queue_wait_ms,
755+
"total_duration_ms": queue_wait_ms,
744756
},
745757
)
746758
self.metrics.task_dequeued(user_id=m.user_id, task_type=m.label)
@@ -919,6 +931,8 @@ def _message_consumer(self) -> None:
919931
now, tz=timezone.utc
920932
).isoformat(),
921933
"queue_wait_ms": queue_wait_ms,
934+
"event_duration_ms": queue_wait_ms,
935+
"total_duration_ms": queue_wait_ms,
922936
},
923937
)
924938
self.metrics.task_dequeued(user_id=msg.user_id, task_type=msg.label)

src/memos/mem_scheduler/general_modules/misc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def put(self, item: T, block: bool = False, timeout: float | None = None) -> Non
233233

234234
def get(
235235
self, block: bool = True, timeout: float | None = None, batch_size: int | None = None
236-
) -> list[T] | T:
236+
) -> list[T]:
237237
"""Get items from the queue.
238238
239239
Args:

src/memos/mem_scheduler/schemas/task_schemas.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,9 @@ class TaskPriorityLevel(Enum):
6262

6363

6464
# task queue
65-
DEFAULT_STREAM_KEY_PREFIX = "scheduler:messages:stream:v1.7"
66-
exchange_name = os.getenv("MEMSCHEDULER_RABBITMQ_EXCHANGE_NAME", None)
67-
if exchange_name is not None:
68-
DEFAULT_STREAM_KEY_PREFIX += f":{exchange_name}"
65+
DEFAULT_STREAM_KEY_PREFIX = os.getenv(
66+
"MEMSCHEDULER_STREAM_KEY_PREFIX", "scheduler:messages:stream:v2.0"
67+
)
6968

7069

7170
# ============== Running Tasks ==============

src/memos/mem_scheduler/task_schedule_modules/dispatcher.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
185185
if isinstance(dequeue_ts, int | float)
186186
else None
187187
),
188+
"event_duration_ms": start_delay_ms,
189+
"total_duration_ms": self._calc_total_duration_ms(start_time, enq_ts),
188190
},
189191
)
190192

@@ -210,6 +212,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
210212
finish_time, tz=timezone.utc
211213
).isoformat(),
212214
"exec_duration_ms": duration * 1000,
215+
"event_duration_ms": duration * 1000,
213216
"total_duration_ms": self._calc_total_duration_ms(
214217
finish_time, getattr(first_msg, "timestamp", None)
215218
),
@@ -244,6 +247,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
244247
finish_time, tz=timezone.utc
245248
).isoformat(),
246249
"exec_duration_ms": (finish_time - start_time) * 1000,
250+
"event_duration_ms": (finish_time - start_time) * 1000,
247251
"error_type": type(e).__name__,
248252
"error_msg": str(e),
249253
"total_duration_ms": self._calc_total_duration_ms(
@@ -273,6 +277,7 @@ def wrapped_handler(messages: list[ScheduleMessageItem]):
273277
mem_cube_id=msg.mem_cube_id,
274278
task_label=msg.label,
275279
redis_message_id=redis_message_id,
280+
message=msg,
276281
)
277282
except Exception as ack_err:
278283
logger.warning(f"Ack in finally failed: {ack_err}")

0 commit comments

Comments
 (0)