Skip to content

Commit b6efc87

Browse files
author
yuan.wang
committed
modify multi-modal code
1 parent d0cea34 commit b6efc87

File tree

7 files changed

+34
-22
lines changed

7 files changed

+34
-22
lines changed

src/memos/mem_reader/read_multi_modal/system_parser.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,21 @@ def __init__(self, embedder: BaseEmbedder, llm: BaseLLM | None = None):
3737

3838
def create_source(
3939
self,
40-
message: str,
40+
message: ChatCompletionSystemMessageParam,
4141
info: dict[str, Any],
4242
) -> SourceMessage:
4343
"""Create SourceMessage from system message."""
44+
content = message["content"]
45+
if isinstance(content, dict):
46+
content = content["text"]
47+
4448
content_wo_tool_schema = re.sub(
4549
r"<tool_schema>(.*?)</tool_schema>",
4650
r"<tool_schema>omitted</tool_schema>",
47-
message,
51+
content,
4852
flags=re.DOTALL,
4953
)
50-
tool_schema_match = re.search(r"<tool_schema>(.*?)</tool_schema>", message, re.DOTALL)
54+
tool_schema_match = re.search(r"<tool_schema>(.*?)</tool_schema>", content, re.DOTALL)
5155
tool_schema_content = tool_schema_match.group(1) if tool_schema_match else ""
5256

5357
return SourceMessage(
@@ -90,7 +94,7 @@ def parse_fast(
9094
flags=re.DOTALL,
9195
)
9296

93-
source = self.create_source(content, info)
97+
source = self.create_source(message, info)
9498
return [
9599
TextualMemoryItem(
96100
memory=content_wo_tool_schema,
@@ -125,9 +129,10 @@ def parse_fine(
125129
return [
126130
TextualMemoryItem(
127131
id=str(uuid.uuid4()),
128-
memory=json.dumps(tool_schema),
132+
memory=json.dumps(schema),
129133
metadata=TreeNodeTextualMemoryMetadata(
130134
memory_type="ToolSchemaMemory",
131135
),
132136
)
137+
for schema in tool_schema
133138
]

src/memos/mem_reader/read_multi_modal/tool_parser.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -230,26 +230,28 @@ def parse_fast(
230230
content = message.get("content", "")
231231
chat_time = message.get("chat_time", None)
232232

233-
if role != "user":
234-
logger.warning(f"[ToolParser] Expected role is `user`, got {role}")
233+
if role != "tool":
234+
logger.warning(f"[ToolParser] Expected role is `tool`, got {role}")
235235
return []
236236
parts = [f"{role}: "]
237237
if chat_time:
238238
parts.append(f"[{chat_time}]: ")
239239
prefix = "".join(parts)
240-
content = json.dumps(content) if isinstance(content, list) else content
240+
content = json.dumps(content) if isinstance(content, list | dict) else content
241241
line = f"{prefix}{content}\n"
242242
if not line:
243243
return []
244-
memory_type = (
245-
"LongTermMemory" # only choce long term memory for tool messages as a placeholder
246-
)
247244

248245
sources = self.create_source(message, info)
249246
return [
250247
TextualMemoryItem(
251248
memory=line,
252-
metadata=TreeNodeTextualMemoryMetadata(memory_type=memory_type, sources=sources),
249+
metadata=TreeNodeTextualMemoryMetadata(
250+
memory_type="LongTermMemory", # only choce long term memory for tool messages as a placeholder
251+
status="activated",
252+
tags=["mode:fast"],
253+
sources=sources,
254+
),
253255
)
254256
]
255257

src/memos/memories/textual/tree_text_memory/organize/manager.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,12 @@ def _process_memory(self, memory: TextualMemoryItem, user_name: str | None = Non
186186
)
187187
futures.append(("working", f_working))
188188

189-
if memory.metadata.memory_type in ("LongTermMemory", "UserMemory"):
189+
if memory.metadata.memory_type in (
190+
"LongTermMemory",
191+
"UserMemory",
192+
"ToolSchemaMemory",
193+
"ToolTrajectoryMemory",
194+
):
190195
f_graph = ex.submit(
191196
self._add_to_graph_memory,
192197
memory=memory,

src/memos/types/openai_chat_completion_types/chat_completion_assistant_message_param.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import Iterable
65
from typing import Literal, TypeAlias
76

87
from typing_extensions import Required, TypedDict
@@ -35,7 +34,7 @@ class ChatCompletionAssistantMessageParam(TypedDict, total=False):
3534
[Learn more](https://platform.openai.com/docs/guides/audio).
3635
"""
3736

38-
content: str | Iterable[ContentArrayOfContentPart] | None
37+
content: str | list[ContentArrayOfContentPart] | ContentArrayOfContentPart | None
3938
"""The contents of the assistant message.
4039
4140
Required unless `tool_calls` or `function_call` is specified.
@@ -44,7 +43,9 @@ class ChatCompletionAssistantMessageParam(TypedDict, total=False):
4443
refusal: str | None
4544
"""The refusal message by the assistant."""
4645

47-
tool_calls: Iterable[ChatCompletionMessageToolCallUnionParam]
46+
tool_calls: (
47+
list[ChatCompletionMessageToolCallUnionParam] | ChatCompletionMessageToolCallUnionParam
48+
)
4849
"""The tool calls generated by the model, such as function calls."""
4950

5051
chat_time: str | None

src/memos/types/openai_chat_completion_types/chat_completion_system_message_param.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import Iterable
65
from typing import Literal
76

87
from typing_extensions import Required, TypedDict
@@ -14,7 +13,9 @@
1413

1514

1615
class ChatCompletionSystemMessageParam(TypedDict, total=False):
17-
content: Required[str | Iterable[ChatCompletionContentPartTextParam]]
16+
content: Required[
17+
str | list[ChatCompletionContentPartTextParam] | ChatCompletionContentPartTextParam
18+
]
1819
"""The contents of the system message."""
1920

2021
role: Required[Literal["system"]]

src/memos/types/openai_chat_completion_types/chat_completion_tool_message_param.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import Iterable
65
from typing import Literal
76

87
from typing_extensions import Required, TypedDict
@@ -14,7 +13,7 @@
1413

1514

1615
class ChatCompletionToolMessageParam(TypedDict, total=False):
17-
content: Required[str | Iterable[ChatCompletionContentPartParam]]
16+
content: Required[str | list[ChatCompletionContentPartParam] | ChatCompletionContentPartParam]
1817
"""The contents of the tool message."""
1918

2019
role: Required[Literal["tool"]]

src/memos/types/openai_chat_completion_types/chat_completion_user_message_param.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from __future__ import annotations
44

5-
from collections.abc import Iterable
65
from typing import Literal
76

87
from typing_extensions import Required, TypedDict
@@ -14,7 +13,7 @@
1413

1514

1615
class ChatCompletionUserMessageParam(TypedDict, total=False):
17-
content: Required[str | Iterable[ChatCompletionContentPartParam]]
16+
content: Required[str | list[ChatCompletionContentPartParam] | ChatCompletionContentPartParam]
1817
"""The contents of the user message."""
1918

2019
role: Required[Literal["user"]]

0 commit comments

Comments
 (0)