Skip to content

Commit 14691f0

Browse files
committed
fix: improve write handling in Spanner
1 parent 1c6a2ad commit 14691f0

File tree

14 files changed

+1549
-394
lines changed

14 files changed

+1549
-394
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,7 @@ split-on-trailing-comma = false
501501
"sqlspec/builder/mixins/**/*.*" = ["SLF001"]
502502
"sqlspec/extensions/adk/converters.py" = ["S403"]
503503
"sqlspec/migrations/utils.py" = ["S404"]
504+
"sqlspec/adapters/spanner/config.py" = ["PLC2801"]
504505
"tests/**/*.*" = [
505506
"A",
506507
"ARG",

sqlspec/adapters/spanner/_type_handlers.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
"""
99

1010
import base64
11-
from datetime import datetime, timezone
11+
from datetime import date, datetime, timezone
1212
from typing import TYPE_CHECKING, Any
1313
from uuid import UUID
1414

@@ -167,6 +167,8 @@ def infer_spanner_param_types(params: "dict[str, Any] | None") -> "dict[str, Any
167167
types[key] = param_types.BYTES
168168
elif isinstance(value, datetime):
169169
types[key] = param_types.TIMESTAMP
170+
elif isinstance(value, date):
171+
types[key] = param_types.DATE
170172
elif isinstance(value, dict) and hasattr(param_types, "JSON"):
171173
types[key] = param_types.JSON
172174
elif isinstance(value, list):

sqlspec/adapters/spanner/config.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -178,19 +178,37 @@ def _close_pool(self) -> None:
178178
def provide_connection(
179179
self, *args: Any, transaction: "bool" = False, **kwargs: Any
180180
) -> Generator[SpannerConnection, None, None]:
181-
"""Yield a Snapshot (default) or Batch context from the configured pool.
181+
"""Yield a Snapshot (default) or Transaction context from the configured pool.
182182
183-
Note: Spanner does not support database.transaction() as a context manager.
184-
For write operations requiring conditional logic, use database.run_in_transaction()
185-
directly. The `transaction=True` option here uses database.batch() which is
186-
suitable for simple insert/update/delete mutations.
183+
Args:
184+
*args: Additional positional arguments (unused, for interface compatibility).
185+
transaction: If True, yields a Transaction context that supports
186+
execute_update() for DML statements. If False (default), yields
187+
a read-only Snapshot context for SELECT queries.
188+
**kwargs: Additional keyword arguments (unused, for interface compatibility).
189+
190+
Note: For complex transactional logic with retries, use database.run_in_transaction()
191+
directly. The Transaction context here auto-commits on successful exit.
187192
"""
188193
database = self.get_database()
189194
if transaction:
190-
with cast("Any", database).batch() as batch:
191-
yield cast("SpannerConnection", batch)
195+
session = cast("Any", database).session()
196+
session.create()
197+
try:
198+
txn = session.transaction()
199+
txn.__enter__()
200+
try:
201+
yield cast("SpannerConnection", txn)
202+
if hasattr(txn, "_transaction_id") and txn._transaction_id is not None:
203+
txn.commit()
204+
except Exception:
205+
if hasattr(txn, "_transaction_id") and txn._transaction_id is not None:
206+
txn.rollback()
207+
raise
208+
finally:
209+
session.delete()
192210
else:
193-
with cast("Any", database).snapshot() as snapshot:
211+
with cast("Any", database).snapshot(multi_use=True) as snapshot:
194212
yield cast("SpannerConnection", snapshot)
195213

196214
@contextmanager
@@ -209,7 +227,6 @@ def provide_session(
209227
def provide_write_session(
210228
self, *args: Any, statement_config: "StatementConfig | None" = None, **kwargs: Any
211229
) -> Generator[SpannerSyncDriver, None, None]:
212-
"""Convenience wrapper that always yields a write-capable transaction session."""
213230
with self.provide_session(*args, statement_config=statement_config, transaction=True, **kwargs) as driver:
214231
yield driver
215232

sqlspec/adapters/spanner/driver.py

Lines changed: 13 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -204,32 +204,21 @@ def _execute_many(self, cursor: "SpannerConnection", statement: "SQL") -> Execut
204204
raise SQLConversionError(msg)
205205
conn = cast("Any", cursor)
206206

207-
parameter_sets = statement.parameters if isinstance(statement.parameters, list) else []
208-
if not parameter_sets:
207+
sql, prepared_parameters = self._get_compiled_sql(statement, self.statement_config)
208+
209+
if not prepared_parameters or not isinstance(prepared_parameters, list):
209210
msg = "execute_many requires at least one parameter set"
210211
raise SQLConversionError(msg)
211212

212-
base_params = parameter_sets[0]
213-
base_statement = self.prepare_statement(
214-
statement.raw_sql, *[base_params], statement_config=statement.statement_config
215-
)
216-
compiled_sql, _ = self._get_compiled_sql(base_statement, self.statement_config)
217-
218-
batch_inputs: list[dict[str, Any]] = []
219-
for params in parameter_sets:
220-
per_statement = self.prepare_statement(
221-
statement.raw_sql, *[params], statement_config=statement.statement_config
222-
)
223-
_, processed_params = self._get_compiled_sql(per_statement, self.statement_config)
224-
coerced_params = self._coerce_params(processed_params)
213+
batch_args: list[tuple[str, dict[str, Any] | None, dict[str, Any]]] = []
214+
for params in prepared_parameters:
215+
coerced_params = self._coerce_params(params)
225216
if coerced_params is None:
226217
coerced_params = {}
227-
batch_inputs.append(coerced_params)
228-
229-
batch_args = [(compiled_sql, p, self._infer_param_types(p)) for p in batch_inputs]
218+
batch_args.append((sql, coerced_params, self._infer_param_types(coerced_params)))
230219

231-
row_counts = conn.batch_update(batch_args)
232-
total_rows = int(sum(int(count) for count in row_counts))
220+
_status, row_counts = conn.batch_update(batch_args)
221+
total_rows = sum(row_counts) if row_counts else 0
233222

234223
return self.create_execution_result(cursor, rowcount_override=total_rows, is_many_result=True)
235224

@@ -350,28 +339,28 @@ def _truncate_table_sync(self, table: str) -> None:
350339

351340

352341
def _build_spanner_profile() -> DriverParameterProfile:
353-
type_coercions: dict[type, Any] = {dict: to_json, list: to_json, tuple: to_json}
342+
type_coercions: dict[type, Any] = {dict: to_json}
354343
return DriverParameterProfile(
355344
name="Spanner",
356345
default_style=ParameterStyle.NAMED_AT,
357346
supported_styles={ParameterStyle.NAMED_AT},
358347
default_execution_style=ParameterStyle.NAMED_AT,
359348
supported_execution_styles={ParameterStyle.NAMED_AT},
360349
has_native_list_expansion=True,
361-
json_serializer_strategy="helper",
350+
json_serializer_strategy="none",
362351
default_dialect="spanner",
363352
preserve_parameter_format=True,
364353
needs_static_script_compilation=False,
365354
allow_mixed_parameter_styles=False,
366355
preserve_original_params_for_many=True,
367356
custom_type_coercions=type_coercions,
368-
extras={"type_coercion_overrides": type_coercions},
357+
extras={},
369358
)
370359

371360

372361
_SPANNER_PROFILE = _build_spanner_profile()
373362
register_driver_profile("spanner", _SPANNER_PROFILE)
374363

375364
spanner_statement_config = build_statement_config_from_profile(
376-
_SPANNER_PROFILE, statement_overrides={"dialect": "spanner"}, json_serializer=to_json
365+
_SPANNER_PROFILE, statement_overrides={"dialect": "spanner"}
377366
)

tests/integration/test_adapters/test_spanner/conftest.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def spanner_database(
3535

3636

3737
@pytest.fixture
38-
def spanner_config(spanner_service: SpannerService, spanner_connection: spanner.Client) -> SpannerSyncConfig:
38+
def spanner_config(
39+
spanner_service: SpannerService, spanner_connection: spanner.Client, spanner_database: "Database"
40+
) -> SpannerSyncConfig:
41+
"""Create SpannerSyncConfig after ensuring database exists."""
42+
_ = spanner_database # Ensure database is created before config
3943
api_endpoint = f"{spanner_service.host}:{spanner_service.port}"
4044

4145
return SpannerSyncConfig(
@@ -53,12 +57,27 @@ def spanner_config(spanner_service: SpannerService, spanner_connection: spanner.
5357

5458
@pytest.fixture
5559
def spanner_session(spanner_config: SpannerSyncConfig) -> Generator[SpannerSyncDriver, None, None]:
60+
"""Read-only session for SELECT operations."""
5661
sql = SQLSpec()
5762
c = sql.add_config(spanner_config)
5863
with sql.provide_session(c) as session:
5964
yield session
6065

6166

67+
@pytest.fixture
68+
def spanner_write_session(spanner_config: SpannerSyncConfig) -> Generator[SpannerSyncDriver, None, None]:
69+
"""Write-capable session for DML operations (INSERT/UPDATE/DELETE)."""
70+
with spanner_config.provide_write_session() as session:
71+
yield session
72+
73+
74+
@pytest.fixture
75+
def spanner_read_session(spanner_config: SpannerSyncConfig) -> Generator[SpannerSyncDriver, None, None]:
76+
"""Read-only session for SELECT operations."""
77+
with spanner_config.provide_session() as session:
78+
yield session
79+
80+
6281
def run_ddl(database: "Database", statements: "list[str]", timeout: int = 300) -> None:
6382
"""Execute DDL statements on Spanner database."""
6483
operation = database.update_ddl(statements) # type: ignore[no-untyped-call]
Lines changed: 33 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,27 @@
1-
"""Integration tests for Spanner Arrow support."""
1+
"""Integration tests for Spanner Arrow support.
22
3-
from typing import TYPE_CHECKING, Any
3+
All operations use SQLSpec interface, not raw SDK calls.
4+
"""
45

56
import pytest
67

78
from sqlspec._typing import PYARROW_INSTALLED
89
from sqlspec.adapters.spanner import SpannerSyncConfig
910

10-
if TYPE_CHECKING:
11-
from google.cloud.spanner_v1.database import Database
12-
1311
pytestmark = [pytest.mark.spanner, pytest.mark.skipif(not PYARROW_INSTALLED, reason="pyarrow not installed")]
1412

1513

16-
def test_select_to_arrow_basic(
17-
spanner_config: SpannerSyncConfig, spanner_database: "Database", test_arrow_table: str
18-
) -> None:
14+
def test_select_to_arrow_basic(spanner_config: SpannerSyncConfig, test_arrow_table: str) -> None:
1915
"""Test basic select_to_arrow functionality."""
2016
import pyarrow as pa
2117

22-
database = spanner_database
23-
24-
def insert_data(transaction: "Any") -> None:
18+
with spanner_config.provide_write_session() as session:
2519
for i, name in enumerate(["Alice", "Bob", "Charlie"], start=1):
26-
transaction.execute_update(
20+
session.execute(
2721
f"INSERT INTO {test_arrow_table} (id, name, value) VALUES (@id, @name, @value)",
28-
params={"id": i, "name": name, "value": i * 10},
29-
param_types={"id": {"code": "INT64"}, "name": {"code": "STRING"}, "value": {"code": "INT64"}},
22+
{"id": i, "name": name, "value": i * 10},
3023
)
3124

32-
database.run_in_transaction(insert_data) # type: ignore[no-untyped-call]
33-
3425
with spanner_config.provide_session() as session:
3526
result = session.select_to_arrow(f"SELECT * FROM {test_arrow_table} ORDER BY id")
3627

@@ -43,28 +34,19 @@ def insert_data(transaction: "Any") -> None:
4334
assert list(df["name"]) == ["Alice", "Bob", "Charlie"]
4435
assert list(df["value"]) == [10, 20, 30]
4536

46-
def cleanup(transaction: "Any") -> None:
47-
transaction.execute_update(f"DELETE FROM {test_arrow_table} WHERE TRUE")
48-
49-
database.run_in_transaction(cleanup) # type: ignore[no-untyped-call]
37+
with spanner_config.provide_write_session() as session:
38+
session.execute(f"DELETE FROM {test_arrow_table} WHERE TRUE")
5039

5140

52-
def test_select_to_arrow_with_parameters(
53-
spanner_config: SpannerSyncConfig, spanner_database: "Database", test_arrow_table: str
54-
) -> None:
41+
def test_select_to_arrow_with_parameters(spanner_config: SpannerSyncConfig, test_arrow_table: str) -> None:
5542
"""Test select_to_arrow with parameterized query."""
56-
database = spanner_database
57-
58-
def insert_data(transaction: "Any") -> None:
43+
with spanner_config.provide_write_session() as session:
5944
for i in range(1, 6):
60-
transaction.execute_update(
45+
session.execute(
6146
f"INSERT INTO {test_arrow_table} (id, name, value) VALUES (@id, @name, @value)",
62-
params={"id": i, "name": f"Item {i}", "value": i * 100},
63-
param_types={"id": {"code": "INT64"}, "name": {"code": "STRING"}, "value": {"code": "INT64"}},
47+
{"id": i, "name": f"Item {i}", "value": i * 100},
6448
)
6549

66-
database.run_in_transaction(insert_data) # type: ignore[no-untyped-call]
67-
6850
with spanner_config.provide_session() as session:
6951
result = session.select_to_arrow(
7052
f"SELECT * FROM {test_arrow_table} WHERE value > @min_value ORDER BY id", {"min_value": 200}
@@ -74,10 +56,8 @@ def insert_data(transaction: "Any") -> None:
7456
df = result.to_pandas()
7557
assert list(df["value"]) == [300, 400, 500]
7658

77-
def cleanup(transaction: "Any") -> None:
78-
transaction.execute_update(f"DELETE FROM {test_arrow_table} WHERE TRUE")
79-
80-
database.run_in_transaction(cleanup) # type: ignore[no-untyped-call]
59+
with spanner_config.provide_write_session() as session:
60+
session.execute(f"DELETE FROM {test_arrow_table} WHERE TRUE")
8161

8262

8363
def test_select_to_arrow_empty_result(spanner_config: SpannerSyncConfig, test_arrow_table: str) -> None:
@@ -89,91 +69,65 @@ def test_select_to_arrow_empty_result(spanner_config: SpannerSyncConfig, test_ar
8969
assert len(result.to_pandas()) == 0
9070

9171

92-
def test_select_to_arrow_table_format(
93-
spanner_config: SpannerSyncConfig, spanner_database: "Database", test_arrow_table: str
94-
) -> None:
72+
def test_select_to_arrow_table_format(spanner_config: SpannerSyncConfig, test_arrow_table: str) -> None:
9573
"""Test select_to_arrow with table return format (default)."""
9674
import pyarrow as pa
9775

98-
database = spanner_database
99-
100-
def insert_data(transaction: "Any") -> None:
76+
with spanner_config.provide_write_session() as session:
10177
for i in range(1, 4):
102-
transaction.execute_update(
78+
session.execute(
10379
f"INSERT INTO {test_arrow_table} (id, name, value) VALUES (@id, @name, @value)",
104-
params={"id": i, "name": f"Row {i}", "value": i},
105-
param_types={"id": {"code": "INT64"}, "name": {"code": "STRING"}, "value": {"code": "INT64"}},
80+
{"id": i, "name": f"Row {i}", "value": i},
10681
)
10782

108-
database.run_in_transaction(insert_data) # type: ignore[no-untyped-call]
109-
11083
with spanner_config.provide_session() as session:
11184
result = session.select_to_arrow(f"SELECT * FROM {test_arrow_table} ORDER BY id", return_format="table")
11285

11386
assert isinstance(result.data, pa.Table)
11487
assert result.rows_affected == 3
11588

116-
def cleanup(transaction: "Any") -> None:
117-
transaction.execute_update(f"DELETE FROM {test_arrow_table} WHERE TRUE")
118-
119-
database.run_in_transaction(cleanup) # type: ignore[no-untyped-call]
89+
with spanner_config.provide_write_session() as session:
90+
session.execute(f"DELETE FROM {test_arrow_table} WHERE TRUE")
12091

12192

122-
def test_select_to_arrow_batch_format(
123-
spanner_config: SpannerSyncConfig, spanner_database: "Database", test_arrow_table: str
124-
) -> None:
93+
def test_select_to_arrow_batch_format(spanner_config: SpannerSyncConfig, test_arrow_table: str) -> None:
12594
"""Test select_to_arrow with batch return format."""
12695
import pyarrow as pa
12796

128-
database = spanner_database
129-
130-
def insert_data(transaction: "Any") -> None:
97+
with spanner_config.provide_write_session() as session:
13198
for i in range(1, 3):
132-
transaction.execute_update(
99+
session.execute(
133100
f"INSERT INTO {test_arrow_table} (id, name, value) VALUES (@id, @name, @value)",
134-
params={"id": i, "name": f"Batch {i}", "value": i * 5},
135-
param_types={"id": {"code": "INT64"}, "name": {"code": "STRING"}, "value": {"code": "INT64"}},
101+
{"id": i, "name": f"Batch {i}", "value": i * 5},
136102
)
137103

138-
database.run_in_transaction(insert_data) # type: ignore[no-untyped-call]
139-
140104
with spanner_config.provide_session() as session:
141105
result = session.select_to_arrow(f"SELECT * FROM {test_arrow_table} ORDER BY id", return_format="batch")
142106

143107
assert isinstance(result.data, pa.RecordBatch)
144108
assert result.rows_affected == 2
145109

146-
def cleanup(transaction: "Any") -> None:
147-
transaction.execute_update(f"DELETE FROM {test_arrow_table} WHERE TRUE")
110+
with spanner_config.provide_write_session() as session:
111+
session.execute(f"DELETE FROM {test_arrow_table} WHERE TRUE")
148112

149-
database.run_in_transaction(cleanup) # type: ignore[no-untyped-call]
150113

151-
152-
def test_select_to_arrow_to_polars(
153-
spanner_config: SpannerSyncConfig, spanner_database: "Database", test_arrow_table: str
154-
) -> None:
114+
def test_select_to_arrow_to_polars(spanner_config: SpannerSyncConfig, test_arrow_table: str) -> None:
155115
"""Test select_to_arrow conversion to Polars DataFrame."""
156116
pytest.importorskip("polars")
157-
database = spanner_database
158117

159-
def insert_data(transaction: "Any") -> None:
118+
with spanner_config.provide_write_session() as session:
160119
for i in range(1, 3):
161-
transaction.execute_update(
120+
session.execute(
162121
f"INSERT INTO {test_arrow_table} (id, name, value) VALUES (@id, @name, @value)",
163-
params={"id": i, "name": f"Polars {i}", "value": i * 7},
164-
param_types={"id": {"code": "INT64"}, "name": {"code": "STRING"}, "value": {"code": "INT64"}},
122+
{"id": i, "name": f"Polars {i}", "value": i * 7},
165123
)
166124

167-
database.run_in_transaction(insert_data) # type: ignore[no-untyped-call]
168-
169125
with spanner_config.provide_session() as session:
170126
result = session.select_to_arrow(f"SELECT * FROM {test_arrow_table} ORDER BY id")
171127
df = result.to_polars()
172128

173129
assert len(df) == 2
174130
assert df["name"].to_list() == ["Polars 1", "Polars 2"]
175131

176-
def cleanup(transaction: "Any") -> None:
177-
transaction.execute_update(f"DELETE FROM {test_arrow_table} WHERE TRUE")
178-
179-
database.run_in_transaction(cleanup) # type: ignore[no-untyped-call]
132+
with spanner_config.provide_write_session() as session:
133+
session.execute(f"DELETE FROM {test_arrow_table} WHERE TRUE")

0 commit comments

Comments
 (0)