Skip to content

Commit 0f5f2ef

Browse files
authored
Feat/sources (#616)
* fix: input Pydantic bug * feat: add image parser * feat: back to MessagesType * fix: other-reader bug * feat: update language detaction in string-fine of multi-modal-struct * feat: add language detection
1 parent a72384b commit 0f5f2ef

File tree

2 files changed

+106
-3
lines changed

2 files changed

+106
-3
lines changed

src/memos/configs/mem_reader.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ def parse_datetime(cls, value):
4444
class SimpleStructMemReaderConfig(BaseMemReaderConfig):
4545
"""SimpleStruct MemReader configuration class."""
4646

47-
# Allow passing additional fields without raising validation errors
4847
model_config = ConfigDict(extra="allow", strict=True)
4948

5049

@@ -61,6 +60,8 @@ class MultiModalStructMemReaderConfig(BaseMemReaderConfig):
6160
class StrategyStructMemReaderConfig(BaseMemReaderConfig):
6261
"""StrategyStruct MemReader configuration class."""
6362

63+
model_config = ConfigDict(extra="allow", strict=True)
64+
6465

6566
class MemReaderConfigFactory(BaseConfig):
6667
"""Factory class for creating MemReader configurations."""

src/memos/mem_reader/multi_modal_struct.py

Lines changed: 104 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from memos.configs.mem_reader import MultiModalStructMemReaderConfig
99
from memos.context.context import ContextThreadPoolExecutor
1010
from memos.mem_reader.read_multi_modal import MultiModalParser, detect_lang
11-
from memos.mem_reader.simple_struct import SimpleStructMemReader
11+
from memos.mem_reader.simple_struct import PROMPT_DICT, SimpleStructMemReader
1212
from memos.memories.textual.item import TextualMemoryItem
1313
from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH
1414
from memos.types import MessagesType
@@ -248,6 +248,104 @@ def _build_window_from_items(
248248

249249
return aggregated_item
250250

251+
def _get_llm_response(
252+
self,
253+
mem_str: str,
254+
custom_tags: list[str] | None = None,
255+
sources: list | None = None,
256+
prompt_type: str = "chat",
257+
) -> dict:
258+
"""
259+
Override parent method to improve language detection by using actual text content
260+
from sources instead of JSON-structured memory string.
261+
262+
Args:
263+
mem_str: Memory string (may contain JSON structures)
264+
custom_tags: Optional custom tags
265+
sources: Optional list of SourceMessage objects to extract text content from
266+
prompt_type: Type of prompt to use ("chat" or "doc")
267+
268+
Returns:
269+
LLM response dictionary
270+
"""
271+
# Try to extract actual text content from sources for better language detection
272+
text_for_lang_detection = mem_str
273+
if sources:
274+
source_texts = []
275+
for source in sources:
276+
if hasattr(source, "content") and source.content:
277+
source_texts.append(source.content)
278+
elif isinstance(source, dict) and source.get("content"):
279+
source_texts.append(source.get("content"))
280+
281+
# If we have text content from sources, use it for language detection
282+
if source_texts:
283+
text_for_lang_detection = " ".join(source_texts)
284+
285+
# Use the extracted text for language detection
286+
lang = detect_lang(text_for_lang_detection)
287+
288+
# Select prompt template based on prompt_type
289+
if prompt_type == "doc":
290+
template = PROMPT_DICT["doc"][lang]
291+
examples = "" # doc prompts don't have examples
292+
prompt = template.replace("{chunk_text}", mem_str)
293+
else:
294+
template = PROMPT_DICT["chat"][lang]
295+
examples = PROMPT_DICT["chat"][f"{lang}_example"]
296+
prompt = template.replace("${conversation}", mem_str)
297+
298+
custom_tags_prompt = (
299+
PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags))
300+
if custom_tags
301+
else ""
302+
)
303+
304+
# Replace custom_tags_prompt placeholder (different for doc vs chat)
305+
if prompt_type == "doc":
306+
prompt = prompt.replace("{custom_tags_prompt}", custom_tags_prompt)
307+
else:
308+
prompt = prompt.replace("${custom_tags_prompt}", custom_tags_prompt)
309+
310+
if self.config.remove_prompt_example and examples:
311+
prompt = prompt.replace(examples, "")
312+
messages = [{"role": "user", "content": prompt}]
313+
try:
314+
response_text = self.llm.generate(messages)
315+
response_json = self.parse_json_result(response_text)
316+
except Exception as e:
317+
logger.error(f"[LLM] Exception during chat generation: {e}")
318+
response_json = {
319+
"memory list": [
320+
{
321+
"key": mem_str[:10],
322+
"memory_type": "UserMemory",
323+
"value": mem_str,
324+
"tags": [],
325+
}
326+
],
327+
"summary": mem_str,
328+
}
329+
return response_json
330+
331+
def _determine_prompt_type(self, sources: list) -> str:
332+
"""
333+
Determine prompt type based on sources.
334+
"""
335+
if not sources:
336+
return "chat"
337+
prompt_type = "doc"
338+
for source in sources:
339+
source_role = None
340+
if hasattr(source, "role"):
341+
source_role = source.role
342+
elif isinstance(source, dict):
343+
source_role = source.get("role")
344+
if source_role in {"user", "assistant", "system", "tool"}:
345+
prompt_type = "chat"
346+
347+
return prompt_type
348+
251349
def _process_string_fine(
252350
self,
253351
fast_memory_items: list[TextualMemoryItem],
@@ -270,8 +368,12 @@ def _process_string_fine(
270368
sources = fast_item.metadata.sources or []
271369
if not isinstance(sources, list):
272370
sources = [sources]
371+
372+
# Determine prompt type based on sources
373+
prompt_type = self._determine_prompt_type(sources)
374+
273375
try:
274-
resp = self._get_llm_response(mem_str, custom_tags)
376+
resp = self._get_llm_response(mem_str, custom_tags, sources, prompt_type)
275377
except Exception as e:
276378
logger.error(f"[MultiModalFine] Error calling LLM: {e}")
277379
continue

0 commit comments

Comments
 (0)