Skip to content

Commit ada1b11

Browse files
committed
Ruff ANN001: fixed missing annotations for function args
1 parent adff8cb commit ada1b11

File tree

12 files changed

+29
-23
lines changed

12 files changed

+29
-23
lines changed

duckdb/experimental/spark/exception.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ class ContributionsAcceptedError(NotImplementedError): # noqa: D100
44
feel free to open up a PR or a Discussion over on https://github.com/duckdb/duckdb.
55
""" # noqa: D205
66

7-
def __init__(self, message=None) -> None: # noqa: D107
7+
def __init__(self, message: str=None) -> None: # noqa: D107
88
doc = self.__class__.__doc__
99
if message:
1010
doc = message + "\n" + doc

duckdb/experimental/spark/sql/catalog.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, session: SparkSession) -> None: # noqa: D107
4040
def listDatabases(self) -> list[Database]: # noqa: D102
4141
res = self._session.conn.sql("select database_name from duckdb_databases()").fetchall()
4242

43-
def transform_to_database(x) -> Database:
43+
def transform_to_database(x: list[str]) -> Database:
4444
return Database(name=x[0], description=None, locationUri="")
4545

4646
databases = [transform_to_database(x) for x in res]
@@ -49,7 +49,7 @@ def transform_to_database(x) -> Database:
4949
def listTables(self) -> list[Table]: # noqa: D102
5050
res = self._session.conn.sql("select table_name, database_name, sql, temporary from duckdb_tables()").fetchall()
5151

52-
def transform_to_table(x) -> Table:
52+
def transform_to_table(x: list[str]) -> Table:
5353
return Table(name=x[0], database=x[1], description=x[2], tableType="", isTemporary=x[3])
5454

5555
tables = [transform_to_table(x) for x in res]
@@ -63,7 +63,7 @@ def listColumns(self, tableName: str, dbName: Optional[str] = None) -> list[Colu
6363
query += f" and database_name = '{dbName}'"
6464
res = self._session.conn.sql(query).fetchall()
6565

66-
def transform_to_column(x) -> Column:
66+
def transform_to_column(x: list[str|bool]) -> Column:
6767
return Column(name=x[0], description=None, dataType=x[1], nullable=x[2], isPartition=False, isBucket=False)
6868

6969
columns = [transform_to_column(x) for x in res]

duckdb/experimental/spark/sql/column.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
__all__ = ["Column"]
1313

1414

15-
def _get_expr(x) -> Expression:
15+
def _get_expr(x: 'Column' | str) -> Expression:
1616
return x.expr if isinstance(x, Column) else ConstantExpression(x)
1717

1818

duckdb/experimental/spark/sql/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -698,7 +698,7 @@ def join(
698698
on = str(on)
699699
assert isinstance(how, str), "how should be a string"
700700

701-
def map_to_recognized_jointype(how):
701+
def map_to_recognized_jointype(how: str):
702702
known_aliases = {
703703
"inner": [],
704704
"outer": ["full", "fullouter", "full_outer"],
@@ -1354,7 +1354,7 @@ def collect(self) -> list[Row]: # noqa: D102
13541354
columns = self.relation.columns
13551355
result = self.relation.fetchall()
13561356

1357-
def construct_row(values, names) -> Row:
1357+
def construct_row(values: list, names: list[str]) -> Row:
13581358
row = tuple.__new__(Row, list(values))
13591359
row.__fields__ = list(names)
13601360
return row

duckdb/experimental/spark/sql/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def when(condition: "Column", value: Any) -> Column: # noqa: D103
9999
return Column(expr)
100100

101101

102-
def _inner_expr_or_val(val):
102+
def _inner_expr_or_val(val: Column | str) -> Column | str:
103103
return val.expr if isinstance(val, Column) else val
104104

105105

duckdb/experimental/spark/sql/session.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import uuid # noqa: D100
22
from collections.abc import Iterable
3-
from typing import TYPE_CHECKING, Any, Optional, Union
3+
from typing import TYPE_CHECKING, Any, Optional, Union, Sized
44

55
if TYPE_CHECKING:
66
from pandas.core.frame import DataFrame as PandasDataFrame
@@ -58,7 +58,7 @@ def _create_dataframe(self, data: Union[Iterable[Any], "PandasDataFrame"]) -> Da
5858
self.conn.register(unique_name, data)
5959
return DataFrame(self.conn.sql(f'select * from "{unique_name}"'), self)
6060

61-
def verify_tuple_integrity(tuples):
61+
def verify_tuple_integrity(tuples: list[tuple]):
6262
if len(tuples) <= 1:
6363
return
6464
expected_length = len(tuples[0])
@@ -80,8 +80,8 @@ def verify_tuple_integrity(tuples):
8080
data = list(data)
8181
verify_tuple_integrity(data)
8282

83-
def construct_query(tuples) -> str:
84-
def construct_values_list(row, start_param_idx):
83+
def construct_query(tuples: Iterable) -> str:
84+
def construct_values_list(row: Sized, start_param_idx: int):
8585
parameter_count = len(row)
8686
parameters = [f"${x + start_param_idx}" for x in range(parameter_count)]
8787
parameters = "(" + ", ".join(parameters) + ")"
@@ -98,7 +98,7 @@ def construct_values_list(row, start_param_idx):
9898

9999
query = construct_query(data)
100100

101-
def construct_parameters(tuples):
101+
def construct_parameters(tuples: Iterable):
102102
parameters = []
103103
for row in tuples:
104104
parameters.extend(list(row))
@@ -109,7 +109,7 @@ def construct_parameters(tuples):
109109
rel = self.conn.sql(query, params=parameters)
110110
return DataFrame(rel, self)
111111

112-
def _createDataFrameFromPandas(self, data: "PandasDataFrame", types, names) -> DataFrame:
112+
def _createDataFrameFromPandas(self, data: "PandasDataFrame", types: list[str] | None, names: list[str] | None) -> DataFrame:
113113
df = self._create_dataframe(data)
114114

115115
# Cast to types

duckdb/experimental/spark/sql/types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
class DataType:
6767
"""Base class for data types."""
6868

69-
def __init__(self, duckdb_type) -> None: # noqa: D107
69+
def __init__(self, duckdb_type: DuckDBPyType) -> None: # noqa: D107
7070
self.duckdb_type = duckdb_type
7171

7272
def __repr__(self) -> str: # noqa: D105

duckdb/filesystem.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from io import TextIOBase # noqa: D100
2+
from typing import IO
23

34
from fsspec import AbstractFileSystem
45
from fsspec.implementations.memory import MemoryFile, MemoryFileSystem
56

67
from .bytes_io_wrapper import BytesIOWrapper
78

89

9-
def is_file_like(obj): # noqa: D103
10+
def is_file_like(obj): # noqa: D103, ANN001
1011
# We only care that we can read from the file
1112
return hasattr(obj, "read") and hasattr(obj, "seek")
1213

@@ -16,7 +17,7 @@ class ModifiedMemoryFileSystem(MemoryFileSystem): # noqa: D101
1617
# defer to the original implementation that doesn't hardcode the protocol
1718
_strip_protocol = classmethod(AbstractFileSystem._strip_protocol.__func__)
1819

19-
def add_file(self, object, path): # noqa: D102
20+
def add_file(self, object: IO, path: str): # noqa: D102
2021
if not is_file_like(object):
2122
msg = "Can not read from a non file-like object"
2223
raise ValueError(msg)

duckdb/udf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1-
def vectorized(func): # noqa: D100
1+
from typing import Callable
2+
3+
4+
def vectorized(func: Callable): # noqa: D100
25
"""Decorate a function with annotated function parameters, so DuckDB can infer that the function should be provided with pyarrow arrays and should expect pyarrow array(s) as output."""
36
import types
47
from inspect import signature

duckdb_packaging/_versioning.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ def strip_post_from_version(version: str) -> str:
147147
return re.sub(r"[\.-]post[0-9]+", "", version)
148148

149149

150-
def get_git_describe(repo_path: Optional[pathlib.Path] = None, since_major=False, since_minor=False) -> Optional[str]:
150+
def get_git_describe(
151+
repo_path: Optional[pathlib.Path] = None, since_major: bool = False, since_minor: bool = False
152+
) -> Optional[str]:
151153
"""Get git describe output for version determination.
152154
153155
Returns:

0 commit comments

Comments
 (0)