Skip to content

Commit 9fd27e4

Browse files
committed
unnecessarily strict return types
1 parent ef3dc7d commit 9fd27e4

File tree

2 files changed

+78
-10
lines changed

2 files changed

+78
-10
lines changed

stdlib/@tests/test_cases/sqlite3/check_aggregations.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,48 @@ def finalize(self) -> int:
7575
con.create_aggregate("sumint", 1, WindowSumIntMultiArgs)
7676
con.create_aggregate("sumint", 2, WindowSumIntMultiArgs)
7777

78+
# n_arg=-1 requires *args to handle any number of arguments
79+
if sys.version_info >= (3, 11):
80+
con.create_window_function("sumint_varargs", -1, WindowSumIntMultiArgs)
81+
82+
con.create_aggregate("sumint_varargs", -1, WindowSumIntMultiArgs)
83+
84+
85+
# n_arg=-1 should reject fixed-arity methods
86+
class FixedArityAggregate:
87+
def __init__(self) -> None:
88+
self.total = 0
89+
90+
def step(self, a: int, b: int) -> None:
91+
self.total += a + b
92+
93+
def finalize(self) -> int:
94+
return self.total
95+
96+
97+
con.create_aggregate("bad_varargs", -1, FixedArityAggregate) # type: ignore[arg-type]
98+
99+
100+
class FixedArityWindowAggregate:
101+
def __init__(self) -> None:
102+
self.total = 0
103+
104+
def step(self, a: int, b: int) -> None:
105+
self.total += a + b
106+
107+
def inverse(self, a: int, b: int) -> None:
108+
self.total -= a + b
109+
110+
def value(self) -> int:
111+
return self.total
112+
113+
def finalize(self) -> int:
114+
return self.total
115+
116+
117+
if sys.version_info >= (3, 11):
118+
con.create_window_function("bad_varargs", -1, FixedArityWindowAggregate) # type: ignore[arg-type]
119+
78120

79121
# Test case: Fixed parameter aggregates (the common case in practice)
80122
class FixedTwoParamAggregate:

stdlib/sqlite3/__init__.pyi

Lines changed: 36 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,8 @@ if sys.version_info < (3, 10):
217217
_CursorT = TypeVar("_CursorT", bound=Cursor)
218218
_SqliteData: TypeAlias = str | ReadableBuffer | int | float | None
219219
_SQLType = TypeVar("_SQLType", bound=_SqliteData)
220+
_SQLType_contra = TypeVar("_SQLType_contra", bound=_SqliteData, contravariant=True)
221+
220222
# Data that is passed through adapters can be of any type accepted by an adapter.
221223
_AdaptedInputData: TypeAlias = _SqliteData | Any
222224
# The Mapping must really be a dict, but making it invariant is too annoying.
@@ -226,9 +228,9 @@ _IsolationLevel: TypeAlias = Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | Non
226228
_RowFactoryOptions: TypeAlias = type[Row] | Callable[[Cursor, Row], object] | None
227229

228230
@type_check_only
229-
class _SingleParamAggregateProtocol(Protocol[_SQLType]):
230-
def step(self, param: _SQLType, /) -> object: ...
231-
def finalize(self) -> _SQLType: ...
231+
class _SingleParamAggregateProtocol(Protocol):
232+
def step(self, param: _SqliteData, /) -> object: ...
233+
def finalize(self) -> _SqliteData: ...
232234

233235
@type_check_only
234236
class _AnyParamAggregateProtocol(Protocol):
@@ -237,11 +239,16 @@ class _AnyParamAggregateProtocol(Protocol):
237239
def finalize(self) -> _SqliteData: ...
238240

239241
@type_check_only
240-
class _SingleParamWindowAggregateClass(Protocol[_SQLType]):
241-
def step(self, param: _SQLType, /) -> object: ...
242-
def inverse(self, param: _SQLType, /) -> object: ...
243-
def value(self) -> _SQLType: ...
244-
def finalize(self) -> _SQLType: ...
242+
class _AnyArgsAggregateProtocol(Protocol):
243+
def step(self, *args: _SqliteData) -> object: ...
244+
def finalize(self) -> _SqliteData: ...
245+
246+
@type_check_only
247+
class _SingleParamWindowAggregateClass(Protocol[_SQLType_contra]):
248+
def step(self, param: _SQLType_contra, /) -> object: ...
249+
def inverse(self, param: _SQLType_contra, /) -> object: ...
250+
def value(self) -> _SqliteData: ...
251+
def finalize(self) -> _SqliteData: ...
245252

246253
@type_check_only
247254
class _AnyParamWindowAggregateClass(Protocol):
@@ -252,6 +259,13 @@ class _AnyParamWindowAggregateClass(Protocol):
252259
def value(self) -> _SqliteData: ...
253260
def finalize(self) -> _SqliteData: ...
254261

262+
@type_check_only
263+
class _AnyArgsWindowAggregateClass(Protocol):
264+
def step(self, *args: _SqliteData) -> object: ...
265+
def inverse(self, *args: _SqliteData) -> object: ...
266+
def value(self) -> _SqliteData: ...
267+
def finalize(self) -> _SqliteData: ...
268+
255269
# These classes are implemented in the C module _sqlite3. At runtime, they're imported
256270
# from there into sqlite3.dbapi2 and from that module to here. However, they
257271
# consider themselves to live in the sqlite3.* namespace, so we'll define them here.
@@ -338,7 +352,11 @@ class Connection:
338352
def commit(self) -> None: ...
339353
@overload
340354
def create_aggregate(
341-
self, name: str, n_arg: Literal[1], aggregate_class: Callable[[], _SingleParamAggregateProtocol[_SQLType]]
355+
self, name: str, n_arg: Literal[1], aggregate_class: Callable[[], _SingleParamAggregateProtocol]
356+
) -> None: ...
357+
@overload
358+
def create_aggregate(
359+
self, name: str, n_arg: Literal[-1], aggregate_class: Callable[[], _AnyArgsAggregateProtocol]
342360
) -> None: ...
343361
@overload
344362
def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AnyParamAggregateProtocol]) -> None: ...
@@ -351,7 +369,15 @@ class Connection:
351369
self,
352370
name: str,
353371
num_params: Literal[1],
354-
aggregate_class: Callable[[], _SingleParamWindowAggregateClass[_SQLType]] | None,
372+
aggregate_class: Callable[[], _SingleParamWindowAggregateClass[_SQLType_contra]] | None,
373+
/,
374+
) -> None: ...
375+
@overload
376+
def create_window_function(
377+
self,
378+
name: str,
379+
num_params: Literal[-1],
380+
aggregate_class: Callable[[], _AnyArgsWindowAggregateClass] | None,
355381
/,
356382
) -> None: ...
357383
@overload

0 commit comments

Comments
 (0)