Skip to content

Commit 2e29fd7

Browse files
committed
feat: enhance event channel type hints and improve async handling in tests
1 parent cb52532 commit 2e29fd7

File tree

16 files changed

+124
-57
lines changed

16 files changed

+124
-57
lines changed

sqlspec/base.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,25 @@ def configs(self) -> "dict[int, DatabaseConfigProtocol[Any, Any, Any]]":
159159
"""
160160
return self._configs
161161

162+
@overload
163+
def event_channel(self, config: "type[SyncConfigT]") -> "SyncEventChannel": ...
164+
165+
@overload
166+
def event_channel(self, config: "type[AsyncConfigT]") -> "AsyncEventChannel": ...
167+
168+
@overload
162169
def event_channel(
163-
self, config: "type[SyncConfigT | AsyncConfigT] | SyncConfigT | AsyncConfigT"
170+
self, config: "SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any]"
171+
) -> "SyncEventChannel": ...
172+
173+
@overload
174+
def event_channel(
175+
self, config: "AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]"
176+
) -> "AsyncEventChannel": ...
177+
178+
def event_channel(
179+
self,
180+
config: "type[SyncConfigT | AsyncConfigT] | SyncDatabaseConfig[Any, Any, Any] | NoPoolSyncConfig[Any, Any] | AsyncDatabaseConfig[Any, Any, Any] | NoPoolAsyncConfig[Any, Any]",
164181
) -> "SyncEventChannel | AsyncEventChannel":
165182
"""Create an event channel for the provided configuration.
166183

sqlspec/utils/sync_tools.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT
161161
from sqlspec.utils.portal import get_global_portal
162162

163163
portal = get_global_portal()
164-
return portal.call(async_function, *args, **kwargs)
164+
typed_partial = cast("Callable[[], Coroutine[Any, Any, ReturnT]]", partial_f)
165+
return portal.call(typed_partial)
165166
else:
166167
if loop.is_running():
167168
try:
@@ -180,7 +181,8 @@ def wrapper(*args: "ParamSpecT.args", **kwargs: "ParamSpecT.kwargs") -> "ReturnT
180181
from sqlspec.utils.portal import get_global_portal
181182

182183
portal = get_global_portal()
183-
return portal.call(async_function, *args, **kwargs)
184+
typed_partial = cast("Callable[[], Coroutine[Any, Any, ReturnT]]", partial_f)
185+
return portal.call(typed_partial)
184186

185187
return wrapper
186188

tests/integration/test_adapters/test_asyncmy/test_extensions/test_events.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# pyright: reportPrivateUsage=false, reportAttributeAccessIssue=false, reportArgumentType=false
22
"""AsyncMy integration tests for the EventChannel queue backend."""
33

4-
from typing import Any
4+
from typing import TYPE_CHECKING, Any, cast
55

66
import pytest
77
from pytest_databases.docker.mysql import MySQLService
@@ -10,6 +10,9 @@
1010
from sqlspec.adapters.asyncmy import AsyncmyConfig
1111
from sqlspec.migrations.commands import AsyncMigrationCommands
1212

13+
if TYPE_CHECKING:
14+
from sqlspec.extensions.events import AsyncEventChannel
15+
1316
pytestmark = pytest.mark.xdist_group("mysql")
1417

1518

