Skip to content

Commit 46267ec

Browse files
committed
mypy is happy
1 parent 3890939 commit 46267ec

File tree

8 files changed

+106
-65
lines changed

8 files changed

+106
-65
lines changed
File renamed without changes.
File renamed without changes.
File renamed without changes.

duckdb/_duckdb/__init__.pyi

Lines changed: 0 additions & 2 deletions
This file was deleted.

duckdb/filesystem.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,33 @@
1-
from io import TextIOBase # noqa: D100
2-
from typing import IO
1+
"""In-memory filesystem to store ephemeral dependencies.
2+
3+
Warning: Not for external use. May change at any moment. Likely to be made internal.
4+
"""
5+
6+
from __future__ import annotations
7+
8+
import io
9+
import typing
310

411
from fsspec import AbstractFileSystem
512
from fsspec.implementations.memory import MemoryFile, MemoryFileSystem
613

714
from .bytes_io_wrapper import BytesIOWrapper
815

916

10-
def is_file_like(obj) -> bool: # noqa: D103, ANN001
11-
# We only care that we can read from the file
12-
return hasattr(obj, "read") and hasattr(obj, "seek")
17+
class ModifiedMemoryFileSystem(MemoryFileSystem):
18+
"""In-memory filesystem implementation that uses its own protocol."""
1319

14-
15-
class ModifiedMemoryFileSystem(MemoryFileSystem): # noqa: D101
1620
protocol = ("DUCKDB_INTERNAL_OBJECTSTORE",)
1721
# defer to the original implementation that doesn't hardcode the protocol
18-
_strip_protocol = classmethod(AbstractFileSystem._strip_protocol.__func__)
22+
_strip_protocol: typing.Callable[[str], str] = classmethod(AbstractFileSystem._strip_protocol.__func__) # type: ignore[assignment]
1923

20-
def add_file(self, object: IO, path: str) -> None: # noqa: D102
21-
if not is_file_like(object):
24+
def add_file(self, obj: io.IOBase | BytesIOWrapper, path: str) -> None:
25+
"""Add a file to the filesystem."""
26+
if not isinstance(obj, io.IOBase):
2227
msg = "Can not read from a non file-like object"
23-
raise ValueError(msg)
24-
path = self._strip_protocol(path)
25-
if isinstance(object, TextIOBase):
28+
raise TypeError(msg)
29+
if isinstance(obj, io.TextIOBase):
2630
# Wrap this so that we can return a bytes object from 'read'
27-
object = BytesIOWrapper(object)
28-
self.store[path] = MemoryFile(self, path, object.read())
31+
obj = BytesIOWrapper(obj)
32+
path = self._strip_protocol(path)
33+
self.store[path] = MemoryFile(self, path, obj.read())

duckdb/polars_io.py

Lines changed: 75 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,24 @@
1-
import datetime # noqa: D100
1+
from __future__ import annotations # noqa: D100
2+
3+
import datetime
24
import json
3-
from collections.abc import Iterator
5+
import typing
46
from decimal import Decimal
5-
from typing import Optional
67

78
import polars as pl
89
from polars.io.plugins import register_io_source
910

1011
import duckdb
11-
from duckdb import SQLExpression
1212

13+
if typing.TYPE_CHECKING:
14+
from collections.abc import Iterator
15+
16+
import typing_extensions
17+
18+
_ExpressionTree: typing_extensions.TypeAlias = typing.Dict[str, typing.Union[str, int, "_ExpressionTree", typing.Any]] # noqa: UP006
1319

