Skip to content

Commit ef3dc7d

Browse files
committed
Improve sqlite aggregration protocols
1 parent bbddfee commit ef3dc7d

File tree

3 files changed

+172
-23
lines changed

3 files changed

+172
-23
lines changed
Lines changed: 142 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
import sqlite3
2+
import sys
3+
4+
5+
class WindowSumInt:
6+
def __init__(self) -> None:
7+
self.count = 0
8+
9+
def step(self, param: int) -> None:
10+
self.count += param
11+
12+
def value(self) -> int:
13+
return self.count
14+
15+
def inverse(self, param: int) -> None:
16+
self.count -= param
17+
18+
def finalize(self) -> int:
19+
return self.count
20+
21+
22+
con = sqlite3.connect(":memory:")
23+
cur = con.execute("CREATE TABLE test(x, y)")
24+
values = [("a", 4), ("b", 5), ("c", 3), ("d", 8), ("e", 1)]
25+
cur.executemany("INSERT INTO test VALUES(?, ?)", values)
26+
27+
if sys.version_info >= (3, 11):
28+
con.create_window_function("sumint", 1, WindowSumInt)
29+
30+
con.create_aggregate("sumint", 1, WindowSumInt)
31+
cur.execute(
32+
"""
33+
SELECT x, sumint(y) OVER (
34+
ORDER BY x ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING
35+
) AS sum_y
36+
FROM test ORDER BY x
37+
"""
38+
)
39+
con.close()
40+
41+
42+
def _create_window_function() -> WindowSumInt:
43+
return WindowSumInt()
44+
45+
46+
# A callable should work as well.
47+
if sys.version_info >= (3, 11):
48+
con.create_window_function("sumint", 1, _create_window_function)
49+
con.create_aggregate("sumint", 1, _create_window_function)
50+
51+
# With num_args set to 1, the callable should not be called with more than one.
52+
53+
54+
class WindowSumIntMultiArgs:
55+
def __init__(self) -> None:
56+
self.count = 0
57+
58+
def step(self, *args: int) -> None:
59+
self.count += sum(args)
60+
61+
def value(self) -> int:
62+
return self.count
63+
64+
def inverse(self, *args: int) -> None:
65+
self.count -= sum(args)
66+
67+
def finalize(self) -> int:
68+
return self.count
69+
70+
71+
if sys.version_info >= (3, 11):
72+
con.create_window_function("sumint", 1, WindowSumIntMultiArgs)
73+
con.create_window_function("sumint", 2, WindowSumIntMultiArgs)
74+
75+
con.create_aggregate("sumint", 1, WindowSumIntMultiArgs)
76+
con.create_aggregate("sumint", 2, WindowSumIntMultiArgs)
77+
78+
79+
# Test case: Fixed parameter aggregates (the common case in practice)
80+
class FixedTwoParamAggregate:
81+
def __init__(self) -> None:
82+
self.total = 0
83+
84+
def step(self, a: int, b: int) -> None:
85+
self.total += a + b
86+
87+
def finalize(self) -> int:
88+
return self.total
89+
90+
91+
con.create_aggregate("sum2", 2, FixedTwoParamAggregate)
92+
93+
94+
class FixedThreeParamWindowAggregate:
95+
def __init__(self) -> None:
96+
self.total = 0
97+
98+
def step(self, a: int, b: int, c: int) -> None:
99+
self.total += a + b + c
100+
101+
def inverse(self, a: int, b: int, c: int) -> None:
102+
self.total -= a + b + c
103+
104+
def value(self) -> int:
105+
return self.total
106+
107+
def finalize(self) -> int:
108+
return self.total
109+
110+
111+
if sys.version_info >= (3, 11):
112+
con.create_window_function("sum3", 3, FixedThreeParamWindowAggregate)
113+
114+
115+
# What do protocols still catch?
116+
117+
118+
# Missing required method
119+
class MissingStep:
120+
def __init__(self) -> None:
121+
self.total = 0
122+
123+
def finalize(self) -> int:
124+
return self.total
125+
126+
127+
con.create_aggregate("bad", 2, MissingStep) # type: ignore[arg-type] # missing step method
128+
129+
130+
# Invalid return type from finalize (not a valid SQLite type)
131+
class BadFinalizeReturn:
132+
def __init__(self) -> None:
133+
self.items: list[int] = []
134+
135+
def step(self, x: int) -> None:
136+
self.items.append(x)
137+
138+
def finalize(self) -> list[int]: # list is not a valid SQLite type
139+
return self.items
140+
141+
142+
con.create_aggregate("bad2", 1, BadFinalizeReturn) # type: ignore[arg-type] # bad return type
File renamed without changes.

stdlib/sqlite3/__init__.pyi

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,7 @@ if sys.version_info < (3, 10):
216216

