Skip to content

Commit 069a248

Browse files
committed
feat(spanner): add Litestar integration with session store and shard support
1 parent 13ad36a commit 069a248

File tree

8 files changed

+788
-396
lines changed

8 files changed

+788
-396
lines changed

sqlspec/adapters/spanner/adk/__init__.py

Lines changed: 2 additions & 396 deletions
Large diffs are not rendered by default.

sqlspec/adapters/spanner/adk/store.py

Lines changed: 431 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
"""Litestar integration for Spanner adapter."""
2+
3+
from sqlspec.adapters.spanner.litestar.store import SpannerSyncStore
4+
5+
__all__ = ("SpannerSyncStore",)
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
"""Spanner session store for Litestar integration."""
2+
3+
from datetime import datetime, timedelta, timezone
4+
from typing import TYPE_CHECKING, Any, cast
5+
6+
from google.api_core import exceptions as api_exceptions
7+
8+
from sqlspec.extensions.litestar.store import BaseSQLSpecStore
9+
from sqlspec.utils.logging import get_logger
10+
from sqlspec.utils.sync_tools import async_
11+
12+
if TYPE_CHECKING:
13+
from sqlspec.adapters.spanner.config import SpannerSyncConfig
14+
15+
logger = get_logger("adapters.spanner.litestar.store")
16+
17+
__all__ = ("SpannerSyncStore",)
18+
19+
20+
class SpannerSyncStore(BaseSQLSpecStore["SpannerSyncConfig"]):
21+
"""Spanner-backed Litestar session store using sync driver wrapped as async."""
22+
23+
__slots__ = ("_index_options", "_shard_count", "_table_options")
24+
25+
def __init__(self, config: "SpannerSyncConfig") -> None:
26+
super().__init__(config)
27+
litestar_cfg = cast("dict[str, Any]", getattr(config, "extension_config", {}).get("litestar", {}))
28+
self._shard_count: int = int(litestar_cfg.get("shard_count", 0)) if litestar_cfg.get("shard_count") else 0
29+
self._table_options: str | None = litestar_cfg.get("table_options")
30+
self._index_options: str | None = litestar_cfg.get("index_options")
31+
32+
def _datetime_to_timestamp(self, dt: "datetime | None") -> "datetime | None":
33+
if dt is None:
34+
return None
35+
if dt.tzinfo is None:
36+
return dt.replace(tzinfo=timezone.utc)
37+
return dt
38+
39+
def _timestamp_to_datetime(self, ts: "datetime | None") -> "datetime | None":
40+
if ts is None:
41+
return None
42+
if ts.tzinfo is None:
43+
return ts.replace(tzinfo=timezone.utc)
44+
return ts
45+
46+
def _build_params(
47+
self, key: str, expires_at: "datetime | None" = None, data: "bytes | None" = None
48+
) -> dict[str, Any]:
49+
return {"session_id": key, "data": data, "expires_at": self._datetime_to_timestamp(expires_at)}
50+
51+
async def get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None":
52+
return await async_(self._get)(key, renew_for)
53+
54+
def _get(self, key: str, renew_for: "int | timedelta | None" = None) -> "bytes | None":
55+
with self._config.provide_session() as driver:
56+
sql = f"""
57+
SELECT data, expires_at
58+
FROM {self._table_name}
59+
WHERE session_id = @session_id
60+
AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP())
61+
"""
62+
if self._shard_count > 1:
63+
sql = f"""
64+
SELECT data, expires_at
65+
FROM {self._table_name}
66+
WHERE shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})
67+
AND session_id = @session_id
68+
AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP())
69+
"""
70+
result = driver.select_one(sql, {"session_id": key})
71+
if result is None:
72+
return None
73+
74+
data = result.get("data")
75+
expires_at = self._timestamp_to_datetime(result.get("expires_at"))
76+
77+
if renew_for is not None and expires_at is not None:
78+
new_expires = self._calculate_expires_at(renew_for)
79+
update_sql = f"""
80+
UPDATE {self._table_name}
81+
SET expires_at = @expires_at, updated_at = PENDING_COMMIT_TIMESTAMP()
82+
WHERE session_id = @session_id
83+
"""
84+
if self._shard_count > 1:
85+
update_sql = f"{update_sql} AND shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})"
86+
driver.execute(update_sql, self._build_params(key, new_expires))
87+
88+
return bytes(data) if data is not None else None
89+
90+
async def set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None:
91+
await async_(self._set)(key, value, expires_in)
92+
93+
def _set(self, key: str, value: "str | bytes", expires_in: "int | timedelta | None" = None) -> None:
94+
data = self._value_to_bytes(value)
95+
expires_at = self._calculate_expires_at(expires_in)
96+
params = self._build_params(key, expires_at, data)
97+
98+
upsert_sql = f"""
99+
UPDATE {self._table_name}
100+
SET data = @data,
101+
expires_at = @expires_at,
102+
updated_at = PENDING_COMMIT_TIMESTAMP()
103+
WHERE session_id = @session_id
104+
"""
105+
if self._shard_count > 1:
106+
upsert_sql = f"{upsert_sql} AND shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})"
107+
insert_sql = f"""
108+
INSERT {self._table_name} (session_id, data, expires_at, created_at, updated_at)
109+
VALUES (@session_id, @data, @expires_at, PENDING_COMMIT_TIMESTAMP(), PENDING_COMMIT_TIMESTAMP())
110+
"""
111+
with self._config.provide_session() as driver:
112+
update_result = driver.execute(upsert_sql, params)
113+
if update_result.rows_affected == 0:
114+
driver.execute(insert_sql, params)
115+
116+
async def delete(self, key: str) -> None:
117+
await async_(self._delete)(key)
118+
119+
def _delete(self, key: str) -> None:
120+
sql = f"DELETE FROM {self._table_name} WHERE session_id = @session_id"
121+
if self._shard_count > 1:
122+
sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})"
123+
with self._config.provide_session() as driver:
124+
driver.execute(sql, {"session_id": key})
125+
126+
async def delete_all(self) -> None:
127+
await async_(self._delete_all)()
128+
129+
def _delete_all(self) -> None:
130+
sql = f"DELETE FROM {self._table_name}"
131+
with self._config.provide_session() as driver:
132+
driver.execute(sql)
133+
134+
async def exists(self, key: str) -> bool:
135+
return await async_(self._exists)(key)
136+
137+
def _exists(self, key: str) -> bool:
138+
sql = f"""
139+
SELECT 1 FROM {self._table_name}
140+
WHERE session_id = @session_id
141+
AND (expires_at IS NULL OR expires_at > CURRENT_TIMESTAMP())
142+
LIMIT 1
143+
"""
144+
with self._config.provide_session() as driver:
145+
row = driver.select_one(sql, {"session_id": key})
146+
return row is not None
147+
148+
async def expires_in(self, key: str) -> "int | None":
149+
return await async_(self._expires_in)(key)
150+
151+
def _expires_in(self, key: str) -> "int | None":
152+
sql = f"""
153+
SELECT expires_at FROM {self._table_name}
154+
WHERE session_id = @session_id
155+
"""
156+
if self._shard_count > 1:
157+
sql = f"{sql} AND shard_id = MOD(FARM_FINGERPRINT(@session_id), {self._shard_count})"
158+
with self._config.provide_session() as driver:
159+
row = driver.select_one(sql, {"session_id": key})
160+
if row is None:
161+
return None
162+
expires_at = self._timestamp_to_datetime(row.get("expires_at"))
163+
if expires_at is None:
164+
return None
165+
delta = expires_at - datetime.now(timezone.utc)
166+
return max(int(delta.total_seconds()), 0)
167+
168+
async def delete_expired(self) -> int:
169+
return await async_(self._delete_expired)()
170+
171+
def _delete_expired(self) -> int:
172+
sql = f"""
173+
DELETE FROM {self._table_name}
174+
WHERE expires_at IS NOT NULL AND expires_at <= CURRENT_TIMESTAMP()
175+
"""
176+
with self._config.provide_session() as driver:
177+
result = driver.execute(sql)
178+
return result.rows_affected or 0
179+
180+
async def create_table(self) -> None:
181+
await async_(self._create_table)()
182+
183+
def _create_table(self) -> None:
184+
ddl_statements = [self._get_create_table_sql(), self._get_create_index_sql()]
185+
try:
186+
self._config.get_database().update_ddl(ddl_statements).result(300) # type: ignore[no-untyped-call]
187+
except api_exceptions.AlreadyExists:
188+
return
189+
190+
def _get_create_table_sql(self) -> str:
191+
shard_column = ""
192+
pk = "PRIMARY KEY (session_id)"
193+
if self._shard_count > 1:
194+
shard_column = f",\n shard_id INT64 AS (MOD(FARM_FINGERPRINT(session_id), {self._shard_count})) STORED"
195+
pk = "PRIMARY KEY (shard_id, session_id)"
196+
options = ""
197+
if self._table_options:
198+
options = f"\nOPTIONS ({self._table_options})"
199+
return f"""
200+
CREATE TABLE {self._table_name} (
201+
session_id STRING(128) NOT NULL,
202+
data BYTES(MAX) NOT NULL,
203+
expires_at TIMESTAMP,
204+
created_at TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true),
205+
updated_at TIMESTAMP NOT NULL OPTIONS (allow_commit_timestamp=true){shard_column}
206+
) {pk}{options}
207+
"""
208+
209+
def _get_create_index_sql(self) -> str:
210+
leading = "expires_at"
211+
if self._shard_count > 1:
212+
leading = "shard_id, expires_at"
213+
opts = ""
214+
if self._index_options:
215+
opts = f" OPTIONS ({self._index_options})"
216+
return f"CREATE INDEX idx_{self._table_name}_expires_at ON {self._table_name}({leading}){opts}"
217+
218+
def _get_drop_table_sql(self) -> "list[str]":
219+
return [f"DROP INDEX idx_{self._table_name}_expires_at", f"DROP TABLE {self._table_name}"]