@@ -38,7 +41,7 @@ async def test_asyncmy_event_channel_queue_fallback(mysql_service: MySQLService,
3841

3942
spec = SQLSpec()
4043
spec.add_config(config)
41-
channel = spec.event_channel(config)
44+
channel = cast("AsyncEventChannel", spec.event_channel(config))
4245

4346
assert channel._backend_name == "table_queue"
4447

@@ -84,7 +87,7 @@ async def test_asyncmy_event_channel_multiple_messages(mysql_service: MySQLServi
8487

8588
spec = SQLSpec()
8689
spec.add_config(config)
87-
channel = spec.event_channel(config)
90+
channel = cast("AsyncEventChannel", spec.event_channel(config))
8891

8992
event_ids = [
9093
await channel.publish("multi_test", {"index": 0}),
@@ -131,7 +134,7 @@ async def test_asyncmy_event_channel_nack_redelivery(mysql_service: MySQLService
131134

132135
spec = SQLSpec()
133136
spec.add_config(config)
134-
channel = spec.event_channel(config)
137+
channel = cast("AsyncEventChannel", spec.event_channel(config))
135138

136139
event_id = await channel.publish("nack_test", {"retry": True})
137140

tests/integration/test_adapters/test_asyncpg/test_extensions/test_events_listen_notify.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
"""PostgreSQL LISTEN/NOTIFY event channel tests for asyncpg adapter."""
33

44
import asyncio
5-
from typing import Any
5+
from typing import TYPE_CHECKING, Any, cast
66

77
import pytest
88

99
from sqlspec import SQLSpec
1010
from sqlspec.adapters.asyncpg import AsyncpgConfig
1111
from sqlspec.migrations.commands import AsyncMigrationCommands
1212

13+
if TYPE_CHECKING:
14+
from sqlspec.extensions.events import AsyncEventChannel
15+
1316
pytestmark = pytest.mark.xdist_group("postgres")
1417

1518

@@ -29,7 +32,7 @@ async def test_asyncpg_listen_notify_publish_and_ack(postgres_service: "Any") ->
2932

3033
spec = SQLSpec()
3134
spec.add_config(config)
32-
channel = spec.event_channel(config)
35+
channel = cast("AsyncEventChannel", spec.event_channel(config))
3336

3437
assert channel._backend_name == "listen_notify"
3538

@@ -55,7 +58,7 @@ async def test_asyncpg_listen_notify_message_delivery(postgres_service: "Any") -
5558

5659
spec = SQLSpec()
5760
spec.add_config(config)
58-
channel = spec.event_channel(config)
61+
channel = cast("AsyncEventChannel", spec.event_channel(config))
5962

6063
received: list[Any] = []
6164

@@ -102,7 +105,7 @@ async def test_asyncpg_hybrid_listen_notify_durable(postgres_service: "Any", tmp
102105

103106
spec = SQLSpec()
104107
spec.add_config(config)
105-
channel = spec.event_channel(config)
108+
channel = cast("AsyncEventChannel", spec.event_channel(config))
106109

107110
assert channel._backend_name == "listen_notify_durable"
108111

@@ -144,7 +147,7 @@ async def test_asyncpg_listen_notify_metadata(postgres_service: "Any") -> None:
144147

145148
spec = SQLSpec()
146149
spec.add_config(config)
147-
channel = spec.event_channel(config)
150+
channel = cast("AsyncEventChannel", spec.event_channel(config))
148151

149152
received: list[Any] = []
150153

tests/integration/test_adapters/test_duckdb/test_extensions/test_events.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
11
"""DuckDB integration tests for EventChannel queue fallback."""
22

3+
from typing import TYPE_CHECKING, cast
4+
35
import pytest
46

57
from sqlspec import SQLSpec
68
from sqlspec.adapters.duckdb import DuckDBConfig
79
from sqlspec.migrations.commands import SyncMigrationCommands
810

11+
if TYPE_CHECKING:
12+
from sqlspec.extensions.events import SyncEventChannel
13+
914

1015
@pytest.mark.integration
1116
@pytest.mark.duckdb
@@ -26,7 +31,7 @@ def test_duckdb_event_channel_queue_fallback(tmp_path) -> None:
2631

2732
spec = SQLSpec()
2833
spec.add_config(config)
29-
channel = spec.event_channel(config)
34+
channel = cast("SyncEventChannel", spec.event_channel(config))
3035

3136
event_id = channel.publish("notifications", {"action": "duck"})
3237
iterator = channel.iter_events("notifications", poll_interval=0.05)

tests/integration/test_adapters/test_oracledb/test_extensions/test_events.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import asyncio
55
import os
66
from textwrap import dedent
7-
from typing import TYPE_CHECKING
7+
from typing import TYPE_CHECKING, cast
88

99
import pytest
1010

@@ -15,6 +15,8 @@
1515
if TYPE_CHECKING:
1616
from pathlib import Path
1717

18+
from sqlspec.extensions.events import AsyncEventChannel, SyncEventChannel
19+
1820
pytestmark = pytest.mark.xdist_group("oracle")
1921

2022
ORACLE_HOST = os.environ.get("ORACLE_TEST_HOST", "127.0.0.1")
@@ -106,7 +108,7 @@ def test_oracle_sync_event_channel_queue_fallback(tmp_path: "Path") -> None:
106108

107109
spec = SQLSpec()
108110
spec.add_config(config)
109-
channel = spec.event_channel(config)
111+
channel = cast("SyncEventChannel", spec.event_channel(config))
110112

111113
event_id = channel.publish("notifications", {"action": "oracle"})
112114
iterator = channel.iter_events("notifications", poll_interval=0.5)
@@ -146,7 +148,7 @@ async def test_oracle_async_event_channel_queue_fallback(tmp_path: "Path") -> No
146148

147149
spec = SQLSpec()
148150
spec.add_config(config)
149-
channel = spec.event_channel(config)
151+
channel = cast("AsyncEventChannel", spec.event_channel(config))
150152

151153
event_id = await channel.publish("notifications", {"action": "oracle_async"})
152154
iterator = channel.iter_events("notifications", poll_interval=0.5)

tests/integration/test_adapters/test_psqlpy/test_extensions/test_events.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""Psqlpy integration tests for the EventChannel queue backend."""
33

44
import asyncio
5+
from typing import TYPE_CHECKING, cast
56

67
import pytest
78
from pytest_databases.docker.postgres import PostgresService
@@ -10,6 +11,9 @@
1011
from sqlspec.adapters.psqlpy import PsqlpyConfig
1112
from sqlspec.migrations.commands import AsyncMigrationCommands
1213

14+
if TYPE_CHECKING:
15+
from sqlspec.extensions.events import AsyncEventChannel
16+
1317
pytestmark = pytest.mark.xdist_group("postgres")
1418

1519

@@ -36,7 +40,7 @@ async def test_psqlpy_event_channel_queue_fallback(tmp_path, postgres_service: P
3640

3741
spec = SQLSpec()
3842
spec.add_config(config)
39-
channel = spec.event_channel(config)
43+
channel = cast("AsyncEventChannel", spec.event_channel(config))
4044

4145
event_id = await channel.publish("notifications", {"action": "psqlpy"})
4246
iterator = channel.iter_events("notifications", poll_interval=0.1)

tests/integration/test_adapters/test_psqlpy/test_extensions/test_events_listen_notify.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,17 @@
22
"""PostgreSQL LISTEN/NOTIFY event channel tests for psqlpy adapter."""
33

44
import asyncio
5-
from typing import Any
5+
from typing import TYPE_CHECKING, Any, cast
66

77
import pytest
88

99
from sqlspec import SQLSpec
1010
from sqlspec.adapters.psqlpy import PsqlpyConfig
1111
from sqlspec.migrations.commands import AsyncMigrationCommands
1212

13+
if TYPE_CHECKING:
14+
from sqlspec.extensions.events import AsyncEventChannel
15+
1316
pytestmark = pytest.mark.xdist_group("postgres")
1417

1518

@@ -28,7 +31,7 @@ async def test_psqlpy_listen_notify_native(postgres_service: "Any") -> None:
2831

2932
spec = SQLSpec()
3033
spec.add_config(config)
31-
channel = spec.event_channel(config)
34+
channel = cast("AsyncEventChannel", spec.event_channel(config))
3235

3336
received: list[Any] = []
3437

@@ -75,7 +78,7 @@ async def test_psqlpy_listen_notify_hybrid(postgres_service: "Any", tmp_path) ->
7578

7679
spec = SQLSpec()
7780
spec.add_config(config)
78-
channel = spec.event_channel(config)
81+
channel = cast("AsyncEventChannel", spec.event_channel(config))
7982

8083
received: list[Any] = []
8184

tests/integration/test_adapters/test_psycopg/test_extensions/test_events.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
"""Psycopg integration tests for the EventChannel queue backend."""
33

44
import asyncio
5+
from typing import TYPE_CHECKING, cast
56

67
import pytest
78
from pytest_databases.docker.postgres import PostgresService
@@ -10,6 +11,9 @@
1011
from sqlspec.adapters.psycopg import PsycopgAsyncConfig, PsycopgSyncConfig
1112
from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands
1213

14+
if TYPE_CHECKING:
15+
from sqlspec.extensions.events import AsyncEventChannel, SyncEventChannel
16+
1317
pytestmark = pytest.mark.xdist_group("postgres")
1418

1519

@@ -34,7 +38,7 @@ def test_psycopg_sync_event_channel_queue_fallback(tmp_path, postgres_service: P
3438

3539
spec = SQLSpec()
3640
spec.add_config(config)
37-
channel = spec.event_channel(config)
41+
channel = cast("SyncEventChannel", spec.event_channel(config))
3842

3943
event_id = channel.publish("notifications", {"action": "queue"})
4044
iterator = channel.iter_events("notifications", poll_interval=0.1)
@@ -70,7 +74,7 @@ async def test_psycopg_async_event_channel_queue_fallback(tmp_path, postgres_ser
7074

7175
spec = SQLSpec()
7276
spec.add_config(config)
73-
channel = spec.event_channel(config)
77+
channel = cast("AsyncEventChannel", spec.event_channel(config))
7478

7579
event_id = await channel.publish("notifications", {"action": "async_queue"})
7680
iterator = channel.iter_events("notifications", poll_interval=0.1)

tests/integration/test_adapters/test_psycopg/test_extensions/test_events_listen_notify.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,17 @@
33

44
import asyncio
55
import time
6-
from typing import Any
6+
from typing import TYPE_CHECKING, Any, cast
77

88
import pytest
99

1010
from sqlspec import SQLSpec
1111
from sqlspec.adapters.psycopg import PsycopgAsyncConfig, PsycopgSyncConfig
1212
from sqlspec.migrations.commands import AsyncMigrationCommands, SyncMigrationCommands
1313

14+
if TYPE_CHECKING:
15+
from sqlspec.extensions.events import AsyncEventChannel, SyncEventChannel
16+
1417
pytestmark = pytest.mark.xdist_group("postgres")
1518

1619

@@ -28,7 +31,7 @@ def test_psycopg_sync_listen_notify(postgres_service: "Any") -> None:
2831

2932
spec = SQLSpec()
3033
spec.add_config(config)
31-
channel = spec.event_channel(config)
34+
channel = cast("SyncEventChannel", spec.event_channel(config))
3235
backend = channel._backend
3336
assert "_ensure_sync_listener" in dir(backend)
3437

@@ -59,7 +62,7 @@ async def test_psycopg_async_listen_notify(postgres_service: "Any") -> None:
5962

6063
spec = SQLSpec()
6164
spec.add_config(config)
62-
channel = spec.event_channel(config)
65+
channel = cast("AsyncEventChannel", spec.event_channel(config))
6366

6467
received: list[Any] = []
6568

@@ -98,7 +101,7 @@ def test_psycopg_sync_hybrid_listen_notify_durable(postgres_service: "Any", tmp_
98101

99102
spec = SQLSpec()
100103
spec.add_config(config)
101-
channel = spec.event_channel(config)
104+
channel = cast("SyncEventChannel", spec.event_channel(config))
102105

103106
received: list[Any] = []
104107
listener = channel.listen("alerts", lambda message: received.append(message), poll_interval=0.2)
@@ -134,7 +137,7 @@ async def test_psycopg_async_hybrid_listen_notify_durable(postgres_service: "Any
134137

135138
spec = SQLSpec()
136139
spec.add_config(config)
137-
channel = spec.event_channel(config)
140+
channel = cast("AsyncEventChannel", spec.event_channel(config))
138141

139142
received: list[Any] = []
140143

0 commit comments

Comments
 (0)