Skip to content

Commit 65578fc

Browse files
authored
Fix episodic configuration is not set correctly (MemMachine#791)
1 parent 2728b88 commit 65578fc

File tree

5 files changed

+164
-72
lines changed

5 files changed

+164
-72
lines changed

src/memmachine/common/configuration/episodic_config.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,16 @@ def merge(self, other: Self) -> EpisodicMemoryConf:
232232
# ---- Step 4: update nested configuration in the base result ----
233233
return EpisodicMemoryConf(
234234
session_key=merged.session_key,
235-
metrics_factory_id=merged.metrics_factory_id,
235+
metrics_factory_id=merged.metrics_factory_id
236+
if merged.metrics_factory_id is not None
237+
else "prometheus",
236238
short_term_memory=stm_merged,
237239
long_term_memory=ltm_merged,
238-
enabled=merged.enabled,
240+
long_term_memory_enabled=True
241+
if merged.long_term_memory_enabled is None and ltm_merged is not None
242+
else merged.long_term_memory_enabled,
243+
short_term_memory_enabled=True
244+
if merged.short_term_memory_enabled is None and stm_merged is not None
245+
else merged.short_term_memory_enabled,
246+
enabled=True if merged.enabled is None else merged.enabled,
239247
)

src/memmachine/main/memmachine.py

Lines changed: 72 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,17 @@
77
from enum import Enum
88
from typing import Any, Final, Protocol, cast
99

10-
from pydantic import BaseModel, InstanceOf
10+
from pydantic import BaseModel, InstanceOf, ValidationError
1111

1212
from memmachine.common.configuration import Configuration
1313
from memmachine.common.configuration.episodic_config import (
1414
EpisodicMemoryConf,
15-
LongTermMemoryConf,
16-
ShortTermMemoryConf,
15+
EpisodicMemoryConfPartial,
16+
LongTermMemoryConfPartial,
17+
ShortTermMemoryConfPartial,
1718
)
1819
from memmachine.common.episode_store import Episode, EpisodeEntry, EpisodeIdT
20+
from memmachine.common.errors import ConfigurationError
1921
from memmachine.common.filter.filter_parser import (
2022
And as FilterAnd,
2123
)
@@ -74,17 +76,62 @@ def __init__(
7476
) -> None:
7577
"""Create a MemMachine using the provided configuration."""
7678
self._conf = conf
77-
7879
if resources is not None:
7980
self._resources = resources
8081
else:
8182
self._resources = ResourceManagerImpl(conf)
83+
self._initialize_default_episodic_configuration()
84+
self._started = False
85+
86+
def _initialize_default_episodic_configuration(self) -> None:
87+
# initialize the default value for episodic memory configuration
88+
# Can not put the logic into the data type
89+
default_prompt = "Based on the following episodes: {episodes}, and the previous summary: {summary}, please update the summary. Keep it under {max_length} characters."
90+
if self._conf.episodic_memory is None:
91+
self._conf.episodic_memory = EpisodicMemoryConfPartial()
92+
self._conf.episodic_memory.enabled = False
93+
if self._conf.episodic_memory.long_term_memory is None:
94+
self._conf.episodic_memory.long_term_memory = LongTermMemoryConfPartial()
95+
self._conf.episodic_memory.long_term_memory_enabled = False
96+
if self._conf.episodic_memory.short_term_memory is None:
97+
self._conf.episodic_memory.short_term_memory = ShortTermMemoryConfPartial()
98+
self._conf.episodic_memory.short_term_memory_enabled = False
99+
if self._conf.episodic_memory.long_term_memory.embedder is None:
100+
self._conf.episodic_memory.long_term_memory.embedder = (
101+
self._conf.default_long_term_memory_embedder
102+
)
103+
if self._conf.episodic_memory.long_term_memory.reranker is None:
104+
self._conf.episodic_memory.long_term_memory.reranker = (
105+
self._conf.default_long_term_memory_reranker
106+
)
107+
if self._conf.episodic_memory.short_term_memory.llm_model is None:
108+
self._conf.episodic_memory.short_term_memory.llm_model = "gpt-4.1"
109+
if self._conf.episodic_memory.short_term_memory.summary_prompt_system is None:
110+
self._conf.episodic_memory.short_term_memory.summary_prompt_system = (
111+
"You are a helpful assistant."
112+
)
113+
if self._conf.episodic_memory.short_term_memory.summary_prompt_user is None:
114+
self._conf.episodic_memory.short_term_memory.summary_prompt_user = (
115+
default_prompt
116+
)
117+
if self._conf.episodic_memory.long_term_memory.vector_graph_store is None:
118+
self._conf.episodic_memory.long_term_memory.vector_graph_store = (
119+
"default_store"
120+
)
82121

83122
async def start(self) -> None:
123+
if self._started:
124+
return
125+
self._started = True
126+
84127
semantic_service = await self._resources.get_semantic_service()
85128
await semantic_service.start()
86129

87130
async def stop(self) -> None:
131+
if not self._started:
132+
return
133+
self._started = False
134+
88135
semantic_service = await self._resources.get_semantic_service()
89136
await semantic_service.stop()
90137

@@ -93,75 +140,39 @@ async def stop(self) -> None:
93140
def _with_default_episodic_memory_conf(
94141
self,
95142
*,
96-
embedder_name: str | None = None,
97-
reranker_name: str | None = None,
143+
user_conf: EpisodicMemoryConfPartial | None = None,
98144
session_key: str,
99145
) -> EpisodicMemoryConf:
100146
# Get default prompts from config, with fallbacks
101-
short_term = self._conf.episodic_memory.short_term_memory
102-
summary_prompt_system = (
103-
short_term.summary_prompt_system
104-
if short_term and short_term.summary_prompt_system
105-
else "You are a helpful assistant."
106-
)
107-
summary_prompt_user = (
108-
short_term.summary_prompt_user
109-
if short_term and short_term.summary_prompt_user
110-
else "Based on the following episodes: {episodes}, and the previous summary: {summary}, please update the summary. Keep it under {max_length} characters."
111-
)
112-
113-
# Get default embedder and reranker from config
114-
long_term = self._conf.episodic_memory.long_term_memory
115-
116-
if not embedder_name:
117-
embedder_name = self._conf.default_long_term_memory_embedder
118-
if not reranker_name:
119-
reranker_name = self._conf.default_long_term_memory_reranker
120-
121-
self._conf.check_reranker(reranker_name)
122-
self._conf.check_embedder(embedder_name)
123-
124-
target_vector_store = (
125-
long_term.vector_graph_store
126-
if long_term and long_term.vector_graph_store
127-
else "default_store"
128-
)
129-
130-
target_short_llm_model = (
131-
short_term.llm_model if short_term and short_term.llm_model else "gpt-4.1"
132-
)
133-
134-
return EpisodicMemoryConf(
135-
session_key=session_key,
136-
long_term_memory=LongTermMemoryConf(
137-
session_id=session_key,
138-
vector_graph_store=target_vector_store,
139-
embedder=embedder_name,
140-
reranker=reranker_name,
141-
),
142-
short_term_memory=ShortTermMemoryConf(
143-
session_key=session_key,
144-
llm_model=target_short_llm_model,
145-
summary_prompt_system=summary_prompt_system,
146-
summary_prompt_user=summary_prompt_user,
147-
),
148-
long_term_memory_enabled=True,
149-
short_term_memory_enabled=True,
150-
enabled=True,
151-
)
147+
try:
148+
if user_conf is None:
149+
user_conf = EpisodicMemoryConfPartial()
150+
user_conf.session_key = session_key
151+
episodic_conf = user_conf.merge(self._conf.episodic_memory)
152+
if episodic_conf.long_term_memory is not None:
153+
if episodic_conf.long_term_memory.embedder is not None:
154+
self._conf.check_embedder(episodic_conf.long_term_memory.embedder)
155+
if episodic_conf.long_term_memory.reranker is not None:
156+
self._conf.check_reranker(episodic_conf.long_term_memory.reranker)
157+
except ValidationError as e:
158+
logger.exception(
159+
"Faield to merge configuration: %s, %s",
160+
str(user_conf),
161+
str(self._conf.episodic_memory),
162+
)
163+
raise ConfigurationError("Failed to merge configuration") from e
164+
return episodic_conf
152165

153166
async def create_session(
154167
self,
155168
session_key: str,
156169
*,
157170
description: str = "",
158-
embedder_name: str | None = None,
159-
reranker_name: str | None = None,
171+
user_conf: EpisodicMemoryConfPartial | None = None,
160172
) -> SessionDataManager.SessionInfo:
161173
"""Create a new session."""
162174
episodic_memory_conf = self._with_default_episodic_memory_conf(
163-
embedder_name=embedder_name,
164-
reranker_name=reranker_name,
175+
user_conf=user_conf,
165176
session_key=session_key,
166177
)
167178

src/memmachine/server/api_v2/router.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,10 @@
2323
SearchMemoriesSpec,
2424
SearchResult,
2525
)
26+
from memmachine.common.configuration.episodic_config import (
27+
EpisodicMemoryConfPartial,
28+
LongTermMemoryConfPartial,
29+
)
2630
from memmachine.common.errors import (
2731
ConfigurationError,
2832
InvalidArgumentError,
@@ -50,11 +54,16 @@ async def create_project(
5054
project_id=spec.project_id,
5155
)
5256
try:
57+
user_conf = EpisodicMemoryConfPartial(
58+
long_term_memory=LongTermMemoryConfPartial(
59+
embedder=spec.config.embedder if spec.config.embedder else None,
60+
reranker=spec.config.reranker if spec.config.reranker else None,
61+
)
62+
)
5363
session = await memmachine.create_session(
5464
session_key=session_data.session_key,
5565
description=spec.description,
56-
embedder_name=spec.config.embedder,
57-
reranker_name=spec.config.reranker,
66+
user_conf=user_conf,
5867
)
5968
except InvalidArgumentError as e:
6069
raise RestError(code=422, message="invalid argument: " + str(e)) from e

tests/memmachine/main/test_memmachine_mock.py

Lines changed: 69 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010

1111
from memmachine.common.configuration import (
1212
Configuration,
13-
EpisodicMemoryConfPartial,
1413
)
1514
from memmachine.common.configuration.episodic_config import (
15+
EpisodicMemoryConfPartial,
1616
LongTermMemoryConfPartial,
1717
ShortTermMemoryConfPartial,
1818
)
@@ -39,8 +39,9 @@ def session_key(self) -> str: # pragma: no cover - trivial accessor
3939
return self._session_key
4040

4141

42-
@pytest.fixture
43-
def minimal_conf() -> Configuration:
42+
def _minimal_conf(
43+
short_memory_enabled: bool = True, long_term_memory_enabled: bool = True
44+
) -> Configuration:
4445
"""Provide the minimal subset of configuration accessed in tests."""
4546
mock_rerankers = MagicMock()
4647
mock_rerankers.contains_reranker.return_value = True
@@ -65,12 +66,24 @@ def minimal_conf() -> Configuration:
6566
embedder="default-embedder",
6667
reranker="default-reranker",
6768
),
69+
short_term_memory_enabled=short_memory_enabled,
70+
long_term_memory_enabled=long_term_memory_enabled,
6871
)
6972
ret.default_long_term_memory_embedder = "default-embedder"
7073
ret.default_long_term_memory_reranker = "default-reranker"
7174
return ret
7275