217217
_CursorT = TypeVar("_CursorT", bound=Cursor)
218218
_SqliteData: TypeAlias = str | ReadableBuffer | int | float | None
219+
_SQLType = TypeVar("_SQLType", bound=_SqliteData)
219220
# Data that is passed through adapters can be of any type accepted by an adapter.
220221
_AdaptedInputData: TypeAlias = _SqliteData | Any
221222
# The Mapping must really be a dict, but making it invariant is too annoying.
@@ -225,28 +226,29 @@ _IsolationLevel: TypeAlias = Literal["DEFERRED", "EXCLUSIVE", "IMMEDIATE"] | Non
225226
_RowFactoryOptions: TypeAlias = type[Row] | Callable[[Cursor, Row], object] | None
226227

227228
@type_check_only
228-
class _AnyParamWindowAggregateClass(Protocol):
229-
def step(self, *args: Any) -> object: ...
230-
def inverse(self, *args: Any) -> object: ...
231-
def value(self) -> _SqliteData: ...
232-
def finalize(self) -> _SqliteData: ...
229+
class _SingleParamAggregateProtocol(Protocol[_SQLType]):
230+
def step(self, param: _SQLType, /) -> object: ...
231+
def finalize(self) -> _SQLType: ...
233232

234233
@type_check_only
235-
class _WindowAggregateClass(Protocol):
236-
step: Callable[..., object]
237-
inverse: Callable[..., object]
238-
def value(self) -> _SqliteData: ...
234+
class _AnyParamAggregateProtocol(Protocol):
235+
@property
236+
def step(self) -> Callable[..., object]: ...
239237
def finalize(self) -> _SqliteData: ...
240238

241239
@type_check_only
242-
class _AggregateProtocol(Protocol):
243-
def step(self, value: int, /) -> object: ...
244-
def finalize(self) -> int: ...
240+
class _SingleParamWindowAggregateClass(Protocol[_SQLType]):
241+
def step(self, param: _SQLType, /) -> object: ...
242+
def inverse(self, param: _SQLType, /) -> object: ...
243+
def value(self) -> _SQLType: ...
244+
def finalize(self) -> _SQLType: ...
245245

246246
@type_check_only
247-
class _SingleParamWindowAggregateClass(Protocol):
248-
def step(self, param: Any, /) -> object: ...
249-
def inverse(self, param: Any, /) -> object: ...
247+
class _AnyParamWindowAggregateClass(Protocol):
248+
@property
249+
def step(self) -> Callable[..., object]: ...
250+
@property
251+
def inverse(self) -> Callable[..., object]: ...
250252
def value(self) -> _SqliteData: ...
251253
def finalize(self) -> _SqliteData: ...
252254

@@ -334,22 +336,27 @@ class Connection:
334336
def blobopen(self, table: str, column: str, row: int, /, *, readonly: bool = False, name: str = "main") -> Blob: ...
335337

336338
def commit(self) -> None: ...
337-
def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AggregateProtocol]) -> None: ...
339+
@overload
340+
def create_aggregate(
341+
self, name: str, n_arg: Literal[1], aggregate_class: Callable[[], _SingleParamAggregateProtocol[_SQLType]]
342+
) -> None: ...
343+
@overload
344+
def create_aggregate(self, name: str, n_arg: int, aggregate_class: Callable[[], _AnyParamAggregateProtocol]) -> None: ...
345+
338346
if sys.version_info >= (3, 11):
339347
# num_params determines how many params will be passed to the aggregate class. We provide an overload
340348
# for the case where num_params = 1, which is expected to be the common case.
341349
@overload
342350
def create_window_function(
343-
self, name: str, num_params: Literal[1], aggregate_class: Callable[[], _SingleParamWindowAggregateClass] | None, /
344-
) -> None: ...
345-
# And for num_params = -1, which means the aggregate must accept any number of parameters.
346-
@overload
347-
def create_window_function(
348-
self, name: str, num_params: Literal[-1], aggregate_class: Callable[[], _AnyParamWindowAggregateClass] | None, /
351+
self,
352+
name: str,
353+
num_params: Literal[1],
354+
aggregate_class: Callable[[], _SingleParamWindowAggregateClass[_SQLType]] | None,
355+
/,
349356
) -> None: ...
350357
@overload
351358
def create_window_function(
352-
self, name: str, num_params: int, aggregate_class: Callable[[], _WindowAggregateClass] | None, /
359+
self, name: str, num_params: int, aggregate_class: Callable[[], _AnyParamWindowAggregateClass] | None, /
353360
) -> None: ...
354361

355362
def create_collation(self, name: str, callback: Callable[[str, str], int | SupportsIndex] | None, /) -> None: ...

0 commit comments

Comments
 (0)