Skip to content

Commit 34cc741

Browse files
committed
feat: add language detection
1 parent 5e23d37 commit 34cc741

File tree

1 file changed

+46
-7
lines changed

1 file changed

+46
-7
lines changed

src/memos/mem_reader/multi_modal_struct.py

Lines changed: 46 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,11 @@ def _build_window_from_items(
249249
return aggregated_item
250250

251251
def _get_llm_response(
252-
self, mem_str: str, custom_tags: list[str] | None = None, sources: list | None = None
252+
self,
253+
mem_str: str,
254+
custom_tags: list[str] | None = None,
255+
sources: list | None = None,
256+
prompt_type: str = "chat",
253257
) -> dict:
254258
"""
255259
Override parent method to improve language detection by using actual text content
@@ -259,6 +263,7 @@ def _get_llm_response(
259263
mem_str: Memory string (may contain JSON structures)
260264
custom_tags: Optional custom tags
261265
sources: Optional list of SourceMessage objects to extract text content from
266+
prompt_type: Type of prompt to use ("chat" or "doc")
262267
263268
Returns:
264269
LLM response dictionary
@@ -279,18 +284,30 @@ def _get_llm_response(
279284

280285
# Use the extracted text for language detection
281286
lang = detect_lang(text_for_lang_detection)
282-
template = PROMPT_DICT["chat"][lang]
283-
examples = PROMPT_DICT["chat"][f"{lang}_example"]
284-
prompt = template.replace("${conversation}", mem_str)
287+
288+
# Select prompt template based on prompt_type
289+
if prompt_type == "doc":
290+
template = PROMPT_DICT["doc"][lang]
291+
examples = "" # doc prompts don't have examples
292+
prompt = template.replace("{chunk_text}", mem_str)
293+
else:
294+
template = PROMPT_DICT["chat"][lang]
295+
examples = PROMPT_DICT["chat"][f"{lang}_example"]
296+
prompt = template.replace("${conversation}", mem_str)
285297

286298
custom_tags_prompt = (
287299
PROMPT_DICT["custom_tags"][lang].replace("{custom_tags}", str(custom_tags))
288300
if custom_tags
289301
else ""
290302
)
291-
prompt = prompt.replace("${custom_tags_prompt}", custom_tags_prompt)
292303

293-
if self.config.remove_prompt_example:
304+
# Replace custom_tags_prompt placeholder (different for doc vs chat)
305+
if prompt_type == "doc":
306+
prompt = prompt.replace("{custom_tags_prompt}", custom_tags_prompt)
307+
else:
308+
prompt = prompt.replace("${custom_tags_prompt}", custom_tags_prompt)
309+
310+
if self.config.remove_prompt_example and examples:
294311
prompt = prompt.replace(examples, "")
295312
messages = [{"role": "user", "content": prompt}]
296313
try:
@@ -311,6 +328,24 @@ def _get_llm_response(
311328
}
312329
return response_json
313330

331+
def _determine_prompt_type(self, sources: list) -> str:
332+
"""
333+
Determine prompt type based on sources.
334+
"""
335+
if not sources:
336+
return "chat"
337+
prompt_type = "doc"
338+
for source in sources:
339+
source_role = None
340+
if hasattr(source, "role"):
341+
source_role = source.role
342+
elif isinstance(source, dict):
343+
source_role = source.get("role")
344+
if source_role in {"user", "assistant", "system", "tool"}:
345+
prompt_type = "chat"
346+
347+
return prompt_type
348+
314349
def _process_string_fine(
315350
self,
316351
fast_memory_items: list[TextualMemoryItem],
@@ -333,8 +368,12 @@ def _process_string_fine(
333368
sources = fast_item.metadata.sources or []
334369
if not isinstance(sources, list):
335370
sources = [sources]
371+
372+
# Determine prompt type based on sources
373+
prompt_type = self._determine_prompt_type(sources)
374+
336375
try:
337-
resp = self._get_llm_response(mem_str, custom_tags, sources)
376+
resp = self._get_llm_response(mem_str, custom_tags, sources, prompt_type)
338377
except Exception as e:
339378
logger.error(f"[MultiModalFine] Error calling LLM: {e}")
340379
continue

0 commit comments

Comments
 (0)