Skip to content

Commit 956cec6

Browse files
authored
chore: add correct TypeIs annotation for is_coroutine_fn (#1306)
1 parent f9ec466 commit 956cec6

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ authors = [{ name = "CocoIndex", email = "[email protected]" }]
1010
readme = "README.md"
1111
requires-python = ">=3.11"
1212
dependencies = [
13-
"typing-extensions>=4.12; python_version < '3.13'",
13+
"typing-extensions>=4.12",
1414
"click>=8.1.8",
1515
"rich>=14.0.0",
1616
"python-dotenv>=1.1.0",

python/cocoindex/query_handler.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,7 @@
44
from typing import Generic, Any
55
from .index import VectorSimilarityMetric
66
import sys
7-
8-
if sys.version_info >= (3, 13):
9-
from typing import TypeVar
10-
else:
11-
from typing_extensions import TypeVar # PEP 696 backport
7+
from typing_extensions import TypeVar
128

139

1410
@dataclasses.dataclass

python/cocoindex/runtime.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
import inspect
99
import warnings
1010

11-
from typing import Any, Callable, Awaitable, TypeVar, Coroutine
11+
from typing import Any, Callable, Awaitable, TypeVar, Coroutine, ParamSpec
12+
from typing_extensions import TypeIs
1213

1314
T = TypeVar("T")
15+
P = ParamSpec("P")
1416

1517

1618
class _ExecutionContext:
@@ -68,14 +70,16 @@ def run(self, coro: Coroutine[Any, Any, T]) -> T:
6870
execution_context = _ExecutionContext()
6971

7072

71-
def is_coroutine_fn(fn: Callable[..., Any]) -> bool:
73+
def is_coroutine_fn(
74+
fn: Callable[P, T] | Callable[P, Coroutine[Any, Any, T]],
75+
) -> TypeIs[Callable[P, Coroutine[Any, Any, T]]]:
7276
if isinstance(fn, (staticmethod, classmethod)):
7377
return inspect.iscoroutinefunction(fn.__func__)
7478
else:
7579
return inspect.iscoroutinefunction(fn)
7680

7781

78-
def to_async_call(fn: Callable[..., Any]) -> Callable[..., Awaitable[Any]]:
82+
def to_async_call(fn: Callable[P, T]) -> Callable[P, Awaitable[T]]:
7983
if is_coroutine_fn(fn):
8084
return fn
8185
return lambda *args, **kwargs: asyncio.to_thread(fn, *args, **kwargs)

0 commit comments

Comments
 (0)