Skip to content

Commit a07199f

Browse files
committed
feat(events): improve connection handling and error management across event backends
1 parent 1471480 commit a07199f

File tree

22 files changed

+119
-72
lines changed

22 files changed

+119
-72
lines changed

sqlspec/adapters/asyncpg/events/backend.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pyright: reportPrivateUsage=false
12
"""Native and hybrid PostgreSQL backends for EventChannel."""
23

34
import asyncio
@@ -68,7 +69,8 @@ async def _ensure_listener(self, channel: str) -> Any:
6869
if self._listen_connection is None:
6970
self._listen_connection_cm = self._config.provide_connection()
7071
self._listen_connection = await self._listen_connection_cm.__aenter__()
71-
await self._listen_connection.execute(f"LISTEN {channel}")
72+
if self._listen_connection is not None:
73+
await self._listen_connection.execute(f"LISTEN {channel}")
7274
return self._listen_connection
7375

7476
async def ack_async(self, event_id: str) -> None:
@@ -240,7 +242,8 @@ async def _ensure_listener(self, channel: str) -> Any:
240242
self._notify_mode = "add_listener"
241243
elif getattr(self._listen_connection, "notifies", None) is not None:
242244
self._notify_mode = "notifies"
243-
await self._listen_connection.execute(f"LISTEN {channel}")
245+
if self._listen_connection is not None:
246+
await self._listen_connection.execute(f"LISTEN {channel}")
244247
else:
245248
msg = "PostgreSQL connection does not support LISTEN/NOTIFY callbacks"
246249
raise EventChannelError(msg)

