Skip to content

Commit bd0c1ee

Browse files
authored
Migrate from pickle to JSON (MemMachine#927)
1 parent 1e7baf3 commit bd0c1ee

File tree

2 files changed

+180
-12
lines changed

2 files changed

+180
-12
lines changed

src/memmachine/common/session_manager/session_data_manager_sql_impl.py

Lines changed: 88 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Manages database for session config and short term data."""
22

3-
import io
3+
import json
44
import os
55
import pickle
66
from typing import Annotated, Any
@@ -9,17 +9,18 @@
99
JSON,
1010
ForeignKeyConstraint,
1111
Integer,
12-
LargeBinary,
1312
PrimaryKeyConstraint,
1413
String,
1514
and_,
1615
func,
1716
insert,
17+
inspect,
1818
select,
19+
text,
1920
update,
2021
)
2122
from sqlalchemy.dialects.postgresql import JSONB
22-
from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker
23+
from sqlalchemy.ext.asyncio import AsyncConnection, AsyncEngine, async_sessionmaker
2324
from sqlalchemy.orm import (
2425
DeclarativeBase,
2526
Mapped,
@@ -44,7 +45,6 @@ class Base(DeclarativeBase): # pylint: disable=too-few-public-methods
4445
StringKeyColumn = Annotated[str, mapped_column(String, primary_key=True)]
4546
StringColumn = Annotated[str, mapped_column(String)]
4647
JSONColumn = Annotated[dict, mapped_column(JSON_AUTO)]
47-
BinaryColumn = Annotated[bytes, mapped_column(LargeBinary)]
4848

4949

5050
class SessionDataManagerSQL(SessionDataManager):
@@ -57,7 +57,7 @@ class SessionConfig(Base): # pylint: disable=too-few-public-methods
5757
session_key: Mapped[StringKeyColumn]
5858
timestamp: Mapped[IntColumn]
5959
configuration: Mapped[JSONColumn]
60-
param_data: Mapped[BinaryColumn]
60+
param_data: Mapped[JSONColumn]
6161
description: Mapped[StringColumn]
6262
user_metadata: Mapped[JSONColumn]
6363
__table_args__ = (PrimaryKeyConstraint("session_key"),)
@@ -93,7 +93,24 @@ def __init__(self, engine: AsyncEngine, schema: str | None = None) -> None:
9393

9494
async def create_tables(self) -> None:
9595
"""Create the necessary tables in the database."""
96+
97+
def _check_migration_needed(conn: AsyncConnection) -> bool:
98+
inspector = inspect(conn)
99+
schema = self.SessionConfig.__table__.schema
100+
table_name = "sessions"
101+
table_names = inspector.get_table_names(schema=schema)
102+
if table_name not in table_names:
103+
return False
104+
columns = inspector.get_columns(table_name, schema=schema)
105+
for column in columns:
106+
if column["name"] == "param_data":
107+
return "JSON" not in str(column["type"]).upper()
108+
return False
109+
96110
async with self._engine.begin() as conn:
111+
if await conn.run_sync(_check_migration_needed):
112+
await self._migrate_pickle_to_json()
113+
return
97114
await conn.run_sync(Base.metadata.create_all)
98115

99116
async def drop_tables(self) -> None:
@@ -104,6 +121,64 @@ async def drop_tables(self) -> None:
104121
async def close(self) -> None:
105122
"""Close any underlying connections."""
106123

124+
async def _migrate_pickle_to_json(self) -> None:
125+
"""Migrate param_data from pickle to JSON."""
126+
schema = self.SessionConfig.__table__.schema
127+
table_name = f"{schema}.sessions" if schema else "sessions"
128+
dialect = self._engine.dialect.name
129+
json_type = "JSONB" if dialect == "postgresql" else "JSON"
130+
131+
async with self._engine.begin() as conn:
132+
await conn.execute(
133+
text(
134+
f"ALTER TABLE {table_name} RENAME COLUMN param_data TO param_data_blob"
135+
)
136+
)
137+
await conn.execute(
138+
text(f"ALTER TABLE {table_name} ADD COLUMN param_data {json_type}")
139+
)
140+
141+
result = await conn.execute(
142+
text(f"SELECT session_key, param_data_blob FROM {table_name}")
143+
)
144+
145+
for session_key, blob_data in result:
146+
if not blob_data:
147+
continue
148+
try:
149+
obj = pickle.loads(blob_data)
150+
if hasattr(obj, "model_dump"):
151+
json_data = obj.model_dump(mode="json")
152+
elif hasattr(obj, "dict"):
153+
json_data = obj.dict()
154+
else:
155+
json_data = obj.__dict__
156+
157+
val = json.dumps(json_data, default=str)
158+
159+
if dialect == "postgresql":
160+
await conn.execute(
161+
text(
162+
f"UPDATE {table_name} SET param_data = :val::jsonb WHERE session_key = :key"
163+
),
164+
{"val": val, "key": session_key},
165+
)
166+
else:
167+
await conn.execute(
168+
text(
169+
f"UPDATE {table_name} SET param_data = :val WHERE session_key = :key"
170+
),
171+
{"val": val, "key": session_key},
172+
)
173+
except Exception as exc:
174+
raise RuntimeError(
175+
f"Error migrating session {session_key}"
176+
) from exc
177+
178+
await conn.execute(
179+
text(f"ALTER TABLE {table_name} DROP COLUMN param_data_blob")
180+
)
181+
107182
async def create_new_session(
108183
self,
109184
session_key: str,
@@ -113,10 +188,13 @@ async def create_new_session(
113188
metadata: dict[str, object],
114189
) -> None:
115190
"""Create a new session entry in the database."""
116-
buffer = io.BytesIO()
117-
pickle.dump(param, buffer)
118-
buffer.seek(0)
119-
param_data = buffer.getvalue()
191+
if hasattr(param, "model_dump"):
192+
param_data = param.model_dump(mode="json")
193+
elif hasattr(param, "dict"):
194+
param_data = param.dict()
195+
else:
196+
param_data = param.__dict__
197+
120198
async with self._async_session() as dbsession:
121199
# Query for an existing session with the same ID
122200
sessions = await dbsession.execute(
@@ -164,9 +242,7 @@ async def get_session_info(
164242
session = sessions.scalars().first()
165243
if session is None:
166244
return None
167-
binary_buffer = io.BytesIO(session.param_data)
168-
binary_buffer.seek(0)
169-
param: EpisodicMemoryConf = pickle.load(binary_buffer)
245+
param = EpisodicMemoryConf(**session.param_data)
170246

171247
return SessionDataManager.SessionInfo(
172248
configuration=session.configuration,
Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
"""Tests for session data manager migration."""
2+
3+
import pickle
4+
5+
import pytest
6+
from sqlalchemy import (
7+
JSON,
8+
Column,
9+
Integer,
10+
LargeBinary,
11+
MetaData,
12+
String,
13+
Table,
14+
insert,
15+
inspect,
16+
)
17+
from sqlalchemy.ext.asyncio import create_async_engine
18+
19+
from memmachine.common.configuration.episodic_config import EpisodicMemoryConf
20+
from memmachine.common.session_manager.session_data_manager_sql_impl import (
21+
SessionDataManagerSQL,
22+
)
23+
24+
25+
@pytest.mark.asyncio
26+
async def test_migrate_pickle_to_json() -> None:
27+
"""Test that the database migrates from pickle to JSON correctly."""
28+
# Use an in-memory SQLite database
29+
engine = create_async_engine("sqlite+aiosqlite:///:memory:")
30+
31+
# Define the old schema (before migration)
32+
metadata = MetaData()
33+
sessions_table = Table(
34+
"sessions",
35+
metadata,
36+
Column("session_key", String, primary_key=True),
37+
Column("timestamp", Integer),
38+
Column("configuration", JSON),
39+
Column("param_data", LargeBinary), # Old type: LargeBinary (BLOB)
40+
Column("description", String),
41+
Column("user_metadata", JSON),
42+
)
43+
44+
# Create the table with the old schema
45+
async with engine.begin() as conn:
46+
await conn.run_sync(metadata.create_all)
47+
48+
# Create a dummy EpisodicMemoryConf and pickle it
49+
# We assume EpisodicMemoryConf can be instantiated with defaults.
50+
session_key = "test_session_migration"
51+
param = EpisodicMemoryConf(session_key=session_key)
52+
pickled_data = pickle.dumps(param)
53+
54+
# Insert a record with the pickled data
55+
async with engine.begin() as conn:
56+
await conn.execute(
57+
insert(sessions_table).values(
58+
session_key=session_key,
59+
timestamp=1234567890,
60+
configuration={"test": "config"},
61+
param_data=pickled_data,
62+
description="Test Session",
63+
user_metadata={"meta": "data"},
64+
)
65+
)
66+
67+
# Initialize the SessionDataManagerSQL
68+
# This should trigger the migration in create_tables
69+
manager = SessionDataManagerSQL(engine)
70+
await manager.create_tables()
71+
72+
# Verify the migration
73+
# 1. Check that we can retrieve the session info (which implies successful JSON deserialization)
74+
session_info = await manager.get_session_info(session_key)
75+
assert session_info is not None
76+
assert session_info.episode_memory_conf.session_key == session_key
77+
78+
assert isinstance(session_info.episode_memory_conf, EpisodicMemoryConf)
79+
80+
# 2. Verify the column type in the database is now JSON (or TEXT in SQLite)
81+
def check_column(conn):
82+
inspector = inspect(conn)
83+
columns = inspector.get_columns("sessions")
84+
for col in columns:
85+
if col["name"] == "param_data":
86+
assert "JSON" in str(col["type"]).upper()
87+
88+
async with engine.connect() as conn:
89+
await conn.run_sync(check_column)
90+
91+
await manager.close()
92+
await engine.dispose()

0 commit comments

Comments
 (0)