14-
def _predicate_to_expression(predicate: pl.Expr) -> Optional[SQLExpression]:
20+
21+
def _predicate_to_expression(predicate: pl.Expr) -> duckdb.Expression | None:
1522
"""Convert a Polars predicate expression to a DuckDB-compatible SQL expression.
1623
1724
Parameters:
@@ -31,7 +38,7 @@ def _predicate_to_expression(predicate: pl.Expr) -> Optional[SQLExpression]:
3138
try:
3239
# Convert the tree to SQL
3340
sql_filter = _pl_tree_to_sql(tree)
34-
return SQLExpression(sql_filter)
41+
return duckdb.SQLExpression(sql_filter)
3542
except Exception:
3643
# If the conversion fails, we return None
3744
return None
@@ -70,7 +77,7 @@ def _escape_sql_identifier(identifier: str) -> str:
7077
return f'"{escaped}"'
7178

7279

73-
def _pl_tree_to_sql(tree: dict) -> str:
80+
def _pl_tree_to_sql(tree: _ExpressionTree) -> str:
7481
"""Recursively convert a Polars expression tree (as JSON) to a SQL string.
7582
7683
Parameters:
@@ -91,38 +98,51 @@ def _pl_tree_to_sql(tree: dict) -> str:
9198
Output: "(foo > 5)"
9299
"""
93100
[node_type] = tree.keys()
94-
subtree = tree[node_type]
95101

96102
if node_type == "BinaryExpr":
97103
# Binary expressions: left OP right
98-
return (
99-
"("
100-
+ " ".join(
101-
(
102-
_pl_tree_to_sql(subtree["left"]),
103-
_pl_operation_to_sql(subtree["op"]),
104-
_pl_tree_to_sql(subtree["right"]),
105-
)
106-
)
107-
+ ")"
108-
)
104+
bin_expr_tree = tree[node_type]
105+
assert isinstance(bin_expr_tree, dict), f"A {node_type} should be a dict but got {type(bin_expr_tree)}"
106+
lhs, op, rhs = bin_expr_tree["left"], bin_expr_tree["op"], bin_expr_tree["right"]
107+
assert isinstance(lhs, dict), f"LHS of a {node_type} should be a dict but got {type(lhs)}"
108+
assert isinstance(op, str), f"The op of a {node_type} should be a str but got {type(op)}"
109+
assert isinstance(rhs, dict), f"RHS of a {node_type} should be a dict but got {type(rhs)}"
110+
return f"({_pl_tree_to_sql(lhs)} {_pl_operation_to_sql(op)} {_pl_tree_to_sql(rhs)})"
109111
if node_type == "Column":
110112
# A reference to a column name
111113
# Wrap in quotes to handle special characters
112-
return _escape_sql_identifier(subtree)
114+
col_name = tree[node_type]
115+
assert isinstance(col_name, str), f"The col name of a {node_type} should be a str but got {type(col_name)}"
116+
return _escape_sql_identifier(col_name)
113117

114118
if node_type in ("Literal", "Dyn"):
115119
# Recursively process dynamic or literal values
116-
return _pl_tree_to_sql(subtree)
120+
val_tree = tree[node_type]
121+
assert isinstance(val_tree, dict), f"A {node_type} should be a dict but got {type(val_tree)}"
122+
return _pl_tree_to_sql(val_tree)
117123

118124
if node_type == "Int":
119125
# Direct integer literals
120-
return str(subtree)
126+
int_literal = tree[node_type]
127+
assert isinstance(int_literal, (int, str)), (
128+
f"The value of an Int should be an int or str but got {type(int_literal)}"
129+
)
130+
return str(int_literal)
121131

122132
if node_type == "Function":
123133
# Handle boolean functions like IsNull, IsNotNull
124-
inputs = subtree["input"]
125-
func_dict = subtree["function"]
134+
func_tree = tree[node_type]
135+
assert isinstance(func_tree, dict), f"A {node_type} should be a dict but got {type(func_tree)}"
136+
inputs = func_tree["input"]
137+
assert isinstance(inputs, list), f"A {node_type} should have a list of dicts as input but got {type(inputs)}"
138+
input_tree = inputs[0]
139+
assert isinstance(input_tree, dict), (
140+
f"A {node_type} should have a list of dicts as input but got {type(input_tree)}"
141+
)
142+
func_dict = func_tree["function"]
143+
assert isinstance(func_dict, dict), (
144+
f"A {node_type} should have a function dict as input but got {type(func_dict)}"
145+
)
126146

127147
if "Boolean" in func_dict:
128148
func = func_dict["Boolean"]
@@ -140,24 +160,31 @@ def _pl_tree_to_sql(tree: dict) -> str:
140160

141161
if node_type == "Scalar":
142162
# Detect format: old style (dtype/value) or new style (direct type key)
143-
if "dtype" in subtree and "value" in subtree:
144-
dtype = str(subtree["dtype"])
145-
value = subtree["value"]
163+
scalar_tree = tree[node_type]
164+
assert isinstance(scalar_tree, dict), f"A {node_type} should be a dict but got {type(scalar_tree)}"
165+
if "dtype" in scalar_tree and "value" in scalar_tree:
166+
dtype = str(scalar_tree["dtype"])
167+
value = scalar_tree["value"]
146168
else:
147169
# New style: dtype is the single key in the dict
148-
dtype = next(iter(subtree.keys()))
149-
value = subtree
170+
dtype = next(iter(scalar_tree.keys()))
171+
value = scalar_tree
172+
assert isinstance(dtype, str), f"A {node_type} should have a str dtype but got {type(dtype)}"
173+
assert isinstance(value, dict), f"A {node_type} should have a dict value but got {type(value)}"
150174

151175
# Decimal support
152176
if dtype.startswith("{'Decimal'") or dtype == "Decimal":
153177
decimal_value = value["Decimal"]
154-
decimal_value = Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[1])
155-
return str(decimal_value)
178+
assert isinstance(decimal_value, list), (
179+
f"A {dtype} should be a two member list but got {type(decimal_value)}"
180+
)
181+
return str(Decimal(decimal_value[0]) / Decimal(10 ** decimal_value[1]))
156182

157183
# Datetime with microseconds since epoch
158184
if dtype.startswith("{'Datetime'") or dtype == "Datetime":
159-
micros = value["Datetime"][0]
160-
dt_timestamp = datetime.datetime.fromtimestamp(micros / 1_000_000, tz=datetime.UTC)
185+
micros = value["Datetime"]
186+
assert isinstance(micros, list), f"A {dtype} should be a one member list but got {type(micros)}"
187+
dt_timestamp = datetime.datetime.fromtimestamp(micros[0] / 1_000_000, tz=datetime.timezone.utc)
161188
return f"'{dt_timestamp!s}'::TIMESTAMP"
162189

163190
# Match simple numeric/boolean types
@@ -179,6 +206,7 @@ def _pl_tree_to_sql(tree: dict) -> str:
179206
# Time type
180207
if dtype == "Time":
181208
nanoseconds = value["Time"]
209+
assert isinstance(nanoseconds, int), f"A {dtype} should be an int but got {type(nanoseconds)}"
182210
seconds = nanoseconds // 1_000_000_000
183211
microseconds = (nanoseconds % 1_000_000_000) // 1_000
184212
dt_time = (datetime.datetime.min + datetime.timedelta(seconds=seconds, microseconds=microseconds)).time()
@@ -187,36 +215,41 @@ def _pl_tree_to_sql(tree: dict) -> str:
187215
# Date type
188216
if dtype == "Date":
189217
days_since_epoch = value["Date"]
218+
assert isinstance(days_since_epoch, (float, int)), (
219+
f"A {dtype} should be a number but got {type(days_since_epoch)}"
220+
)
190221
date = datetime.date(1970, 1, 1) + datetime.timedelta(days=days_since_epoch)
191222
return f"'{date}'::DATE"
192223

193224
# Binary type
194225
if dtype == "Binary":
195-
binary_data = bytes(value["Binary"])
226+
bin_value = value["Binary"]
227+
assert isinstance(bin_value, bytes), f"A {dtype} should be bytes but got {type(bin_value)}"
228+
binary_data = bytes(bin_value)
196229
escaped = "".join(f"\\x{b:02x}" for b in binary_data)
197230
return f"'{escaped}'::BLOB"
198231

199232
# String type
200233
if dtype == "String" or dtype == "StringOwned":
201234
# Some new formats may store directly under StringOwned
202-
string_val = value.get("StringOwned", value.get("String", None))
235+
string_val: object | None = value.get("StringOwned", value.get("String", None))
203236
return f"'{string_val}'"
204237

205238
msg = f"Unsupported scalar type {dtype!s}, with value {value}"
206239
raise NotImplementedError(msg)
207240

208-
msg = f"Node type: {node_type} is not implemented. {subtree}"
241+
msg = f"Node type: {node_type} is not implemented. {tree[node_type]}"
209242
raise NotImplementedError(msg)
210243

211244

212245
def duckdb_source(relation: duckdb.DuckDBPyRelation, schema: pl.schema.Schema) -> pl.LazyFrame:
213246
"""A polars IO plugin for DuckDB."""
214247

215248
def source_generator(
216-
with_columns: Optional[list[str]],
217-
predicate: Optional[pl.Expr],
218-
n_rows: Optional[int],
219-
batch_size: Optional[int],
249+
with_columns: list[str] | None,
250+
predicate: pl.Expr | None,
251+
n_rows: int | None,
252+
batch_size: int | None,
220253
) -> Iterator[pl.DataFrame]:
221254
duck_predicate = None
222255
relation_final = relation
@@ -239,8 +272,8 @@ def source_generator(
239272
for record_batch in iter(results.read_next_batch, None):
240273
if predicate is not None and duck_predicate is None:
241274
# We have a predicate, but did not manage to push it down, we fallback here
242-
yield pl.from_arrow(record_batch).filter(predicate)
275+
yield pl.from_arrow(record_batch).filter(predicate) # type: ignore[arg-type,misc]
243276
else:
244-
yield pl.from_arrow(record_batch)
277+
yield pl.from_arrow(record_batch) # type: ignore[misc]
245278

246279
return register_io_source(source_generator, schema=schema)

duckdb/udf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
# ruff: noqa: D100
2-
from typing import Callable
2+
import typing
33

44

5-
def vectorized(func: Callable) -> Callable:
5+
def vectorized(func: typing.Callable[..., typing.Any]) -> typing.Callable[..., typing.Any]:
66
"""Decorate a function with annotated function parameters.
77
88
This allows DuckDB to infer that the function should be provided with pyarrow arrays and should expect

pyproject.toml

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,19 +305,24 @@ pretty = true
305305
python_version = "3.9"
306306
exclude = [
307307
"duckdb/experimental/",
308-
"duckdb/functional/",
309308
"duckdb/query_graph/",
310-
"duckdb/typing/",
311309
"duckdb/value/",
312310
]
313311

314312
[[tool.mypy.overrides]]
315313
module = [
316-
# "some_untyped_dependency.*",
317-
# "another_untyped_lib"
314+
"fsspec.*",
315+
"pandas",
316+
"polars",
317+
"pyarrow",
318+
"torch",
318319
]
319320
ignore_missing_imports = true
320321

322+
[[tool.mypy.overrides]]
323+
module = "duckdb.filesystem"
324+
disallow_subclassing_any = false
325+
321326
[tool.pytest.ini_options]
322327
minversion = "6.0"
323328
addopts = "-ra -q"

0 commit comments

Comments
 (0)