sqlspec/adapters/oracledb/events/backend.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def __init__(self, config: "DatabaseConfigProtocol[Any, Any, Any]", settings: di
5151
def publish_sync(self, channel: str, payload: dict[str, Any], metadata: dict[str, Any] | None = None) -> str:
5252
event_id = uuid.uuid4().hex
5353
envelope = self._build_envelope(channel, event_id, payload, metadata)
54-
with self._config.provide_session() as driver:
54+
session_cm = self._config.provide_session()
55+
with session_cm as driver: # type: ignore[union-attr]
5556
connection = getattr(driver, "connection", None)
5657
if connection is None:
5758
msg = "Oracle driver does not expose a raw connection"
@@ -67,7 +68,8 @@ async def publish_async(self, *_: Any, **__: Any) -> str: # pragma: no cover -
6768
raise ImproperConfigurationError(msg)
6869

6970
def dequeue_sync(self, channel: str, poll_interval: float) -> EventMessage | None:
70-
with self._config.provide_session() as driver:
71+
session_cm = self._config.provide_session()
72+
with session_cm as driver: # type: ignore[union-attr]
7173
connection = getattr(driver, "connection", None)
7274
if connection is None:
7375
msg = "Oracle driver does not expose a raw connection"
@@ -76,10 +78,13 @@ def dequeue_sync(self, channel: str, poll_interval: float) -> EventMessage | Non
7678
options = oracledb.AQDequeueOptions() # type: ignore[attr-defined]
7779
options.wait = max(int(self._wait_seconds), 0)
7880
if self._visibility:
79-
options.visibility = getattr(oracledb, self._visibility, None) or oracledb.AQMSG_VISIBLE
81+
default_visibility = getattr(oracledb, "AQMSG_VISIBLE", None)
82+
options.visibility = getattr(oracledb, self._visibility, None) or default_visibility
8083
try:
8184
message = queue.deqone(options=options)
82-
except oracledb.DatabaseError as error: # pragma: no cover - driver surfaced runtime
85+
except Exception as error: # pragma: no cover - driver surfaced runtime
86+
if oracledb is None or not isinstance(error, oracledb.DatabaseError):
87+
raise
8388
logger.warning("Oracle AQ dequeue failed: %s", error)
8489
driver.rollback()
8590
return None

sqlspec/adapters/psycopg/events/backend.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pyright: reportPrivateUsage=false
12
"""Psycopg LISTEN/NOTIFY and hybrid event backends."""
23

34
import contextlib
@@ -55,7 +56,8 @@ async def publish_async(
5556
) -> str:
5657
event_id = uuid.uuid4().hex
5758
envelope = self._encode_payload(event_id, payload, metadata)
58-
async with self._config.provide_session() as driver:
59+
session_cm = self._config.provide_session()
60+
async with session_cm as driver: # type: ignore[union-attr]
5961
await driver.execute(SQL("SELECT pg_notify(:channel, :payload)", {"channel": channel, "payload": envelope}))
6062
await driver.commit()
6163
self._runtime.increment_metric("events.publish.native")
@@ -64,7 +66,8 @@ async def publish_async(
6466
def publish_sync(self, channel: str, payload: "dict[str, Any]", metadata: "dict[str, Any] | None" = None) -> str:
6567
event_id = uuid.uuid4().hex
6668
envelope = self._encode_payload(event_id, payload, metadata)
67-
with self._config.provide_session() as driver:
69+
session_cm = self._config.provide_session()
70+
with session_cm as driver: # type: ignore[union-attr]
6871
driver.execute(SQL("SELECT pg_notify(:channel, :payload)", {"channel": channel, "payload": envelope}))
6972
driver.commit()
7073
self._runtime.increment_metric("events.publish.native")
@@ -122,9 +125,10 @@ async def _ensure_async_listener(self, channel: str) -> Any:
122125
if self._listen_connection_async is None:
123126
validated_channel = normalize_event_channel_name(channel)
124127
self._listen_connection_async_cm = self._config.provide_connection()
125-
self._listen_connection_async = await self._listen_connection_async_cm.__aenter__()
126-
await self._listen_connection_async.set_autocommit(True)
127-
await self._listen_connection_async.execute(f"LISTEN {validated_channel}")
128+
self._listen_connection_async = await self._listen_connection_async_cm.__aenter__() # type: ignore[union-attr]
129+
if self._listen_connection_async is not None:
130+
await self._listen_connection_async.set_autocommit(True)
131+
await self._listen_connection_async.execute(f"LISTEN {validated_channel}")
128132
return self._listen_connection_async
129133

130134
def _ensure_sync_listener(self, channel: str) -> Any:
@@ -139,9 +143,10 @@ def _ensure_sync_listener(self, channel: str) -> Any:
139143
if self._listen_connection_sync is None:
140144
validated_channel = normalize_event_channel_name(channel)
141145
self._listen_connection_sync_cm = self._config.provide_connection()
142-
self._listen_connection_sync = self._listen_connection_sync_cm.__enter__()
143-
self._listen_connection_sync.autocommit = True
144-
self._listen_connection_sync.execute(f"LISTEN {validated_channel}")
146+
self._listen_connection_sync = self._listen_connection_sync_cm.__enter__() # type: ignore[union-attr]
147+
if self._listen_connection_sync is not None:
148+
self._listen_connection_sync.autocommit = True
149+
self._listen_connection_sync.execute(f"LISTEN {validated_channel}")
145150
return self._listen_connection_sync
146151

147152
async def shutdown_async(self) -> None:
@@ -280,9 +285,10 @@ async def _ensure_async_listener(self, channel: str) -> Any:
280285
if self._listen_connection_async is None:
281286
validated_channel = normalize_event_channel_name(channel)
282287
self._listen_connection_async_cm = self._config.provide_connection()
283-
self._listen_connection_async = await self._listen_connection_async_cm.__aenter__()
284-
await self._listen_connection_async.set_autocommit(True)
285-
await self._listen_connection_async.execute(f"LISTEN {validated_channel}")
288+
self._listen_connection_async = await self._listen_connection_async_cm.__aenter__() # type: ignore[union-attr]
289+
if self._listen_connection_async is not None:
290+
await self._listen_connection_async.set_autocommit(True)
291+
await self._listen_connection_async.execute(f"LISTEN {validated_channel}")
286292
return self._listen_connection_async
287293

288294
def _ensure_sync_listener(self, channel: str) -> Any:
@@ -297,9 +303,10 @@ def _ensure_sync_listener(self, channel: str) -> Any:
297303
if self._listen_connection_sync is None:
298304
validated_channel = normalize_event_channel_name(channel)
299305
self._listen_connection_sync_cm = self._config.provide_connection()
300-
self._listen_connection_sync = self._listen_connection_sync_cm.__enter__()
301-
self._listen_connection_sync.autocommit = True
302-
self._listen_connection_sync.execute(f"LISTEN {validated_channel}")
306+
self._listen_connection_sync = self._listen_connection_sync_cm.__enter__() # type: ignore[union-attr]
307+
if self._listen_connection_sync is not None:
308+
self._listen_connection_sync.autocommit = True
309+
self._listen_connection_sync.execute(f"LISTEN {validated_channel}")
303310
return self._listen_connection_sync
304311

305312
async def ack_async(self, event_id: str) -> None:
@@ -364,7 +371,8 @@ async def _publish_durable_async(
364371
"""
365372
now = datetime.now(timezone.utc)
366373
queue = self._get_table_queue()
367-
async with self._config.provide_session() as driver:
374+
session_cm = self._config.provide_session()
375+
async with session_cm as driver: # type: ignore[union-attr]
368376
await driver.execute(
369377
SQL(
370378
queue._upsert_sql,
@@ -403,7 +411,8 @@ def _publish_durable_sync(
403411
"""
404412
now = datetime.now(timezone.utc)
405413
queue = self._get_table_queue()
406-
with self._config.provide_session() as driver:
414+
session_cm = self._config.provide_session()
415+
with session_cm as driver: # type: ignore[union-attr]
407416
driver.execute(
408417
SQL(
409418
queue._upsert_sql,

sqlspec/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,11 @@ def event_channel(self, config: "type[SyncConfigT | AsyncConfigT] | SyncConfigT
165165
from sqlspec.extensions.events import EventChannel
166166

167167
if isinstance(config, type):
168-
config_obj = self._configs.get(config)
168+
config_obj: DatabaseConfigProtocol[Any, Any, Any] | None = None
169+
for registered_config in self._configs.values():
170+
if isinstance(registered_config, config):
171+
config_obj = registered_config
172+
break
169173
if config_obj is None:
170174
msg = f"Configuration {self._get_config_name(config)} is not registered"
171175
raise ImproperConfigurationError(msg)

sqlspec/extensions/events/_queue.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -270,8 +270,9 @@ def _cleanup_sync(self, reference: "datetime") -> None:
270270

271271
async def _fetch_candidate_async(self, channel: str) -> "dict[str, Any] | None":
272272
current_time = self._utcnow()
273-
async with self._config.provide_session() as driver:
274-
return await driver.select_one_or_none(
273+
session_cm = self._config.provide_session()
274+
async with session_cm as driver: # type: ignore[union-attr]
275+
result: dict[str, Any] | None = await driver.select_one_or_none(
275276
SQL(
276277
self._select_sql,
277278
{
@@ -284,11 +285,13 @@ async def _fetch_candidate_async(self, channel: str) -> "dict[str, Any] | None":
284285
statement_config=self._statement_config,
285286
)
286287
)
288+
return result
287289

288290
def _fetch_candidate_sync(self, channel: str) -> "dict[str, Any] | None":
289291
current_time = self._utcnow()
290-
with self._config.provide_session() as driver:
291-
return driver.select_one_or_none(
292+
session_cm = self._config.provide_session()
293+
with session_cm as driver: # type: ignore[union-attr]
294+
result: dict[str, Any] | None = driver.select_one_or_none(
292295
SQL(
293296
self._select_sql,
294297
{
@@ -301,20 +304,23 @@ def _fetch_candidate_sync(self, channel: str) -> "dict[str, Any] | None":
301304
statement_config=self._statement_config,
302305
)
303306
)
307+
return result
304308

305309
async def _execute_async(self, sql: str, parameters: "dict[str, Any]") -> int:
306-
async with self._config.provide_session() as driver:
310+
session_cm = self._config.provide_session()
311+
async with session_cm as driver: # type: ignore[union-attr]
307312
result = await driver.execute(SQL(sql, parameters, statement_config=self._statement_config))
308313
if result.rows_affected:
309314
await driver.commit()
310-
return result.rows_affected
315+
return int(result.rows_affected)
311316

312317
def _execute_sync(self, sql: str, parameters: "dict[str, Any]") -> int:
313-
with self._config.provide_session() as driver:
318+
session_cm = self._config.provide_session()
319+
with session_cm as driver: # type: ignore[union-attr]
314320
result = driver.execute(SQL(sql, parameters, statement_config=self._statement_config))
315321
if result.rows_affected:
316322
driver.commit()
317-
return result.rows_affected
323+
return int(result.rows_affected)
318324

319325
@staticmethod
320326
def _coerce_datetime(value: Any) -> "datetime":

sqlspec/extensions/events/channel.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ def publish_sync(self, channel: str, payload: "dict[str, Any]", metadata: "dict[
149149
channel = self._normalize_channel_name(channel)
150150
if self._is_async:
151151
if self._should_bridge_sync_calls():
152-
return self._bridge_sync_call(self.publish_async, channel, payload, metadata)
152+
result: str = self._bridge_sync_call(self.publish_async, channel, payload, metadata)
153+
return result
153154
msg = "publish_sync requires a sync configuration"
154155
raise ImproperConfigurationError(msg)
155156
if not getattr(self._backend, "supports_sync", False):
@@ -456,11 +457,14 @@ def _dequeue_for_sync(self, channel: str, poll_interval: float) -> "EventMessage
456457
except Exception as error:
457458
self._end_event_span(span, error=error)
458459
raise
459-
result = "empty" if event is None else "delivered"
460-
self._end_event_span(span, result=result)
460+
span_result = "empty" if event is None else "delivered"
461+
self._end_event_span(span, result=span_result)
461462
return event
462463
if self._should_bridge_sync_calls():
463-
return self._bridge_sync_call(self._dequeue_async_with_span, channel, poll_interval)
464+
bridged_result: EventMessage | None = self._bridge_sync_call(
465+
self._dequeue_async_with_span, channel, poll_interval
466+
)
467+
return bridged_result
464468
return None
465469

466470
def _ack_for_sync(self, event_id: str) -> None:
@@ -489,7 +493,7 @@ def _bridge_sync_call(self, func: Any, *args: Any, **kwargs: Any) -> Any:
489493

490494
def _ensure_portal(self) -> Any:
491495
if self._portal is None:
492-
self._portal = get_global_portal()
496+
self._portal = get_global_portal() # type: ignore[assignment]
493497
return self._portal
494498

495499
@staticmethod

tests/integration/test_adapters/test_asyncmy/test_extensions/test_events.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pyright: reportPrivateUsage=false, reportAttributeAccessIssue=false
12
"""AsyncMy integration tests for the EventChannel queue backend."""
23

34
import pytest

tests/integration/test_adapters/test_asyncpg/test_extensions/test_events.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pyright: reportPrivateUsage=false
12
"""AsyncPG integration tests for the EventChannel native backend."""
23

34
import pytest

tests/integration/test_adapters/test_asyncpg/test_extensions/test_events_listen_notify.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pyright: reportPrivateUsage=false
12
"""PostgreSQL LISTEN/NOTIFY event channel tests for asyncpg adapter."""
23

34
import asyncio
@@ -58,8 +59,8 @@ async def test_asyncpg_listen_notify_message_delivery(postgres_service: "Any") -
5859

5960
received: list[Any] = []
6061

61-
async def _handler(msg: Any) -> None:
62-
received.append(msg)
62+
async def _handler(message: Any) -> None:
63+
received.append(message)
6364

6465
listener = channel.listen_async("notifications", _handler, poll_interval=0.2)
6566
event_id = await channel.publish_async("notifications", {"action": "async_delivery"})
@@ -105,8 +106,8 @@ async def test_asyncpg_hybrid_listen_notify_durable(postgres_service: "Any", tmp
105106

106107
received: list[Any] = []
107108

108-
async def _handler(msg: Any) -> None:
109-
received.append(msg)
109+
async def _handler(message: Any) -> None:
110+
received.append(message)
110111

111112
listener = channel.listen_async("alerts", _handler, poll_interval=0.2)
112113
event_id = await channel.publish_async("alerts", {"action": "hybrid_async"})
@@ -143,8 +144,8 @@ async def test_asyncpg_listen_notify_metadata(postgres_service: "Any") -> None:
143144

144145
received: list[Any] = []
145146

146-
async def _handler(msg: Any) -> None:
147-
received.append(msg)
147+
async def _handler(message: Any) -> None:
148+
received.append(message)
148149

149150
listener = channel.listen_async("meta_channel", _handler, poll_interval=0.2)
150151
event_id = await channel.publish_async(

tests/integration/test_adapters/test_oracledb/test_extensions/test_events.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# pyright: reportPossiblyUnboundVariable=false, reportAttributeAccessIssue=false
12
"""OracleDB integration tests for the EventChannel queue backend."""
23

34
import asyncio

0 commit comments

Comments
 (0)