Skip to content

Commit c1007ae

Browse files
committed
feat(spanner): enhance Spanner configuration and add integration tests for ADK and Litestar session store
1 parent f3be644 commit c1007ae

File tree

12 files changed

+49
-33
lines changed

12 files changed

+49
-33
lines changed

sqlspec/adapters/spanner/config.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,20 @@ def _create_pool(self) -> AbstractSessionPool:
156156
database = instance.database(database_id, pool=None)
157157

158158
pool_type = cast("type[AbstractSessionPool]", self.pool_config.get("pool_type", FixedSizePool))
159-
connection_keys = {"project", "instance_id", "database_id", "credentials", "client_options", "pool_type"}
160-
pool_kwargs = {k: v for k, v in self.pool_config.items() if k not in connection_keys and v is not None}
161159

162-
if pool_type is FixedSizePool and "size" not in pool_kwargs and "max_sessions" in self.pool_config:
163-
pool_kwargs["size"] = self.pool_config["max_sessions"]
160+
pool_kwargs: dict[str, Any] = {}
161+
if pool_type is FixedSizePool:
162+
if "size" in self.pool_config:
163+
pool_kwargs["size"] = self.pool_config["size"]
164+
elif "max_sessions" in self.pool_config:
165+
pool_kwargs["size"] = self.pool_config["max_sessions"]
166+
if "labels" in self.pool_config:
167+
pool_kwargs["labels"] = self.pool_config["labels"]
168+
else:
169+
valid_pool_keys = {"size", "labels", "ping_interval"}
170+
pool_kwargs = {k: v for k, v in self.pool_config.items() if k in valid_pool_keys and v is not None}
171+
if "size" not in pool_kwargs and "max_sessions" in self.pool_config:
172+
pool_kwargs["size"] = self.pool_config["max_sessions"]
164173

