Skip to content

Commit a9e5e4e

Browse files
committed
Create non-existent projects when adding semantic memories
1 parent 24f4397 commit a9e5e4e

File tree

3 files changed

+69
-24
lines changed

3 files changed

+69
-24
lines changed

src/memmachine/common/session_manager/session_data_manager_sql_impl.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
update,
2121
)
2222
from sqlalchemy.dialects.postgresql import JSONB
23+
from sqlalchemy.exc import IntegrityError
2324
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, async_sessionmaker
2425
from sqlalchemy.orm import (
2526
DeclarativeBase,
@@ -196,15 +197,6 @@ async def create_new_session(
196197
param_data = param.__dict__
197198

198199
async with self._async_session() as dbsession:
199-
# Query for an existing session with the same ID
200-
sessions = await dbsession.execute(
201-
select(self.SessionConfig).where(
202-
self.SessionConfig.session_key == session_key,
203-
),
204-
)
205-
session = sessions.first()
206-
if session is not None:
207-
raise SessionAlreadyExistsError(session_key)
208200
# create a new entry
209201
new_session = self.SessionConfig(
210202
session_key=session_key,
@@ -215,7 +207,11 @@ async def create_new_session(
215207
user_metadata=metadata,
216208
)
217209
dbsession.add(new_session)
218-
await dbsession.commit()
210+
try:
211+
await dbsession.commit()
212+
except IntegrityError as exc:
213+
await dbsession.rollback()
214+
raise SessionAlreadyExistsError(session_key) from exc
219215

220216
async def delete_session(self, session_key: str) -> None:
221217
"""Delete a session and its related data from the database."""

src/memmachine/main/memmachine.py

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,11 @@
1717
ShortTermMemoryConfPartial,
1818
)
1919
from memmachine.common.episode_store import Episode, EpisodeEntry, EpisodeIdT
20-
from memmachine.common.errors import ConfigurationError, SessionNotFoundError
20+
from memmachine.common.errors import (
21+
ConfigurationError,
22+
SessionAlreadyExistsError,
23+
SessionNotFoundError,
24+
)
2125
from memmachine.common.filter.filter_parser import (
2226
And as FilterAnd,
2327
)
@@ -162,6 +166,7 @@ async def create_session(
162166
*,
163167
description: str = "",
164168
user_conf: EpisodicMemoryConfPartial | None = None,
169+
exist_ok: bool = False,
165170
) -> SessionDataManager.SessionInfo:
166171
"""Create a new session."""
167172
episodic_memory_conf = self._with_default_episodic_memory_conf(
@@ -170,13 +175,18 @@ async def create_session(
170175
)
171176

172177
session_data_manager = await self._resources.get_session_data_manager()
173-
await session_data_manager.create_new_session(
174-
session_key=session_key,
175-
configuration={},
176-
param=episodic_memory_conf,
177-
description=description,
178-
metadata={},
179-
)
178+
try:
179+
await session_data_manager.create_new_session(
180+
session_key=session_key,
181+
configuration={},
182+
param=episodic_memory_conf,
183+
description=description,
184+
metadata={},
185+
)
186+
except SessionAlreadyExistsError:
187+
if not exist_ok:
188+
raise
189+
180190
ret = await self.get_session(session_key=session_key)
181191
if ret is None:
182192
raise RuntimeError(f"Failed to create session {session_key}")
@@ -267,19 +277,17 @@ async def add_episodes(
267277
)
268278
episode_ids = [e.uid for e in episodes]
269279

280+
if await self.get_session(session_data.session_key) is None:
281+
await self.create_session(session_data.session_key, exist_ok=True)
282+
270283
tasks = []
271284

272285
if MemoryType.Episodic in target_memories:
273286
episodic_memory_manager = (
274287
await self._resources.get_episodic_memory_manager()
275288
)
276-
async with episodic_memory_manager.open_or_create_episodic_memory(
289+
async with episodic_memory_manager.open_episodic_memory(
277290
session_key=session_data.session_key,
278-
description="",
279-
episodic_memory_config=self._with_default_episodic_memory_conf(
280-
session_key=session_data.session_key
281-
),
282-
metadata={},
283291
) as episodic_session:
284292
tasks.append(episodic_session.add_memory_episodes(episodes))
285293

tests/memmachine/main/test_memmachine_mock.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,6 +392,47 @@ async def test_add_episodes_dispatches_to_all_memories(
392392
)
393393

394394

395+
@pytest.mark.asyncio
396+
async def test_add_semantic_memory_will_create_nonexistent_session(
397+
minimal_conf, patched_resource_manager
398+
):
399+
memmachine = MemMachine(minimal_conf, patched_resource_manager)
400+
session = DummySessionData("new-session")
401+
402+
entries = [
403+
EpisodeEntry(content="hello", producer_id="user", producer_role="assistant"),
404+
]
405+
406+
mock_session_manager = AsyncMock()
407+
patched_resource_manager.get_session_data_manager = AsyncMock(
408+
return_value=mock_session_manager
409+
)
410+
mock_session_manager.get_session_info.return_value = None
411+
412+
semantic_manager = MagicMock()
413+
semantic_manager.add_message = AsyncMock()
414+
patched_resource_manager.get_semantic_session_manager = AsyncMock(
415+
return_value=semantic_manager
416+
)
417+
418+
async def _create_session_side_effect(*args, **kwargs):
419+
mock_session_manager.get_session_info.return_value = session
420+
421+
mock_session_manager.create_new_session = AsyncMock(
422+
side_effect=_create_session_side_effect
423+
)
424+
425+
await memmachine.add_episodes(
426+
session, entries, target_memories=[MemoryType.Semantic]
427+
)
428+
429+
mock_session_manager.create_new_session.assert_awaited_once()
430+
call_args = mock_session_manager.create_new_session.await_args[1]
431+
assert call_args["session_key"] == session.session_key
432+
433+
semantic_manager.add_message.assert_awaited_once()
434+
435+
395436
@pytest.mark.asyncio
396437
async def test_add_episodes_skips_memories_not_requested(
397438
minimal_conf, patched_resource_manager

0 commit comments

Comments
 (0)