Skip to content

Commit f06f4bd

Browse files
committed
refactor: rename internal classes and update related tests for clarity
1 parent 14691f0 commit f06f4bd

File tree

2 files changed

+105
-28
lines changed

2 files changed

+105
-28
lines changed

sqlspec/extensions/litestar/plugin.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,16 @@
7171
CORRELATION_STATE_KEY = "sqlspec_correlation_id"
7272

7373
__all__ = (
74+
"CORRELATION_STATE_KEY",
7475
"DEFAULT_COMMIT_MODE",
7576
"DEFAULT_CONNECTION_KEY",
77+
"DEFAULT_CORRELATION_HEADER",
7678
"DEFAULT_POOL_KEY",
7779
"DEFAULT_SESSION_KEY",
80+
"TRACE_CONTEXT_FALLBACK_HEADERS",
7881
"CommitMode",
82+
"CorrelationMiddleware",
83+
"PluginConfigState",
7984
"SQLSpecPlugin",
8085
)
8186

@@ -117,7 +122,7 @@ def _build_correlation_headers(*, primary: str, configured: list[str], auto_trac
117122
return tuple(_dedupe_headers(header_order))
118123

119124

120-
class _CorrelationMiddleware:
125+
class CorrelationMiddleware:
121126
__slots__ = ("_app", "_headers")
122127

123128
def __init__(self, app: "ASGIApp", *, headers: tuple[str, ...]) -> None:
@@ -153,7 +158,7 @@ async def __call__(self, scope: "Scope", receive: "Receive", send: "Send") -> No
153158

154159

155160
@dataclass
156-
class _PluginConfigState:
161+
class PluginConfigState:
157162
"""Internal state for each database configuration."""
158163

159164
config: "DatabaseConfigProtocol[Any, Any, Any]"
@@ -219,7 +224,7 @@ def __init__(self, sqlspec: SQLSpec, *, loader: "SQLFileLoader | None" = None) -
219224
"""
220225
self._sqlspec = sqlspec
221226

222-
self._plugin_configs: list[_PluginConfigState] = []
227+
self._plugin_configs: list[PluginConfigState] = []
223228
for cfg in self._sqlspec.configs.values():
224229
config_union = cast(
225230
"SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]",
@@ -276,9 +281,9 @@ def _create_config_state(
276281
self,
277282
config: "SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]",
278283
settings: "dict[str, Any]",
279-
) -> _PluginConfigState:
284+
) -> PluginConfigState:
280285
"""Create plugin state with handlers for the given configuration."""
281-
state = _PluginConfigState(
286+
state = PluginConfigState(
282287
config=config,
283288
connection_key=settings["connection_key"],
284289
pool_key=settings["pool_key"],
@@ -296,7 +301,7 @@ def _create_config_state(
296301
self._setup_handlers(state)
297302
return state
298303

299-
def _setup_handlers(self, state: _PluginConfigState) -> None:
304+
def _setup_handlers(self, state: PluginConfigState) -> None:
300305
"""Setup handlers for the plugin state."""
301306
connection_key = state.connection_key
302307
pool_key = state.pool_key
@@ -403,7 +408,7 @@ def store_sqlspec_in_state() -> None:
403408
app_config.type_decoders = decoders_list
404409

405410
if self._correlation_headers:
406-
middleware = DefineMiddleware(_CorrelationMiddleware, headers=self._correlation_headers)
411+
middleware = DefineMiddleware(CorrelationMiddleware, headers=self._correlation_headers)
407412
existing_middleware = list(app_config.middleware or [])
408413
existing_middleware.append(middleware)
409414
app_config.middleware = existing_middleware
@@ -579,7 +584,7 @@ def provide_request_connection(
579584

580585
def _get_plugin_state(
581586
self, key: "str | SyncConfigT | AsyncConfigT | type[SyncConfigT | AsyncConfigT]"
582-
) -> _PluginConfigState:
587+
) -> PluginConfigState:
583588
"""Get plugin state for a configuration by key."""
584589
if isinstance(key, str):
585590
for state in self._plugin_configs:

tests/unit/test_adapters/test_spanner/test_config.py

Lines changed: 92 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ def test_driver_features_defaults() -> None:
5353

5454

5555
def test_provide_connection_batch_and_snapshot() -> None:
56-
"""Ensure provide_connection selects snapshot vs batch correctly."""
57-
batch_obj = object()
56+
"""Ensure provide_connection selects snapshot vs transaction correctly."""
5857
snap_obj = object()
5958

6059
class _Ctx:
@@ -67,39 +66,88 @@ def __enter__(self):
6766
def __exit__(self, *_):
6867
return False
6968

69+
class _Txn:
70+
_transaction_id = "test-txn-id"
71+
72+
def __enter__(self):
73+
return self
74+
75+
def __exit__(self, *_):
76+
return False
77+
78+
def commit(self):
79+
pass
80+
81+
def rollback(self):
82+
pass
83+
84+
class _Session:
85+
def create(self):
86+
pass
87+
88+
def delete(self):
89+
pass
90+
91+
def transaction(self):
92+
return _Txn()
93+
7094
class _DB:
71-
def batch(self):
72-
return _Ctx(batch_obj)
95+
def session(self):
96+
return _Session()
7397

74-
def snapshot(self):
98+
def snapshot(self, multi_use: bool = False):
7599
return _Ctx(snap_obj)
76100

77101
config = SpannerSyncConfig(pool_config={"project": "p", "instance_id": "i", "database_id": "d"})
78102
config.get_database = lambda: _DB() # type: ignore[assignment]
79103

80104
with config.provide_connection(transaction=True) as conn:
81-
assert conn is batch_obj
105+
assert isinstance(conn, _Txn)
82106

83107
with config.provide_connection(transaction=False) as conn:
84108
assert conn is snap_obj
85109

86110

87111
def test_provide_session_uses_batch_when_transaction_requested() -> None:
88-
"""Driver should receive batch connection when transaction=True."""
89-
batch_obj = object()
112+
"""Driver should receive transaction connection when transaction=True."""
113+
114+
class _Txn:
115+
_transaction_id = "test-txn-id"
116+
117+
def __enter__(self):
118+
return self
119+
120+
def __exit__(self, *_):
121+
return False
122+
123+
def commit(self):
124+
pass
125+
126+
def rollback(self):
127+
pass
128+
129+
class _Session:
130+
def create(self):
131+
pass
132+
133+
def delete(self):
134+
pass
135+
136+
def transaction(self):
137+
return _Txn()
90138

91139
class _Ctx:
92140
def __enter__(self):
93-
return batch_obj
141+
return object()
94142

95143
def __exit__(self, *_):
96144
return False
97145

98146
class _DB:
99-
def batch(self):
100-
return _Ctx()
147+
def session(self):
148+
return _Session()
101149

102-
def snapshot(self):
150+
def snapshot(self, multi_use: bool = False):
103151
return _Ctx()
104152

105153
config = SpannerSyncConfig(pool_config={"project": "p", "instance_id": "i", "database_id": "d"})
@@ -108,30 +156,54 @@ def snapshot(self):
108156

109157
with config.provide_session(transaction=True) as driver:
110158
assert isinstance(driver, _DummyDriver)
111-
assert driver.connection is batch_obj
159+
assert isinstance(driver.connection, _Txn)
112160

113161

114162
def test_provide_write_session_alias() -> None:
115-
"""provide_write_session should always give a batch-backed driver."""
116-
batch_obj = object()
163+
"""provide_write_session should always give a transaction-backed driver."""
164+
165+
class _Txn:
166+
_transaction_id = "test-txn-id"
167+
168+
def __enter__(self):
169+
return self
170+
171+
def __exit__(self, *_):
172+
return False
173+
174+
def commit(self):
175+
pass
176+
177+
def rollback(self):
178+
pass
179+
180+
class _Session:
181+
def create(self):
182+
pass
183+
184+
def delete(self):
185+
pass
186+
187+
def transaction(self):
188+
return _Txn()
117189

118190
class _Ctx:
119191
def __enter__(self):
120-
return batch_obj
192+
return object()
121193

122194
def __exit__(self, *_):
123195
return False
124196

125197
class _DB:
126-
def batch(self):
127-
return _Ctx()
198+
def session(self):
199+
return _Session()
128200

129-
def snapshot(self):
201+
def snapshot(self, multi_use: bool = False):
130202
return _Ctx()
131203

132204
config = SpannerSyncConfig(pool_config={"project": "p", "instance_id": "i", "database_id": "d"})
133205
config.get_database = lambda: _DB() # type: ignore[assignment]
134206
config.driver_type = _DummyDriver # type: ignore[assignment,misc]
135207

136208
with config.provide_write_session() as driver:
137-
assert driver.connection is batch_obj
209+
assert isinstance(driver.connection, _Txn)

0 commit comments

Comments
 (0)