Skip to content

Commit 07a8994

Browse files
authored
fix: input Pydantic bug (#602)
* fix: input Pydantic bug * feat: add image parser * feat: back to MessagesType
1 parent 53aa48c commit 07a8994

File tree

13 files changed

+386
-39
lines changed

13 files changed

+386
-39
lines changed

src/memos/mem_feedback/feedback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from memos.mem_feedback.base import BaseMemFeedback
1818
from memos.mem_feedback.utils import should_keep_update, split_into_chunks
1919
from memos.mem_reader.factory import MemReaderFactory
20-
from memos.mem_reader.simple_struct import detect_lang
20+
from memos.mem_reader.read_multi_modal import detect_lang
2121
from memos.memories.textual.item import TextualMemoryItem, TreeNodeTextualMemoryMetadata
2222
from memos.memories.textual.tree_text_memory.organize.manager import (
2323
MemoryManager,

src/memos/mem_reader/multi_modal_struct.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77
from memos import log
88
from memos.configs.mem_reader import MultiModalStructMemReaderConfig
99
from memos.context.context import ContextThreadPoolExecutor
10-
from memos.mem_reader.read_multi_modal import MultiModalParser
11-
from memos.mem_reader.simple_struct import SimpleStructMemReader, detect_lang
10+
from memos.mem_reader.read_multi_modal import MultiModalParser, detect_lang
11+
from memos.mem_reader.simple_struct import SimpleStructMemReader
1212
from memos.memories.textual.item import TextualMemoryItem
1313
from memos.templates.tool_mem_prompts import TOOL_TRAJECTORY_PROMPT_EN, TOOL_TRAJECTORY_PROMPT_ZH
1414
from memos.types import MessagesType

src/memos/mem_reader/read_multi_modal/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from .text_content_parser import TextContentParser
2424
from .tool_parser import ToolParser
2525
from .user_parser import UserParser
26-
from .utils import coerce_scene_data, extract_role
26+
from .utils import coerce_scene_data, detect_lang, extract_role
2727

2828

2929
__all__ = [
@@ -38,5 +38,6 @@
3838
"ToolParser",
3939
"UserParser",
4040
"coerce_scene_data",
41+
"detect_lang",
4142
"extract_role",
4243
]

src/memos/mem_reader/read_multi_modal/image_parser.py

Lines changed: 271 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
"""Parser for image_url content parts."""
22

3+
import json
4+
import re
5+
36
from typing import Any
47

58
from memos.embedders.base import BaseEmbedder
69
from memos.llms.base import BaseLLM
710
from memos.log import get_logger
8-
from memos.memories.textual.item import SourceMessage, TextualMemoryItem
11+
from memos.memories.textual.item import (
12+
SourceMessage,
13+
TextualMemoryItem,
14+
TreeNodeTextualMemoryMetadata,
15+
)
16+
from memos.templates.mem_reader_prompts import IMAGE_ANALYSIS_PROMPT_EN, IMAGE_ANALYSIS_PROMPT_ZH
917
from memos.types.openai_chat_completion_types import ChatCompletionContentPartImageParam
1018

11-
from .base import BaseMessageParser
19+
from .base import BaseMessageParser, _derive_key
20+
from .utils import detect_lang
1221

1322

1423
logger = get_logger(__name__)
@@ -43,7 +52,7 @@ def create_source(
4352
detail = "auto"
4453
return SourceMessage(
4554
type="image",
46-
content=f"[image_url]: {url}",
55+
content=url,
4756
original_part=message,
4857
url=url,
4958
detail=detail,
@@ -87,7 +96,262 @@ def parse_fine(
8796
info: dict[str, Any],
8897
**kwargs,
8998
) -> list[TextualMemoryItem]:
90-
"""Parse image_url in fine mode - placeholder for future vision model integration."""
91-
# Fine mode processing would use vision models to extract text from images
92-
# For now, return empty list
93-
return []
99+
"""
100+
Parse image_url in fine mode using vision models to extract information from images.
101+
102+
Args:
103+
message: Image message to parse
104+
info: Dictionary containing user_id and session_id
105+
**kwargs: Additional parameters (e.g., context_items, custom_tags)
106+
107+
Returns:
108+
List of TextualMemoryItem objects extracted from the image
109+
"""
110+
if not self.llm:
111+
logger.warning("[ImageParser] LLM not available for fine mode processing")
112+
return []
113+
114+
# Extract image information
115+
if not isinstance(message, dict):
116+
logger.warning(f"[ImageParser] Expected dict, got {type(message)}")
117+
return []
118+
119+
image_url = message.get("image_url", {})
120+
if isinstance(image_url, dict):
121+
url = image_url.get("url", "")
122+
detail = image_url.get("detail", "auto")
123+
else:
124+
url = str(image_url)
125+
detail = "auto"
126+
127+
if not url:
128+
logger.warning("[ImageParser] No image URL found in message")
129+
return []
130+
131+
# Create source for this image
132+
source = self.create_source(message, info)
133+
134+
# Get context items if available
135+
context_items = kwargs.get("context_items")
136+
137+
# Determine language from context if available
138+
lang = "en"
139+
if context_items:
140+
for item in context_items:
141+
if hasattr(item, "memory") and item.memory:
142+
lang = detect_lang(item.memory)
143+
break
144+
145+
# Select prompt based on language
146+
image_analysis_prompt = (
147+
IMAGE_ANALYSIS_PROMPT_ZH if lang == "zh" else IMAGE_ANALYSIS_PROMPT_EN
148+
)
149+
150+
# Build messages with image content
151+
messages = [
152+
{
153+
"role": "user",
154+
"content": [
155+
{"type": "text", "text": image_analysis_prompt},
156+
{
157+
"type": "image_url",
158+
"image_url": {
159+
"url": url,
160+
"detail": detail,
161+
},
162+
},
163+
],
164+
}
165+
]
166+
167+
# Add context if available
168+
if context_items:
169+
context_text = ""
170+
for item in context_items:
171+
if hasattr(item, "memory") and item.memory:
172+
context_text += f"{item.memory}\n"
173+
if context_text:
174+
messages.insert(
175+
0,
176+
{
177+
"role": "system",
178+
"content": f"Context from previous conversation:\n{context_text}",
179+
},
180+
)
181+
182+
try:
183+
# Call LLM with vision model
184+
response_text = self.llm.generate(messages)
185+
if not response_text:
186+
logger.warning("[ImageParser] Empty response from LLM")
187+
return []
188+
189+
# Parse JSON response
190+
response_json = self._parse_json_result(response_text)
191+
192+
# Extract memory items from response
193+
memory_items = []
194+
memory_list = response_json.get("memory list", [])
195+
196+
if not memory_list:
197+
logger.warning("[ImageParser] No memory items extracted from image")
198+
# Fallback: create a simple memory item with the summary
199+
summary = response_json.get(
200+
"summary", "Image analyzed but no specific memories extracted."
201+
)
202+
if summary:
203+
memory_items.append(
204+
self._create_memory_item(
205+
value=summary,
206+
info=info,
207+
memory_type="LongTermMemory",
208+
tags=["image", "visual"],
209+
key=_derive_key(summary),
210+
sources=[source],
211+
background=summary,
212+
)
213+
)
214+
return memory_items
215+
216+
# Create memory items from parsed response
217+
for mem_data in memory_list:
218+
try:
219+
# Normalize memory_type
220+
memory_type = (
221+
mem_data.get("memory_type", "LongTermMemory")
222+
.replace("长期记忆", "LongTermMemory")
223+
.replace("用户记忆", "UserMemory")
224+
)
225+
if memory_type not in ["LongTermMemory", "UserMemory"]:
226+
memory_type = "LongTermMemory"
227+
228+
value = mem_data.get("value", "").strip()
229+
if not value:
230+
continue
231+
232+
tags = mem_data.get("tags", [])
233+
if not isinstance(tags, list):
234+
tags = []
235+
# Add image-related tags
236+
if "image" not in [t.lower() for t in tags]:
237+
tags.append("image")
238+
if "visual" not in [t.lower() for t in tags]:
239+
tags.append("visual")
240+
241+
key = mem_data.get("key", "")
242+
background = response_json.get("summary", "")
243+
244+
memory_item = self._create_memory_item(
245+
value=value,
246+
info=info,
247+
memory_type=memory_type,
248+
tags=tags,
249+
key=key if key else _derive_key(value),
250+
sources=[source],
251+
background=background,
252+
)
253+
memory_items.append(memory_item)
254+
except Exception as e:
255+
logger.error(f"[ImageParser] Error creating memory item: {e}")
256+
continue
257+
258+
return memory_items
259+
260+
except Exception as e:
261+
logger.error(f"[ImageParser] Error processing image in fine mode: {e}")
262+
# Fallback: create a simple memory item
263+
fallback_value = f"Image analyzed: {url}"
264+
return [
265+
self._create_memory_item(
266+
value=fallback_value,
267+
info=info,
268+
memory_type="LongTermMemory",
269+
tags=["image", "visual"],
270+
key=_derive_key(fallback_value),
271+
sources=[source],
272+
background="Image processing encountered an error.",
273+
)
274+
]
275+
276+
def _parse_json_result(self, response_text: str) -> dict:
277+
"""
278+
Parse JSON result from LLM response.
279+
Similar to SimpleStructMemReader.parse_json_result.
280+
"""
281+
s = (response_text or "").strip()
282+
283+
# Try to extract JSON from code blocks
284+
m = re.search(r"```(?:json)?\s*([\s\S]*?)```", s, flags=re.I)
285+
s = (m.group(1) if m else s.replace("```", "")).strip()
286+
287+
# Find first {
288+
i = s.find("{")
289+
if i == -1:
290+
return {}
291+
s = s[i:].strip()
292+
293+
try:
294+
return json.loads(s)
295+
except json.JSONDecodeError:
296+
pass
297+
298+
# Try to find the last } or ]
299+
j = max(s.rfind("}"), s.rfind("]"))
300+
if j != -1:
301+
try:
302+
return json.loads(s[: j + 1])
303+
except json.JSONDecodeError:
304+
pass
305+
306+
# Try to close brackets
307+
def _cheap_close(t: str) -> str:
308+
t += "}" * max(0, t.count("{") - t.count("}"))
309+
t += "]" * max(0, t.count("[") - t.count("]"))
310+
return t
311+
312+
t = _cheap_close(s)
313+
try:
314+
return json.loads(t)
315+
except json.JSONDecodeError as e:
316+
if "Invalid \\escape" in str(e):
317+
s = s.replace("\\", "\\\\")
318+
try:
319+
return json.loads(s)
320+
except json.JSONDecodeError:
321+
pass
322+
logger.error(f"[ImageParser] Failed to parse JSON: {e}\nResponse: {response_text}")
323+
return {}
324+
325+
def _create_memory_item(
326+
self,
327+
value: str,
328+
info: dict[str, Any],
329+
memory_type: str,
330+
tags: list[str],
331+
key: str,
332+
sources: list[SourceMessage],
333+
background: str = "",
334+
) -> TextualMemoryItem:
335+
"""Create a TextualMemoryItem with the given parameters."""
336+
info_ = info.copy()
337+
user_id = info_.pop("user_id", "")
338+
session_id = info_.pop("session_id", "")
339+
340+
return TextualMemoryItem(
341+
memory=value,
342+
metadata=TreeNodeTextualMemoryMetadata(
343+
user_id=user_id,
344+
session_id=session_id,
345+
memory_type=memory_type,
346+
status="activated",
347+
tags=tags,
348+
key=key,
349+
embedding=self.embedder.embed([value])[0],
350+
usage=[],
351+
sources=sources,
352+
background=background,
353+
confidence=0.99,
354+
type="fact",
355+
info=info_,
356+
),
357+
)

src/memos/mem_reader/read_multi_modal/multi_modal_parser.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,6 +226,8 @@ def process_transfer(
226226
parser = self.file_content_parser
227227
elif source.type == "text":
228228
parser = self.text_content_parser
229+
elif source.type in ["image", "image_url"]:
230+
parser = self.image_parser
229231
elif source.role:
230232
# Chat message, use role parser
231233
parser = self.role_parsers.get(source.role)

src/memos/mem_reader/read_multi_modal/user_parser.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,8 +85,20 @@ def create_source(
8585
original_part=part,
8686
)
8787
)
88+
elif part_type == "image_url":
89+
image_info = part.get("image_url", {})
90+
sources.append(
91+
SourceMessage(
92+
type="image",
93+
role=role,
94+
chat_time=chat_time,
95+
message_id=message_id,
96+
image_path=image_info.get("url"),
97+
original_part=part,
98+
)
99+
)
88100
else:
89-
# image_url, input_audio, etc.
101+
# input_audio, etc.
90102
sources.append(
91103
SourceMessage(
92104
type=part_type,

0 commit comments

Comments
 (0)