sqlspec/config.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,28 @@ class ADKConfig(TypedDict):
403403
- Ignored by non-Oracle adapters
404404
"""
405405

406+
shard_count: NotRequired[int]
407+
"""Optional hash shard count for session/event tables to reduce hotspotting.
408+
409+
When set (>1), adapters that support computed shard columns will create a
410+
generated shard_id using MOD(FARM_FINGERPRINT(primary_key), shard_count) and
411+
include it in the primary key and filters. Ignored by adapters that do not
412+
support computed shards.
413+
"""
414+
415+
session_table_options: NotRequired[str]
416+
"""Adapter-specific table OPTIONS/clauses for the sessions table.
417+
418+
Passed verbatim when supported (e.g., Spanner columnar/tiered storage). Ignored by
419+
adapters without table OPTIONS support.
420+
"""
421+
422+
events_table_options: NotRequired[str]
423+
"""Adapter-specific table OPTIONS/clauses for the events table."""
424+
425+
expires_index_options: NotRequired[str]
426+
"""Adapter-specific options for the expires/index used in ADK stores."""
427+
406428

407429
class OpenTelemetryConfig(TypedDict):
408430
"""Configuration options for OpenTelemetry integration.

sqlspec/extensions/litestar/config.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,27 @@ class LitestarConfig(TypedDict):
6464
- Tables created with INMEMORY PRIORITY HIGH clause
6565
- Ignored by unsupported adapters
6666
"""
67+
68+
shard_count: NotRequired[int]
69+
"""Optional hash shard count for session table primary key.
70+
71+
When set (>1), adapters that support computed shard columns (e.g., Spanner)
72+
will create a generated shard_id using MOD(FARM_FINGERPRINT(session_id), shard_count)
73+
and include it in the primary key to reduce hotspotting. Ignored by adapters
74+
that do not support computed shards.
75+
"""
76+
77+
table_options: NotRequired[str]
78+
"""Optional raw OPTIONS/engine-specific table options string.
79+
80+
Passed verbatim when the adapter supports table-level OPTIONS/clauses
81+
(e.g., Spanner columnar/tiered storage). Ignored by adapters that do not
82+
support table options.
83+
"""
84+
85+
index_options: NotRequired[str]
86+
"""Optional raw OPTIONS/engine-specific options for the expires_at index.
87+
88+
Passed verbatim to the index definition for adapters that support index
89+
OPTIONS/clauses. Ignored by adapters that do not support index options.
90+
"""
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from collections.abc import AsyncGenerator
2+
from typing import Any, cast
3+
4+
import pytest
5+
from google.auth.credentials import AnonymousCredentials
6+
7+
from sqlspec.adapters.spanner import SpannerSyncConfig
8+
from sqlspec.adapters.spanner.litestar import SpannerSyncStore
9+
10+
11+
@pytest.fixture(scope="session")
12+
def spanner_litestar_config(spanner_service: Any) -> SpannerSyncConfig:
13+
host = getattr(spanner_service, "host", "localhost")
14+
port = getattr(spanner_service, "port", 9010)
15+
project_id = getattr(spanner_service, "project", "test-project")
16+
instance_id = getattr(spanner_service, "instance_id", getattr(spanner_service, "instance", "test-instance"))
17+
database_id = getattr(spanner_service, "database_id", getattr(spanner_service, "database", "test-database"))
18+
api_endpoint = f"{host}:{port}"
19+
20+
return SpannerSyncConfig(
21+
pool_config={
22+
"project": project_id,
23+
"instance_id": instance_id,
24+
"database_id": database_id,
25+
"credentials": cast(Any, AnonymousCredentials()), # type: ignore[no-untyped-call]
26+
"client_options": {"api_endpoint": api_endpoint},
27+
"min_sessions": 1,
28+
"max_sessions": 5,
29+
},
30+
extension_config={"litestar": {"session_table": "litestar_sessions"}},
31+
)
32+
33+
34+
@pytest.fixture
35+
async def spanner_store(spanner_litestar_config: SpannerSyncConfig) -> AsyncGenerator[SpannerSyncStore, None]:
36+
store = SpannerSyncStore(spanner_litestar_config)
37+
await store.create_table()
38+
try:
39+
yield store
40+
finally:
41+
await store.delete_all()
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
"""Integration tests for Spanner session store."""
2+
3+
import asyncio
4+
5+
import pytest
6+
7+
from sqlspec.adapters.spanner.litestar import SpannerSyncStore
8+
9+
pytestmark = [pytest.mark.spanner, pytest.mark.integration]
10+
11+
12+
async def test_store_set_and_get(spanner_store: SpannerSyncStore) -> None:
13+
data = b"payload"
14+
await spanner_store.set("s1", data)
15+
assert await spanner_store.get("s1") == data
16+
17+
18+
async def test_store_expiration(spanner_store: SpannerSyncStore) -> None:
19+
await spanner_store.set("expiring", b"data", expires_in=1)
20+
assert await spanner_store.exists("expiring")
21+
await asyncio.sleep(1.1)
22+
assert await spanner_store.get("expiring") is None
23+
24+
25+
async def test_store_delete(spanner_store: SpannerSyncStore) -> None:
26+
await spanner_store.set("todelete", b"d")
27+
await spanner_store.delete("todelete")
28+
assert await spanner_store.get("todelete") is None
29+
30+
31+
async def test_store_renew(spanner_store: SpannerSyncStore) -> None:
32+
await spanner_store.set("renew", b"r", expires_in=1)
33+
await asyncio.sleep(0.5)
34+
await spanner_store.get("renew", renew_for=1)
35+
await asyncio.sleep(0.7)
36+
assert await spanner_store.get("renew") == b"r"
37+
38+
39+
async def test_delete_expired_returns_count(spanner_store: SpannerSyncStore) -> None:
40+
await spanner_store.set("exp1", b"x", expires_in=1)
41+
await spanner_store.set("exp2", b"x", expires_in=1)
42+
await asyncio.sleep(1.1)
43+
deleted = await spanner_store.delete_expired()
44+
assert deleted >= 2

0 commit comments

Comments
 (0)