Skip to content

Commit 94b4d9c

Browse files
committed
feat(channels): refine type hints and improve async handling in SQLSpecChannelsBackend
1 parent f7abb60 commit 94b4d9c

File tree

4 files changed

+15
-16
lines changed

4 files changed

+15
-16
lines changed

sqlspec/extensions/events/channel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __init__(self, config: "DatabaseConfigProtocol[Any, Any, Any]") -> None:
114114
self._config = config
115115
self._backend_name = backend_label
116116
self._is_async = bool(config.is_async)
117-
self._portal_bridge = bool(extension_settings.get("portal_bridge", False)) and self._is_async
117+
self._portal_bridge = bool(extension_settings.get("portal_bridge")) and self._is_async
118118
self._portal = None
119119
self._runtime = config.get_observability_runtime()
120120
self._listeners_async: dict[str, AsyncEventListener] = {}
@@ -626,7 +626,7 @@ async def shutdown_async(self) -> None:
626626
if backend_shutdown is not None and callable(backend_shutdown):
627627
result = backend_shutdown()
628628
if result is not None:
629-
await result # type: ignore[misc]
629+
await result
630630
except Exception as error:
631631
self._end_event_span(span, error=error)
632632
raise

sqlspec/extensions/litestar/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
from sqlspec.extensions.litestar.cli import database_group
21
from sqlspec.extensions.litestar.channels import SQLSpecChannelsBackend
2+
from sqlspec.extensions.litestar.cli import database_group
33
from sqlspec.extensions.litestar.config import LitestarConfig
44
from sqlspec.extensions.litestar.plugin import (
55
DEFAULT_COMMIT_MODE,

sqlspec/extensions/litestar/channels.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,17 @@
44
import base64
55
import hashlib
66
import re
7-
from collections.abc import AsyncGenerator, Iterable
8-
from typing import Any
7+
from typing import TYPE_CHECKING, Any
98

109
from litestar.channels.backends.base import ChannelsBackend
1110

12-
from sqlspec.extensions.events.channel import EventChannel
1311
from sqlspec.utils.logging import get_logger
1412

13+
if TYPE_CHECKING:
14+
from collections.abc import AsyncGenerator, Iterable
15+
16+
from sqlspec.extensions.events.channel import EventChannel
17+
1518
logger = get_logger("extensions.litestar.channels")
1619

1720
_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
@@ -42,11 +45,7 @@ class SQLSpecChannelsBackend(ChannelsBackend):
4245
)
4346

4447
def __init__(
45-
self,
46-
event_channel: "EventChannel",
47-
*,
48-
channel_prefix: str = "litestar",
49-
poll_interval: float = 0.2,
48+
self, event_channel: "EventChannel", *, channel_prefix: str = "litestar", poll_interval: float = 0.2
5049
) -> None:
5150
if not _IDENTIFIER_PATTERN.match(channel_prefix):
5251
msg = f"channel_prefix must be a valid identifier, got: {channel_prefix!r}"
@@ -57,9 +56,9 @@ def __init__(
5756
self._event_channel = event_channel
5857
self._channel_prefix = channel_prefix
5958
self._poll_interval = poll_interval
60-
self._output_queue: "asyncio.Queue[tuple[str, bytes]] | None" = None
59+
self._output_queue: asyncio.Queue[tuple[str, bytes]] | None = None
6160
self._shutdown = asyncio.Event()
62-
self._tasks: dict[str, "asyncio.Task[None]"] = {}
61+
self._tasks: dict[str, asyncio.Task[None]] = {}
6362
self._to_db_channel: dict[str, str] = {}
6463
self._to_litestar_channel: dict[str, str] = {}
6564

@@ -95,7 +94,7 @@ async def subscribe(self, channels: "Iterable[str]") -> None:
9594
self._tasks[channel] = task
9695

9796
async def unsubscribe(self, channels: "Iterable[str]") -> None:
98-
cancelled: list["asyncio.Task[None]"] = []
97+
cancelled: list[asyncio.Task[None]] = []
9998
for channel in channels:
10099
task = self._tasks.pop(channel, None)
101100
if task is None:

tests/integration/test_extensions/test_litestar/test_channels_backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import asyncio
22
import tempfile
3-
from typing import Any
3+
from typing import Any, cast
44

55
import msgspec.json
66
import pytest
@@ -14,7 +14,7 @@
1414

1515
async def _next_event(subscriber: "Any") -> bytes:
1616
async for event in subscriber.iter_events():
17-
return event
17+
return cast("bytes", event)
1818
msg = "Subscriber stopped without yielding an event"
1919
raise RuntimeError(msg)
2020

0 commit comments

Comments
 (0)