11"""Manages database for session config and short term data."""
22
3- import io
3+ import json
44import os
55import pickle
66from typing import Annotated , Any
99 JSON ,
1010 ForeignKeyConstraint ,
1111 Integer ,
12- LargeBinary ,
1312 PrimaryKeyConstraint ,
1413 String ,
1514 and_ ,
1615 func ,
1716 insert ,
17+ inspect ,
1818 select ,
19+ text ,
1920 update ,
2021)
2122from sqlalchemy .dialects .postgresql import JSONB
22- from sqlalchemy .ext .asyncio import AsyncEngine , async_sessionmaker
23+ from sqlalchemy .ext .asyncio import AsyncConnection , AsyncEngine , async_sessionmaker
2324from sqlalchemy .orm import (
2425 DeclarativeBase ,
2526 Mapped ,
@@ -44,7 +45,6 @@ class Base(DeclarativeBase): # pylint: disable=too-few-public-methods
4445StringKeyColumn = Annotated [str , mapped_column (String , primary_key = True )]
4546StringColumn = Annotated [str , mapped_column (String )]
4647JSONColumn = Annotated [dict , mapped_column (JSON_AUTO )]
47- BinaryColumn = Annotated [bytes , mapped_column (LargeBinary )]
4848
4949
5050class SessionDataManagerSQL (SessionDataManager ):
@@ -57,7 +57,7 @@ class SessionConfig(Base): # pylint: disable=too-few-public-methods
5757 session_key : Mapped [StringKeyColumn ]
5858 timestamp : Mapped [IntColumn ]
5959 configuration : Mapped [JSONColumn ]
60- param_data : Mapped [BinaryColumn ]
60+ param_data : Mapped [JSONColumn ]
6161 description : Mapped [StringColumn ]
6262 user_metadata : Mapped [JSONColumn ]
6363 __table_args__ = (PrimaryKeyConstraint ("session_key" ),)
@@ -93,7 +93,24 @@ def __init__(self, engine: AsyncEngine, schema: str | None = None) -> None:
9393
9494 async def create_tables (self ) -> None :
9595 """Create the necessary tables in the database."""
96+
97+ def _check_migration_needed (conn : AsyncConnection ) -> bool :
98+ inspector = inspect (conn )
99+ schema = self .SessionConfig .__table__ .schema
100+ table_name = "sessions"
101+ table_names = inspector .get_table_names (schema = schema )
102+ if table_name not in table_names :
103+ return False
104+ columns = inspector .get_columns (table_name , schema = schema )
105+ for column in columns :
106+ if column ["name" ] == "param_data" :
107+ return "JSON" not in str (column ["type" ]).upper ()
108+ return False
109+
96110 async with self ._engine .begin () as conn :
111+ if await conn .run_sync (_check_migration_needed ):
112+ await self ._migrate_pickle_to_json ()
113+ return
97114 await conn .run_sync (Base .metadata .create_all )
98115
99116 async def drop_tables (self ) -> None :
@@ -104,6 +121,64 @@ async def drop_tables(self) -> None:
104121 async def close (self ) -> None :
105122 """Close any underlying connections."""
106123
124+ async def _migrate_pickle_to_json (self ) -> None :
125+ """Migrate param_data from pickle to JSON."""
126+ schema = self .SessionConfig .__table__ .schema
127+ table_name = f"{ schema } .sessions" if schema else "sessions"
128+ dialect = self ._engine .dialect .name
129+ json_type = "JSONB" if dialect == "postgresql" else "JSON"
130+
131+ async with self ._engine .begin () as conn :
132+ await conn .execute (
133+ text (
134+ f"ALTER TABLE { table_name } RENAME COLUMN param_data TO param_data_blob"
135+ )
136+ )
137+ await conn .execute (
138+ text (f"ALTER TABLE { table_name } ADD COLUMN param_data { json_type } " )
139+ )
140+
141+ result = await conn .execute (
142+ text (f"SELECT session_key, param_data_blob FROM { table_name } " )
143+ )
144+
145+ for session_key , blob_data in result :
146+ if not blob_data :
147+ continue
148+ try :
149+ obj = pickle .loads (blob_data )
150+ if hasattr (obj , "model_dump" ):
151+ json_data = obj .model_dump (mode = "json" )
152+ elif hasattr (obj , "dict" ):
153+ json_data = obj .dict ()
154+ else :
155+ json_data = obj .__dict__
156+
157+ val = json .dumps (json_data , default = str )
158+
159+ if dialect == "postgresql" :
160+ await conn .execute (
161+ text (
162+ f"UPDATE { table_name } SET param_data = :val::jsonb WHERE session_key = :key"
163+ ),
164+ {"val" : val , "key" : session_key },
165+ )
166+ else :
167+ await conn .execute (
168+ text (
169+ f"UPDATE { table_name } SET param_data = :val WHERE session_key = :key"
170+ ),
171+ {"val" : val , "key" : session_key },
172+ )
173+ except Exception as exc :
174+ raise RuntimeError (
175+ f"Error migrating session { session_key } "
176+ ) from exc
177+
178+ await conn .execute (
179+ text (f"ALTER TABLE { table_name } DROP COLUMN param_data_blob" )
180+ )
181+
107182 async def create_new_session (
108183 self ,
109184 session_key : str ,
@@ -113,10 +188,13 @@ async def create_new_session(
113188 metadata : dict [str , object ],
114189 ) -> None :
115190 """Create a new session entry in the database."""
116- buffer = io .BytesIO ()
117- pickle .dump (param , buffer )
118- buffer .seek (0 )
119- param_data = buffer .getvalue ()
191+ if hasattr (param , "model_dump" ):
192+ param_data = param .model_dump (mode = "json" )
193+ elif hasattr (param , "dict" ):
194+ param_data = param .dict ()
195+ else :
196+ param_data = param .__dict__
197+
120198 async with self ._async_session () as dbsession :
121199 # Query for an existing session with the same ID
122200 sessions = await dbsession .execute (
@@ -164,9 +242,7 @@ async def get_session_info(
164242 session = sessions .scalars ().first ()
165243 if session is None :
166244 return None
167- binary_buffer = io .BytesIO (session .param_data )
168- binary_buffer .seek (0 )
169- param : EpisodicMemoryConf = pickle .load (binary_buffer )
245+ param = EpisodicMemoryConf (** session .param_data )
170246
171247 return SessionDataManager .SessionInfo (
172248 configuration = session .configuration ,
0 commit comments