Skip to content

Commit cc3ed0c

Browse files
committed
feat:use schema to structure mem_reader output
1 parent e069928 commit cc3ed0c

File tree

4 files changed

+55
-5
lines changed

4 files changed

+55
-5
lines changed

src/memos/llms/openai.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ def clear_cache(cls):
5656
cls._instances.clear()
5757
logger.info("OpenAI LLM instance cache cleared")
5858

59-
def generate(self, messages: MessageList) -> str:
59+
def generate(self, messages: MessageList, **kwargs) -> str:
6060
"""Generate a response from OpenAI LLM."""
6161
response = self.client.chat.completions.create(
6262
model=self.config.model_name_or_path,
@@ -65,6 +65,7 @@ def generate(self, messages: MessageList) -> str:
6565
temperature=self.config.temperature,
6666
max_tokens=self.config.max_tokens,
6767
top_p=self.config.top_p,
68+
**kwargs
6869
)
6970
logger.info(f"Response from OpenAI: {response.model_dump_json()}")
7071
response_content = response.choices[0].message.content

src/memos/llms/vllm.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -85,16 +85,16 @@ def build_vllm_kv_cache(self, messages: Any) -> str:
8585

8686
return prompt
8787

88-
def generate(self, messages: list[MessageDict]) -> str:
88+
def generate(self, messages: list[MessageDict], **kwargs) -> str:
8989
"""
9090
Generate a response from the model.
9191
"""
9292
if self.client:
93-
return self._generate_with_api_client(messages)
93+
return self._generate_with_api_client(messages, **kwargs)
9494
else:
9595
raise RuntimeError("API client is not available")
9696

97-
def _generate_with_api_client(self, messages: list[MessageDict]) -> str:
97+
def _generate_with_api_client(self, messages: list[MessageDict], **kwargs) -> str:
9898
"""
9999
Generate response using vLLM API client.
100100
"""
@@ -106,6 +106,7 @@ def _generate_with_api_client(self, messages: list[MessageDict]) -> str:
106106
"max_tokens": int(getattr(self.config, "max_tokens", 1024)),
107107
"top_p": float(getattr(self.config, "top_p", 0.9)),
108108
"extra_body": {"chat_template_kwargs": {"enable_thinking": False}},
109+
**kwargs
109110
}
110111

111112
response = self.client.chat.completions.create(**completion_kwargs)

src/memos/mem_reader/simple_struct.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
SIMPLE_STRUCT_MEM_READER_EXAMPLE_ZH,
2828
SIMPLE_STRUCT_MEM_READER_PROMPT,
2929
SIMPLE_STRUCT_MEM_READER_PROMPT_ZH,
30+
reader_output_schema
3031
)
3132
from memos.utils import timed
3233

@@ -200,7 +201,9 @@ def _get_llm_response(self, mem_str: str) -> dict:
200201
prompt = prompt.replace(examples, "")
201202
messages = [{"role": "user", "content": prompt}]
202203
try:
203-
response_text = self.llm.generate(messages)
204+
response_text = self.llm.generate(messages,
205+
response_format={"type": "json_object",
206+
"schema": reader_output_schema})
204207
response_json = self.parse_json_result(response_text)
205208
except Exception as e:
206209
logger.error(f"[LLM] Exception during chat generation: {e}")

src/memos/templates/mem_reader_prompts.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -417,3 +417,48 @@
417417
}
418418
419419
"""
420+
421+
reader_output_schema = {
422+
"$schema": "https://json-schema.org/draft/2020-12/schema",
423+
"type": "object",
424+
"properties": {
425+
"memory list": {
426+
"type": "array",
427+
"items": {
428+
"type": "object",
429+
"properties": {
430+
"key": {
431+
"type": "string",
432+
"description": "A brief title or identifier for the memory."
433+
},
434+
"memory_type": {
435+
"type": "string",
436+
"enum": ["LongTermMemory", "ShortTermMemory", "WorkingMemory"],
437+
"description": "The type of memory, expected to be 'LongTermMemory' in this context."
438+
},
439+
"value": {
440+
"type": "string",
441+
"description": "Detailed description of the memory, including viewpoint, time, and content."
442+
},
443+
"tags": {
444+
"type": "array",
445+
"items": {
446+
"type": "string"
447+
},
448+
"description": "List of keywords or categories associated with the memory."
449+
}
450+
},
451+
"required": ["key", "memory_type", "value", "tags"],
452+
"additionalProperties": False
453+
},
454+
"description": "List of memory entries."
455+
},
456+
"summary": {
457+
"type": "string",
458+
"description": "A synthesized summary of the overall situation based on all memories."
459+
}
460+
},
461+
"required": ["memory list", "summary"],
462+
"additionalProperties": False,
463+
"description": "Structured output containing a list of memories and a summary."
464+
}

0 commit comments

Comments
 (0)