7376

77+
@pytest.fixture
78+
def minimal_conf() -> Configuration:
79+
return _minimal_conf()
80+
81+
82+
@pytest.fixture
83+
def minimal_conf_factory():
84+
return _minimal_conf
85+
86+
7487
@pytest.fixture
7588
def patched_resource_manager(monkeypatch):
7689
"""Replace :class:`ResourceManagerImpl` with a controllable double."""
@@ -122,6 +135,52 @@ def test_with_default_episodic_memory_conf_uses_fallbacks(
122135
)
123136

124137

138+
def test_with_default_short_conf_enable_status(
139+
minimal_conf_factory, patched_resource_manager
140+
):
141+
min_conf = minimal_conf_factory(
142+
short_memory_enabled=False, long_term_memory_enabled=True
143+
)
144+
memmachine = MemMachine(min_conf, patched_resource_manager)
145+
conf = memmachine._with_default_episodic_memory_conf(session_key="session-2")
146+
assert min_conf.episodic_memory.short_term_memory_enabled is False
147+
assert min_conf.episodic_memory.long_term_memory_enabled is True
148+
assert conf.short_term_memory_enabled is False
149+
assert conf.long_term_memory_enabled is True
150+
user_conf = EpisodicMemoryConfPartial(
151+
short_term_memory_enabled=True,
152+
long_term_memory_enabled=False,
153+
)
154+
conf = memmachine._with_default_episodic_memory_conf(
155+
session_key="session-2", user_conf=user_conf
156+
)
157+
assert conf.short_term_memory_enabled is True
158+
assert conf.long_term_memory_enabled is False
159+
160+
161+
def test_with_default_long_conf_enable_status(
162+
minimal_conf_factory, patched_resource_manager
163+
):
164+
min_conf = minimal_conf_factory(
165+
short_memory_enabled=True, long_term_memory_enabled=False
166+
)
167+
memmachine = MemMachine(min_conf, patched_resource_manager)
168+
conf = memmachine._with_default_episodic_memory_conf(session_key="session-2")
169+
assert min_conf.episodic_memory.short_term_memory_enabled is True
170+
assert min_conf.episodic_memory.long_term_memory_enabled is False
171+
assert conf.short_term_memory_enabled is True
172+
assert conf.long_term_memory_enabled is False
173+
user_conf = EpisodicMemoryConfPartial(
174+
short_term_memory_enabled=False,
175+
long_term_memory_enabled=True,
176+
)
177+
conf = memmachine._with_default_episodic_memory_conf(
178+
session_key="session-2", user_conf=user_conf
179+
)
180+
assert conf.short_term_memory_enabled is False
181+
assert conf.long_term_memory_enabled is True
182+
183+
125184
@pytest.mark.asyncio
126185
async def test_create_session_passes_generated_config(
127186
minimal_conf, patched_resource_manager
@@ -133,11 +192,16 @@ async def test_create_session_passes_generated_config(
133192

134193
memmachine = MemMachine(minimal_conf, patched_resource_manager)
135194

195+
user_conf = EpisodicMemoryConfPartial(
196+
long_term_memory=LongTermMemoryConfPartial(
197+
embedder="custom-embed",
198+
reranker="custom-reranker",
199+
)
200+
)
136201
await memmachine.create_session(
137202
"alpha",
138203
description="demo",
139-
embedder_name="custom-embed",
140-
reranker_name="custom-reranker",
204+
user_conf=user_conf,
141205
)
142206

143207
session_manager.create_new_session.assert_awaited_once()

tests/memmachine/server/api_v2/test_router.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,8 @@ def test_create_project(client, mock_memmachine):
6060
call_args = mock_memmachine.create_session.call_args[1]
6161
assert call_args["session_key"] == "test_org/test_proj"
6262
assert call_args["description"] == "A test project"
63-
assert call_args["embedder_name"] == "openai"
64-
assert call_args["reranker_name"] == "cohere"
63+
assert call_args["user_conf"].long_term_memory.embedder == "openai"
64+
assert call_args["user_conf"].long_term_memory.reranker == "cohere"
6565

6666
mock_memmachine.create_session.reset_mock()
6767
mock_memmachine.create_session.side_effect = InvalidArgumentError(

0 commit comments

Comments
 (0)