Skip to content

Commit 49d74a3

Browse files
authored
fix: Refactor ADK store SQL table creation methods to be asynchronous (#144)
* fix: Refactor ADK store SQL table creation methods to be asynchronous This is done mainly to address issues with Oracle's JSON detection and in-memory flag on creation. - Changed `_get_create_sessions_table_sql` and `_get_create_events_table_sql` methods in various ADK store classes to be asynchronous. - Updated `create_tables` methods to use `execute_script` for executing SQL statements asynchronously. - Modified the Oracle ADK store to support INMEMORY PRIORITY HIGH clause for table creation. - Added tests to verify the correct creation of tables with INMEMORY settings for both async and sync stores. - Updated Litestar configuration to reflect changes in INMEMORY clause handling. * fix: assertions * refactor: switch from information_schema to pg_catalog for column and table queries in PostgreSQL adapters (#146)
1 parent 96cbe15 commit 49d74a3

File tree

24 files changed

+1605
-1322
lines changed

24 files changed

+1605
-1322
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ repos:
1717
- id: mixed-line-ending
1818
- id: trailing-whitespace
1919
- repo: https://github.com/charliermarsh/ruff-pre-commit
20-
rev: "v0.14.0"
20+
rev: "v0.14.1"
2121
hooks:
2222
- id: ruff
2323
args: ["--fix"]

sqlspec/adapters/adbc/data_dictionary.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -300,22 +300,44 @@ def get_columns(
300300
for row in result.data or []
301301
]
302302

303+
if dialect == "postgres":
304+
schema_name = schema or "public"
305+
sql = """
306+
SELECT
307+
a.attname::text AS column_name,
308+
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
309+
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
310+
pg_catalog.pg_get_expr(d.adbin, d.adrelid)::text AS column_default
311+
FROM pg_catalog.pg_attribute a
312+
JOIN pg_catalog.pg_class c ON a.attrelid = c.oid
313+
JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
314+
LEFT JOIN pg_catalog.pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum
315+
WHERE c.relname = ?
316+
AND n.nspname = ?
317+
AND a.attnum > 0
318+
AND NOT a.attisdropped
319+
ORDER BY a.attnum
320+
"""
321+
result = adbc_driver.execute(sql, (table, schema_name))
322+
return result.data or []
323+
303324
if schema:
304-
sql = f"""
325+
sql = """
305326
SELECT column_name, data_type, is_nullable, column_default
306327
FROM information_schema.columns
307-
WHERE table_name = '{table}' AND table_schema = '{schema}'
328+
WHERE table_name = ? AND table_schema = ?
308329
ORDER BY ordinal_position
309330
"""
331+
result = adbc_driver.execute(sql, (table, schema))
310332
else:
311-
sql = f"""
333+
sql = """
312334
SELECT column_name, data_type, is_nullable, column_default
313335
FROM information_schema.columns
314-
WHERE table_name = '{table}'
336+
WHERE table_name = ?
315337
ORDER BY ordinal_position
316338
"""
339+
result = adbc_driver.execute(sql, (table,))
317340

318-
result = adbc_driver.execute(sql)
319341
return result.data or []
320342

321343
def list_available_features(self) -> "list[str]":

sqlspec/adapters/aiosqlite/adk/store.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def __init__(self, config: "AiosqliteConfig") -> None:
136136
"""
137137
super().__init__(config)
138138

139-
def _get_create_sessions_table_sql(self) -> str:
139+
async def _get_create_sessions_table_sql(self) -> str:
140140
"""Get SQLite CREATE TABLE SQL for sessions.
141141
142142
Returns:
@@ -163,7 +163,7 @@ def _get_create_sessions_table_sql(self) -> str:
163163
ON {self._session_table}(update_time DESC);
164164
"""
165165

166-
def _get_create_events_table_sql(self) -> str:
166+
async def _get_create_events_table_sql(self) -> str:
167167
"""Get SQLite CREATE TABLE SQL for events.
168168
169169
Returns:
@@ -228,11 +228,10 @@ async def _enable_foreign_keys(self, connection: Any) -> None:
228228

229229
async def create_tables(self) -> None:
230230
"""Create both sessions and events tables if they don't exist."""
231-
async with self._config.provide_connection() as conn:
232-
await self._enable_foreign_keys(conn)
233-
await conn.executescript(self._get_create_sessions_table_sql())
234-
await conn.executescript(self._get_create_events_table_sql())
235-
await conn.commit()
231+
async with self._config.provide_session() as driver:
232+
await self._enable_foreign_keys(driver.connection)
233+
await driver.execute_script(await self._get_create_sessions_table_sql())
234+
await driver.execute_script(await self._get_create_events_table_sql())
236235
logger.debug("Created ADK tables: %s, %s", self._session_table, self._events_table)
237236

238237
async def create_session(

sqlspec/adapters/asyncmy/adk/store.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def _parse_owner_id_column_for_mysql(self, column_ddl: str) -> "tuple[str, str]"
106106

107107
return (col_def, fk_constraint)
108108

109-
def _get_create_sessions_table_sql(self) -> str:
109+
async def _get_create_sessions_table_sql(self) -> str:
110110
"""Get MySQL CREATE TABLE SQL for sessions.
111111
112112
Returns:
@@ -145,7 +145,7 @@ def _get_create_sessions_table_sql(self) -> str:
145145
) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_unicode_ci
146146
"""
147147

148-
def _get_create_events_table_sql(self) -> str:
148+
async def _get_create_events_table_sql(self) -> str:
149149
"""Get MySQL CREATE TABLE SQL for events.
150150
151151
Returns:
@@ -199,9 +199,9 @@ def _get_drop_tables_sql(self) -> "list[str]":
199199

200200
async def create_tables(self) -> None:
201201
"""Create both sessions and events tables if they don't exist."""
202-
async with self._config.provide_connection() as conn, conn.cursor() as cursor:
203-
await cursor.execute(self._get_create_sessions_table_sql())
204-
await cursor.execute(self._get_create_events_table_sql())
202+
async with self._config.provide_session() as driver:
203+
await driver.execute_script(await self._get_create_sessions_table_sql())
204+
await driver.execute_script(await self._get_create_events_table_sql())
205205
logger.debug("Created ADK tables: %s, %s", self._session_table, self._events_table)
206206

207207
async def create_session(

sqlspec/adapters/asyncpg/adk/store.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def __init__(self, config: AsyncConfigT) -> None:
8484
"""
8585
super().__init__(config)
8686

87-
def _get_create_sessions_table_sql(self) -> str:
87+
async def _get_create_sessions_table_sql(self) -> str:
8888
"""Get PostgreSQL CREATE TABLE SQL for sessions.
8989
9090
Returns:
@@ -125,7 +125,7 @@ def _get_create_sessions_table_sql(self) -> str:
125125
WHERE state != '{{}}'::jsonb;
126126
"""
127127

128-
def _get_create_events_table_sql(self) -> str:
128+
async def _get_create_events_table_sql(self) -> str:
129129
"""Get PostgreSQL CREATE TABLE SQL for events.
130130
131131
Returns:
@@ -181,9 +181,9 @@ def _get_drop_tables_sql(self) -> "list[str]":
181181

182182
async def create_tables(self) -> None:
183183
"""Create both sessions and events tables if they don't exist."""
184-
async with self.config.provide_connection() as conn:
185-
await conn.execute(self._get_create_sessions_table_sql())
186-
await conn.execute(self._get_create_events_table_sql())
184+
async with self.config.provide_session() as driver:
185+
await driver.execute_script(await self._get_create_sessions_table_sql())
186+
await driver.execute_script(await self._get_create_events_table_sql())
187187
logger.debug("Created ADK tables: %s, %s", self._session_table, self._events_table)
188188

189189
async def create_session(

sqlspec/adapters/asyncpg/data_dictionary.py

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ async def get_optimal_type(self, driver: AsyncDriverAdapterBase, type_category:
117117
async def get_columns(
118118
self, driver: AsyncDriverAdapterBase, table: str, schema: "str | None" = None
119119
) -> "list[dict[str, Any]]":
120-
"""Get column information for a table using information_schema.
120+
"""Get column information for a table using pg_catalog.
121121
122122
Args:
123123
driver: AsyncPG driver instance
@@ -130,25 +130,32 @@ async def get_columns(
130130
- data_type: PostgreSQL data type
131131
- is_nullable: Whether column allows NULL (YES/NO)
132132
- column_default: Default value if any
133+
134+
Notes:
135+
Uses pg_catalog instead of information_schema to avoid potential
136+
issues with PostgreSQL 'name' type in some drivers.
133137
"""
134138
asyncpg_driver = cast("AsyncpgDriver", driver)
135139

136-
if schema:
137-
sql = f"""
138-
SELECT column_name, data_type, is_nullable, column_default
139-
FROM information_schema.columns
140-
WHERE table_name = '{table}' AND table_schema = '{schema}'
141-
ORDER BY ordinal_position
142-
"""
143-
else:
144-
sql = f"""
145-
SELECT column_name, data_type, is_nullable, column_default
146-
FROM information_schema.columns
147-
WHERE table_name = '{table}' AND table_schema = 'public'
148-
ORDER BY ordinal_position
149-
"""
150-
151-
result = await asyncpg_driver.execute(sql)
140+
schema_name = schema or "public"
141+
sql = """
142+
SELECT
143+
a.attname::text AS column_name,
144+
pg_catalog.format_type(a.atttypid, a.atttypmod) AS data_type,
145+
CASE WHEN a.attnotnull THEN 'NO' ELSE 'YES' END AS is_nullable,
146+
pg_catalog.pg_get_expr(d.adbin, d.adrelid)::text AS column_default
147+
FROM pg_catalog.pg_attribute a
148+
JOIN pg_catalog.pg_class c ON a.attrelid = c.oid
149+
JOIN pg_catalog.pg_namespace n ON c.relnamespace = n.oid
150+
LEFT JOIN pg_catalog.pg_attrdef d ON a.attrelid = d.adrelid AND a.attnum = d.adnum
151+
WHERE c.relname = $1
152+
AND n.nspname = $2
153+
AND a.attnum > 0
154+
AND NOT a.attisdropped
155+
ORDER BY a.attnum
156+
"""
157+
158+
result = await asyncpg_driver.execute(sql, (table, schema_name))
152159
return result.data or []
153160

154161
def list_available_features(self) -> "list[str]":

sqlspec/adapters/bigquery/adk/store.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sqlspec.extensions.adk import BaseAsyncADKStore, EventRecord, SessionRecord
99
from sqlspec.utils.logging import get_logger
1010
from sqlspec.utils.serializers import from_json, to_json
11-
from sqlspec.utils.sync_tools import async_
11+
from sqlspec.utils.sync_tools import async_, run_
1212

1313
if TYPE_CHECKING:
1414
from sqlspec.adapters.bigquery.config import BigQueryConfig
@@ -102,7 +102,7 @@ def _get_full_table_name(self, table_name: str) -> str:
102102
return f"`{self._dataset_id}.{table_name}`"
103103
return f"`{table_name}`"
104104

105-
def _get_create_sessions_table_sql(self) -> str:
105+
async def _get_create_sessions_table_sql(self) -> str:
106106
"""Get BigQuery CREATE TABLE SQL for sessions.
107107
108108
Returns:
@@ -136,7 +136,7 @@ def _get_create_sessions_table_sql(self) -> str:
136136
CLUSTER BY app_name, user_id
137137
"""
138138

139-
def _get_create_events_table_sql(self) -> str:
139+
async def _get_create_events_table_sql(self) -> str:
140140
"""Get BigQuery CREATE TABLE SQL for events.
141141
142142
Returns:
@@ -193,9 +193,9 @@ def _get_drop_tables_sql(self) -> "list[str]":
193193

194194
def _create_tables(self) -> None:
195195
"""Synchronous implementation of create_tables."""
196-
with self._config.provide_connection() as conn:
197-
conn.query(self._get_create_sessions_table_sql()).result()
198-
conn.query(self._get_create_events_table_sql()).result()
196+
with self._config.provide_session() as driver:
197+
driver.execute_script(run_(self._get_create_sessions_table_sql)())
198+
driver.execute_script(run_(self._get_create_events_table_sql)())
199199
logger.debug("Created BigQuery ADK tables: %s, %s", self._session_table, self._events_table)
200200

201201
async def create_tables(self) -> None:

0 commit comments

Comments
 (0)