165174
pool_factory = cast("Callable[..., AbstractSessionPool]", pool_type)
166175
return pool_factory(database, **pool_kwargs)
@@ -176,7 +185,7 @@ def provide_connection(
176185
"""Yield a Snapshot (default) or Transaction from the configured pool."""
177186
database = self.get_database()
178187
if transaction:
179-
with cast("Any", database).transaction() as txn: # type: ignore[no-untyped-call]
188+
with cast("Any", database).transaction() as txn:
180189
yield cast("SpannerConnection", txn)
181190
else:
182191
with cast("Any", database).snapshot() as snapshot:
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Spanner dialect submodule."""
22

3-
from sqlspec.adapters.spanner.dialect._spanner import Spanner
43
from sqlspec.adapters.spanner.dialect._spangres import Spangres
4+
from sqlspec.adapters.spanner.dialect._spanner import Spanner
55

6-
__all__ = ("Spanner", "Spangres")
6+
__all__ = ("Spangres", "Spanner")

sqlspec/adapters/spanner/dialect/_spangres.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
"""Google Cloud Spanner PostgreSQL-interface dialect (\"Spangres\")."""
1+
r"""Google Cloud Spanner PostgreSQL-interface dialect ("Spangres")."""
22

33
from typing import Any, cast
44

sqlspec/adapters/spanner/dialect/_spanner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _parse_property(self) -> exp.Expression:
7676
this=exp.Literal.string(_ROW_DELETION_NAME), value=exp.Tuple(expressions=[column, interval])
7777
)
7878

79-
if self._match_text_seq("TTL"): # PostgreSQL-dialect style, keep for compatibility
79+
if self._match_text_seq("TTL"): # type: ignore[no-untyped-call] # PostgreSQL-dialect style, keep for compatibility
8080
self._match_text_seq("INTERVAL") # type: ignore[no-untyped-call]
8181
interval = cast("exp.Expression", self._parse_expression())
8282
self._match_text_seq("ON") # type: ignore[no-untyped-call]

sqlspec/builder/_ddl.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1155,8 +1155,8 @@ def _create_base_expression(self) -> exp.Expression:
11551155
if self._select_query is None:
11561156
self._raise_sql_builder_error("SELECT query must be set for CREATE MATERIALIZED VIEW.")
11571157

1158-
select_expr = None
1159-
select_parameters = None
1158+
select_expr: exp.Expression | None = None
1159+
select_parameters: dict[str, Any] | None = None
11601160

11611161
if isinstance(self._select_query, SQL):
11621162
select_expr = self._select_query.expression
@@ -1251,8 +1251,8 @@ def _create_base_expression(self) -> exp.Expression:
12511251
if self._select_query is None:
12521252
self._raise_sql_builder_error("SELECT query must be set for CREATE VIEW.")
12531253

1254-
select_expr = None
1255-
select_parameters = None
1254+
select_expr: exp.Expression | None = None
1255+
select_parameters: dict[str, Any] | None = None
12561256

12571257
if isinstance(self._select_query, SQL):
12581258
select_expr = self._select_query.expression

sqlspec/builder/_vector_expressions.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
"""Custom SQLGlot expressions for vector distance operations.
22
33
Provides dialect-specific SQL generation for vector similarity search
4-
across PostgreSQL (pgvector), MySQL 9+, and Oracle 23ai+.
4+
across PostgreSQL (pgvector), MySQL 9+, Oracle 23ai+, BigQuery, and Spanner.
55
"""
66

7+
from contextlib import suppress
78
from typing import Any
89

910
from sqlglot import exp
@@ -194,11 +195,13 @@ def _register_with_sqlglot() -> None:
194195
from sqlglot.dialects.postgres import Postgres
195196
from sqlglot.generator import Generator
196197

197-
try: # optional, only when Spanner dialects are present
198-
from sqlspec.adapters.spanner.dialect import Spanner, Spangres
199-
except Exception: # pragma: no cover - optional import
200-
Spanner = None
201-
Spangres = None
198+
spanner_dialect: type | None = None
199+
spangres_dialect: type | None = None
200+
with suppress(ImportError):
201+
from sqlspec.adapters.spanner.dialect import Spangres, Spanner
202+
203+
spanner_dialect = Spanner
204+
spangres_dialect = Spangres
202205

203206
def vector_distance_sql_base(generator: "Generator", expression: "VectorDistance") -> str:
204207
"""Base generator for VectorDistance expressions."""
@@ -247,10 +250,10 @@ def vector_distance_sql_duckdb(generator: "Generator", expression: "VectorDistan
247250
Oracle.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_oracle
248251
BigQuery.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_bigquery
249252
DuckDB.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_duckdb
250-
if Spanner is not None:
251-
Spanner.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_spanner
252-
if Spangres is not None:
253-
Spangres.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_postgres
253+
if spanner_dialect is not None:
254+
spanner_dialect.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_spanner # type: ignore[attr-defined]
255+
if spangres_dialect is not None:
256+
spangres_dialect.Generator.TRANSFORMS[VectorDistance] = vector_distance_sql_postgres # type: ignore[attr-defined]
254257

255258

256259
_register_with_sqlglot()

sqlspec/extensions/litestar/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def add_sessions_delete_expired_command() -> None:
5252
except ImportError:
5353
return
5454

55-
@sessions_group.command("delete-expired") # type: ignore[misc]
55+
@sessions_group.command("delete-expired") # type: ignore[untyped-decorator]
5656
@click.option(
5757
"--verbose", is_flag=True, default=False, help="Show detailed information about the cleanup operation"
5858
)

tests/integration/test_adapters/test_spanner/conftest.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from collections.abc import Generator
2-
from typing import TYPE_CHECKING, Any, cast
2+
from typing import TYPE_CHECKING
33

44
import pytest
55
from google.api_core import exceptions as api_exceptions
@@ -17,7 +17,9 @@
1717

1818

1919
@pytest.fixture(scope="session")
20-
def spanner_database(spanner_service: SpannerService, spanner_connection: spanner.Client) -> Generator["Database", None, None]:
20+
def spanner_database(
21+
spanner_service: SpannerService, spanner_connection: spanner.Client
22+
) -> Generator["Database", None, None]:
2123
"""Ensure emulator instance and database exist, yield Database."""
2224
instance = spanner_connection.instance(spanner_service.instance_name)
2325
if not instance.exists():

tests/integration/test_adapters/test_spanner/test_extensions/test_adk/test_store.py renamed to tests/integration/test_adapters/test_spanner/test_extensions/test_adk/test_adk_store.py

File renamed without changes.

tests/integration/test_adapters/test_spanner/test_extensions/test_litestar/test_store.py renamed to tests/integration/test_adapters/test_spanner/test_extensions/test_litestar/test_litestar_store.py

File renamed without changes.

0 commit comments

Comments
 (0)