Skip to content

Commit e56d39a

Browse files
committed
feat: change and add config for multi-memreader
1 parent 2048c8f commit e56d39a

File tree

6 files changed

+186
-11
lines changed

6 files changed

+186
-11
lines changed

src/memos/api/config.py

Lines changed: 111 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ def get_activation_config() -> dict[str, Any]:
321321

322322
@staticmethod
323323
def get_memreader_config() -> dict[str, Any]:
324-
"""Get MemReader configuration."""
324+
"""Get MemReader configuration for chat/doc extraction (fine-tuned 0.6B model)."""
325325
return {
326326
"backend": "openai",
327327
"config": {
@@ -338,6 +338,107 @@ def get_memreader_config() -> dict[str, Any]:
338338
},
339339
}
340340

341+
@staticmethod
342+
def get_memreader_general_llm_config() -> dict[str, Any]:
343+
"""Get general LLM configuration for non-chat/doc tasks.
344+
345+
Used for: hallucination filter, memory rewrite, memory merge,
346+
tool trajectory extraction, skill memory extraction.
347+
348+
This is the fallback for image_parser_llm and preference_extractor_llm.
349+
Fallback chain: MEMREADER_GENERAL_MODEL -> MEMRADER_MODEL (memreader config)
350+
351+
Note: If you have fine-tuned a custom model for chat/doc extraction only,
352+
you should configure MEMREADER_GENERAL_MODEL to use a general-purpose LLM
353+
for other tasks. Otherwise, all tasks will use the same MEMRADER_MODEL.
354+
"""
355+
# Check if specific general model is configured
356+
general_model = os.getenv("MEMREADER_GENERAL_MODEL")
357+
if general_model:
358+
return {
359+
"backend": os.getenv("MEMREADER_GENERAL_BACKEND", "openai"),
360+
"config": {
361+
"model_name_or_path": general_model,
362+
"temperature": 0.6,
363+
"max_tokens": int(os.getenv("MEMREADER_GENERAL_MAX_TOKENS", "8000")),
364+
"top_p": 0.95,
365+
"top_k": 20,
366+
"api_key": os.getenv(
367+
"MEMREADER_GENERAL_API_KEY", os.getenv("OPENAI_API_KEY", "EMPTY")
368+
),
369+
"api_base": os.getenv(
370+
"MEMREADER_GENERAL_API_BASE",
371+
os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
372+
),
373+
"remove_think_prefix": True,
374+
},
375+
}
376+
# Fallback to memreader config (same behavior as before for users who don't customize)
377+
return APIConfig.get_memreader_config()
378+
379+
@staticmethod
380+
def get_image_parser_llm_config() -> dict[str, Any]:
381+
"""Get LLM configuration for image parsing (requires vision model).
382+
383+
Used for: image content extraction and analysis.
384+
Requires a vision-capable model like GPT-4V, GPT-4o, etc.
385+
386+
Fallback chain: IMAGE_PARSER_MODEL -> general_llm -> OpenAI config
387+
"""
388+
image_model = os.getenv("IMAGE_PARSER_MODEL")
389+
if image_model:
390+
return {
391+
"backend": os.getenv("IMAGE_PARSER_BACKEND", "openai"),
392+
"config": {
393+
"model_name_or_path": image_model,
394+
"temperature": 0.6,
395+
"max_tokens": int(os.getenv("IMAGE_PARSER_MAX_TOKENS", "4096")),
396+
"top_p": 0.95,
397+
"top_k": 20,
398+
"api_key": os.getenv(
399+
"IMAGE_PARSER_API_KEY", os.getenv("OPENAI_API_KEY", "EMPTY")
400+
),
401+
"api_base": os.getenv(
402+
"IMAGE_PARSER_API_BASE",
403+
os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
404+
),
405+
"remove_think_prefix": True,
406+
},
407+
}
408+
# Fallback to general_llm config (which itself falls back to OpenAI)
409+
return APIConfig.get_memreader_general_llm_config()
410+
411+
@staticmethod
412+
def get_preference_extractor_llm_config() -> dict[str, Any]:
413+
"""Get LLM configuration for preference extraction.
414+
415+
Used for: extracting user preferences from conversations.
416+
417+
Fallback chain: PREFERENCE_EXTRACTOR_MODEL -> general_llm -> OpenAI config
418+
"""
419+
pref_model = os.getenv("PREFERENCE_EXTRACTOR_MODEL")
420+
if pref_model:
421+
return {
422+
"backend": os.getenv("PREFERENCE_EXTRACTOR_BACKEND", "openai"),
423+
"config": {
424+
"model_name_or_path": pref_model,
425+
"temperature": 0.6,
426+
"max_tokens": int(os.getenv("PREFERENCE_EXTRACTOR_MAX_TOKENS", "8000")),
427+
"top_p": 0.95,
428+
"top_k": 20,
429+
"api_key": os.getenv(
430+
"PREFERENCE_EXTRACTOR_API_KEY", os.getenv("OPENAI_API_KEY", "EMPTY")
431+
),
432+
"api_base": os.getenv(
433+
"PREFERENCE_EXTRACTOR_API_BASE",
434+
os.getenv("OPENAI_API_BASE", "https://api.openai.com/v1"),
435+
),
436+
"remove_think_prefix": True,
437+
},
438+
}
439+
# Fallback to general_llm config (which itself falls back to OpenAI)
440+
return APIConfig.get_memreader_general_llm_config()
441+
341442
@staticmethod
342443
def get_activation_vllm_config() -> dict[str, Any]:
343444
"""Get Ollama configuration."""
@@ -358,7 +459,7 @@ def get_preference_memory_config() -> dict[str, Any]:
358459
return {
359460
"backend": "pref_text",
360461
"config": {
361-
"extractor_llm": APIConfig.get_memreader_config(),
462+
"extractor_llm": APIConfig.get_preference_extractor_llm_config(),
362463
"vector_db": {
363464
"backend": "milvus",
364465
"config": APIConfig.get_milvus_config(),
@@ -802,6 +903,10 @@ def get_product_default_config() -> dict[str, Any]:
802903
"backend": reader_config["backend"],
803904
"config": {
804905
"llm": APIConfig.get_memreader_config(),
906+
# General LLM for non-chat/doc tasks (hallucination filter, rewrite, merge, etc.)
907+
"general_llm": APIConfig.get_memreader_general_llm_config(),
908+
# Image parser LLM (requires vision model)
909+
"image_parser_llm": APIConfig.get_image_parser_llm_config(),
805910
"embedder": APIConfig.get_embedder_config(),
806911
"chunker": {
807912
"backend": "sentence",
@@ -924,6 +1029,10 @@ def create_user_config(user_name: str, user_id: str) -> tuple["MOSConfig", "Gene
9241029
"backend": reader_config["backend"],
9251030
"config": {
9261031
"llm": APIConfig.get_memreader_config(),
1032+
# General LLM for non-chat/doc tasks (hallucination filter, rewrite, merge, etc.)
1033+
"general_llm": APIConfig.get_memreader_general_llm_config(),
1034+
# Image parser LLM (requires vision model)
1035+
"image_parser_llm": APIConfig.get_image_parser_llm_config(),
9271036
"embedder": APIConfig.get_embedder_config(),
9281037
"chunker": {
9291038
"backend": "sentence",

src/memos/api/handlers/config_builders.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,3 +201,26 @@ def build_nli_client_config() -> dict[str, Any]:
201201
NLI client configuration dictionary
202202
"""
203203
return APIConfig.get_nli_config()
204+
205+
206+
def build_general_llm_config() -> dict[str, Any]:
207+
"""
208+
Build general LLM configuration for non-chat/doc tasks.
209+
210+
Used for: hallucination filter, memory rewrite, memory merge,
211+
tool trajectory extraction, skill memory extraction.
212+
213+
Returns:
214+
Validated general LLM configuration dictionary
215+
"""
216+
return LLMConfigFactory.model_validate(APIConfig.get_memreader_general_llm_config())
217+
218+
219+
def build_image_parser_llm_config() -> dict[str, Any]:
220+
"""
221+
Build image parser LLM configuration (requires vision model).
222+
223+
Returns:
224+
Validated image parser LLM configuration dictionary
225+
"""
226+
return LLMConfigFactory.model_validate(APIConfig.get_image_parser_llm_config())

src/memos/configs/mem_reader.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,18 @@ def parse_datetime(cls, value):
2424
return datetime.fromisoformat(value.replace("Z", "+00:00"))
2525
return value
2626

27-
llm: LLMConfigFactory = Field(..., description="LLM configuration for the MemReader")
27+
llm: LLMConfigFactory = Field(
28+
..., description="LLM configuration for chat/doc memory extraction (fine-tuned model)"
29+
)
30+
general_llm: LLMConfigFactory | None = Field(
31+
default=None,
32+
description="General LLM for non-chat/doc tasks: hallucination filter, memory rewrite, "
33+
"memory merge, tool trajectory, skill memory. Falls back to main llm if not set.",
34+
)
35+
image_parser_llm: LLMConfigFactory | None = Field(
36+
default=None,
37+
description="Vision LLM for image parsing. Falls back to main llm if not set.",
38+
)
2839
embedder: EmbedderConfigFactory = Field(
2940
..., description="Embedder configuration for the MemReader"
3041
)

src/memos/mem_reader/multi_modal_struct.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def __init__(self, config: MultiModalStructMemReaderConfig):
3939
config: Configuration object for the reader
4040
"""
4141
from memos.configs.mem_reader import SimpleStructMemReaderConfig
42+
from memos.llms.factory import LLMFactory
4243

4344
# Extract direct_markdown_hostnames before converting to SimpleStructMemReaderConfig
4445
direct_markdown_hostnames = getattr(config, "direct_markdown_hostnames", None)
@@ -56,10 +57,20 @@ def __init__(self, config: MultiModalStructMemReaderConfig):
5657
simple_config = SimpleStructMemReaderConfig(**config_dict)
5758
super().__init__(simple_config)
5859

60+
# Image parser LLM (requires vision model)
61+
# Falls back to main llm if not configured
62+
self.image_parser_llm = (
63+
LLMFactory.from_config(config.image_parser_llm)
64+
if config.image_parser_llm is not None
65+
else self.llm
66+
)
67+
5968
# Initialize MultiModalParser for routing to different parsers
69+
# Pass image_parser_llm for image parsing
6070
self.multi_modal_parser = MultiModalParser(
6171
embedder=self.embedder,
6272
llm=self.llm,
73+
image_parser_llm=self.image_parser_llm,
6374
parser=None,
6475
direct_markdown_hostnames=direct_markdown_hostnames,
6576
)
@@ -631,7 +642,8 @@ def _merge_memories_with_llm(
631642
)
632643

633644
try:
634-
response_text = self.llm.generate([{"role": "user", "content": merge_prompt}])
645+
# Use general_llm for memory merge (not fine-tuned for this task)
646+
response_text = self.general_llm.generate([{"role": "user", "content": merge_prompt}])
635647
merge_result = parse_json_result(response_text)
636648

637649
if merge_result.get("should_merge", False):
@@ -873,12 +885,14 @@ def get_chunk_idx(item_with_pos) -> int:
873885
def _get_llm_tool_trajectory_response(self, mem_str: str) -> dict:
874886
"""
875887
Generete tool trajectory experience item by llm.
888+
Uses general_llm as this task is not fine-tuned for the main model.
876889
"""
877890
try:
878891
lang = detect_lang(mem_str)
879892
template = TOOL_TRAJECTORY_PROMPT_ZH if lang == "zh" else TOOL_TRAJECTORY_PROMPT_EN
880893
prompt = template.replace("{messages}", mem_str)
881-
rsp = self.llm.generate([{"role": "user", "content": prompt}])
894+
# Use general_llm for tool trajectory (not fine-tuned for this task)
895+
rsp = self.general_llm.generate([{"role": "user", "content": prompt}])
882896
rsp = rsp.replace("```json", "").replace("```", "")
883897
return json.loads(rsp)
884898
except Exception as e:
@@ -1000,13 +1014,14 @@ def _process_multi_modal_data(
10001014
future_tool = executor.submit(
10011015
self._process_tool_trajectory_fine, fast_memory_items, info, **kwargs
10021016
)
1017+
# Use general_llm for skill memory extraction (not fine-tuned for this task)
10031018
future_skill = executor.submit(
10041019
process_skill_memory_fine,
10051020
fast_memory_items=fast_memory_items,
10061021
info=info,
10071022
searcher=self.searcher,
10081023
graph_db=self.graph_db,
1009-
llm=self.llm,
1024+
llm=self.general_llm,
10101025
embedder=self.embedder,
10111026
oss_config=self.oss_config,
10121027
skills_dir_config=self.skills_dir_config,
@@ -1067,12 +1082,13 @@ def _process_transfer_multi_modal_data(
10671082
future_tool = executor.submit(
10681083
self._process_tool_trajectory_fine, raw_nodes, info, **kwargs
10691084
)
1085+
# Use general_llm for skill memory extraction (not fine-tuned for this task)
10701086
future_skill = executor.submit(
10711087
process_skill_memory_fine,
10721088
raw_nodes,
10731089
info,
10741090
searcher=self.searcher,
1075-
llm=self.llm,
1091+
llm=self.general_llm,
10761092
embedder=self.embedder,
10771093
graph_db=self.graph_db,
10781094
oss_config=self.oss_config,

src/memos/mem_reader/read_multi_modal/multi_modal_parser.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ def __init__(
3535
self,
3636
embedder: BaseEmbedder,
3737
llm: BaseLLM | None = None,
38+
image_parser_llm: BaseLLM | None = None,
3839
parser: Any | None = None,
3940
direct_markdown_hostnames: list[str] | None = None,
4041
):
@@ -43,14 +44,18 @@ def __init__(
4344
4445
Args:
4546
embedder: Embedder for generating embeddings
46-
llm: Optional LLM for fine mode processing
47+
llm: Optional LLM for fine mode processing (chat/doc extraction)
48+
image_parser_llm: Optional vision LLM for image parsing.
49+
Falls back to llm if not provided.
4750
parser: Optional parser for parsing file contents
4851
direct_markdown_hostnames: List of hostnames that should return markdown directly
4952
without parsing. If None, reads from FILE_PARSER_DIRECT_MARKDOWN_HOSTNAMES
5053
environment variable (comma-separated). Default: ["139.196.232.20"]
5154
"""
5255
self.embedder = embedder
5356
self.llm = llm
57+
# Image parser LLM (requires vision model), falls back to main llm
58+
self.image_parser_llm = image_parser_llm if image_parser_llm is not None else llm
5459
self.parser = parser
5560

5661
# Initialize parsers for different message types
@@ -63,7 +68,8 @@ def __init__(
6368
self.file_content_parser = FileContentParser(
6469
embedder, llm, parser, direct_markdown_hostnames=direct_markdown_hostnames
6570
)
66-
self.image_parser = ImageParser(embedder, llm)
71+
# Use dedicated image_parser_llm for image parsing (requires vision model)
72+
self.image_parser = ImageParser(embedder, self.image_parser_llm)
6773
self.audio_parser = None # future
6874

6975
self.role_parsers = {

src/memos/mem_reader/simple_struct.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,15 @@ def __init__(self, config: SimpleStructMemReaderConfig):
173173
config: Configuration object for the reader
174174
"""
175175
self.config = config
176+
# Main LLM for chat/doc memory extraction (fine-tuned model)
176177
self.llm = LLMFactory.from_config(config.llm)
178+
# General LLM for non-chat/doc tasks (hallucination filter, rewrite, merge, etc.)
179+
# Falls back to main llm if not configured
180+
self.general_llm = (
181+
LLMFactory.from_config(config.general_llm)
182+
if config.general_llm is not None
183+
else self.llm
184+
)
177185
self.embedder = EmbedderFactory.from_config(config.embedder)
178186
self.chunker = ChunkerFactory.from_config(config.chunker)
179187
self.save_rawfile = self.chunker.config.save_rawfile
@@ -505,8 +513,9 @@ def rewrite_memories(
505513
prompt = template.format(**prompt_args)
506514

507515
# Optionally run filter and parse the output
516+
# Use general_llm for rewrite (not fine-tuned for this task)
508517
try:
509-
raw = self.llm.generate([{"role": "user", "content": prompt}])
518+
raw = self.general_llm.generate([{"role": "user", "content": prompt}])
510519
success, parsed = parse_rewritten_response(raw)
511520
logger.info(
512521
f"[rewrite_memories] Hallucination filter parsed successfully: {success};prompt: {prompt}"
@@ -565,8 +574,9 @@ def filter_hallucination_in_memories(
565574
prompt = template.format(**prompt_args)
566575

567576
# Optionally run filter and parse the output
577+
# Use general_llm for hallucination filter (not fine-tuned for this task)
568578
try:
569-
raw = self.llm.generate([{"role": "user", "content": prompt}])
579+
raw = self.general_llm.generate([{"role": "user", "content": prompt}])
570580
success, parsed = parse_keep_filter_response(raw)
571581
logger.info(
572582
f"[filter_hallucination_in_memories] Hallucination filter parsed successfully: {success};prompt: {prompt}"

0 commit comments

Comments
 (0)