Skip to content

Commit 948070a

Browse files
committed
Ruff ANN401: specialized Any type hints where possible
1 parent 83304f6 commit 948070a

File tree

10 files changed

+115
-102
lines changed

10 files changed

+115
-102
lines changed

duckdb/bytes_io_wrapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self, buffer: Union[StringIO, TextIOBase], encoding: str = "utf-8")
4848
# overflow to the front of the bytestring the next time reading is performed
4949
self.overflow = b""
5050

51-
def __getattr__(self, attr: str) -> Any: # noqa: D105
51+
def __getattr__(self, attr: str) -> Any: # noqa: D105, ANN401
5252
return getattr(self.buffer, attr)
5353

5454
def read(self, n: Union[int, None] = -1) -> bytes: # noqa: D102

duckdb/experimental/spark/sql/column.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
from typing import TYPE_CHECKING, Any, Callable, Union, cast # noqa: D100
1+
from collections.abc import Iterable # noqa: D100
2+
from typing import TYPE_CHECKING, Any, Callable, Union, cast
23

34
from ..exception import ContributionsAcceptedError
45
from .types import DataType
@@ -136,7 +137,7 @@ def __neg__(self) -> "Column": # noqa: D105
136137

137138
__rpow__ = _bin_op("__rpow__")
138139

139-
def __getitem__(self, k: Any) -> "Column":
140+
def __getitem__(self, k: Any) -> "Column": # noqa: ANN401
140141
"""An expression that gets an item at position ``ordinal`` out of a list,
141142
or gets an item by key out of a dict.
142143
@@ -176,7 +177,7 @@ def __getitem__(self, k: Any) -> "Column":
176177
expr_str = str(self.expr) + "." + str(k)
177178
return Column(ColumnExpression(expr_str))
178179

179-
def __getattr__(self, item: Any) -> "Column":
180+
def __getattr__(self, item: Any) -> "Column": # noqa: ANN401
180181
"""An expression that gets an item at position ``ordinal`` out of a list,
181182
or gets an item by key out of a dict.
182183
@@ -208,15 +209,15 @@ def __getattr__(self, item: Any) -> "Column":
208209
def alias(self, alias: str) -> "Column": # noqa: D102
209210
return Column(self.expr.alias(alias))
210211

211-
def when(self, condition: "Column", value: Any) -> "Column": # noqa: D102
212+
def when(self, condition: "Column", value: Union["Column", str]) -> "Column": # noqa: D102
212213
if not isinstance(condition, Column):
213214
msg = "condition should be a Column"
214215
raise TypeError(msg)
215216
v = _get_expr(value)
216217
expr = self.expr.when(condition.expr, v)
217218
return Column(expr)
218219

219-
def otherwise(self, value: Any) -> "Column": # noqa: D102
220+
def otherwise(self, value: Union["Column", str]) -> "Column": # noqa: D102
220221
v = _get_expr(value)
221222
expr = self.expr.otherwise(v)
222223
return Column(expr)
@@ -229,7 +230,7 @@ def cast(self, dataType: Union[DataType, str]) -> "Column": # noqa: D102
229230
internal_type = dataType.duckdb_type
230231
return Column(self.expr.cast(internal_type))
231232

232-
def isin(self, *cols: Any) -> "Column": # noqa: D102
233+
def isin(self, *cols: Union[Iterable[Union["Column", str]], Union["Column", str]]) -> "Column": # noqa: D102
233234
if len(cols) == 1 and isinstance(cols[0], (list, set)):
234235
# Only one argument supplied, it's a list
235236
cols = cast("tuple", cols[0])

duckdb/experimental/spark/sql/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def withColumnsRenamed(self, colsMap: dict[str, str]) -> "DataFrame":
282282
rel = self.relation.select(*cols)
283283
return DataFrame(rel, self.session)
284284

285-
def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any) -> "DataFrame":
285+
def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any) -> "DataFrame": # noqa: ANN401
286286
"""Returns a new :class:`DataFrame`. Concise syntax for chaining custom transformations.
287287
288288
.. versionadded:: 3.0.0
@@ -342,7 +342,7 @@ def transform(self, func: Callable[..., "DataFrame"], *args: Any, **kwargs: Any)
342342
)
343343
return result
344344

345-
def sort(self, *cols: Union[str, Column, list[Union[str, Column]]], **kwargs: Any) -> "DataFrame":
345+
def sort(self, *cols: Union[str, Column, list[Union[str, Column]]], **kwargs: Any) -> "DataFrame": # noqa: ANN401
346346
"""Returns a new :class:`DataFrame` sorted by the specified column(s).
347347
348348
Parameters

duckdb/experimental/spark/sql/functions.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def ucase(str: "ColumnOrName") -> Column:
9090
return upper(str)
9191

9292

93-
def when(condition: "Column", value: Any) -> Column: # noqa: D103
93+
def when(condition: "Column", value: Column | str) -> Column: # noqa: D103
9494
if not isinstance(condition, Column):
9595
msg = "condition should be a Column"
9696
raise TypeError(msg)
@@ -143,7 +143,7 @@ def array(*cols: Union["ColumnOrName", Union[list["ColumnOrName"], tuple["Column
143143
return _invoke_function_over_columns("list_value", *cols)
144144

145145

146-
def lit(col: Any) -> Column: # noqa: D103
146+
def lit(col: Any) -> Column: # noqa: D103, ANN401
147147
return col if isinstance(col, Column) else Column(ConstantExpression(col))
148148

149149

@@ -842,7 +842,7 @@ def collect_list(col: "ColumnOrName") -> Column:
842842
return array_agg(col)
843843

844844

845-
def array_append(col: "ColumnOrName", value: Any) -> Column:
845+
def array_append(col: "ColumnOrName", value: Column | str) -> Column:
846846
"""Collection function: returns an array of the elements in col1 along
847847
with the added element in col2 at the last of the array.
848848
@@ -876,7 +876,7 @@ def array_append(col: "ColumnOrName", value: Any) -> Column:
876876
return _invoke_function("list_append", _to_column_expr(col), _get_expr(value))
877877

878878

879-
def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Any) -> Column:
879+
def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: Column | str) -> Column:
880880
"""Collection function: adds an item into a given array at a specified array index.
881881
Array indices start at 1, or start from the end if index is negative.
882882
Index above array size appends the array, or prepends the array if index is negative,
@@ -969,7 +969,7 @@ def array_insert(arr: "ColumnOrName", pos: Union["ColumnOrName", int], value: An
969969
)
970970

971971

972-
def array_contains(col: "ColumnOrName", value: Any) -> Column:
972+
def array_contains(col: "ColumnOrName", value: Column | str) -> Column:
973973
"""Collection function: returns null if the array is null, true if the array contains the
974974
given value, and false otherwise.
975975
@@ -1937,7 +1937,7 @@ def array_compact(col: "ColumnOrName") -> Column:
19371937
)
19381938

19391939

1940-
def array_remove(col: "ColumnOrName", element: Any) -> Column:
1940+
def array_remove(col: "ColumnOrName", element: Any) -> Column: # noqa: ANN401
19411941
"""Collection function: Remove all elements that equal to element from the given array.
19421942
19431943
.. versionadded:: 2.4.0
@@ -5083,7 +5083,7 @@ def array_join(col: "ColumnOrName", delimiter: str, null_replacement: Optional[s
50835083
return _invoke_function("array_to_string", col, ConstantExpression(delimiter))
50845084

50855085

5086-
def array_position(col: "ColumnOrName", value: Any) -> Column:
5086+
def array_position(col: "ColumnOrName", value: Any) -> Column: # noqa: ANN401
50875087
"""Collection function: Locates the position of the first occurrence of the given value
50885088
in the given array. Returns null if either of the arguments are null.
50895089
@@ -5122,7 +5122,7 @@ def array_position(col: "ColumnOrName", value: Any) -> Column:
51225122
)
51235123

51245124

5125-
def array_prepend(col: "ColumnOrName", value: Any) -> Column:
5125+
def array_prepend(col: "ColumnOrName", value: Any) -> Column: # noqa: ANN401
51265126
"""Collection function: Returns an array containing element as
51275127
well as all elements from array. The new element is positioned
51285128
at the beginning of the array.

duckdb/experimental/spark/sql/session.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import uuid # noqa: D100
22
from collections.abc import Iterable, Sized
3-
from typing import TYPE_CHECKING, Any, Optional, Union
3+
from typing import TYPE_CHECKING, Any, NoReturn, Optional, Union
44

55
if TYPE_CHECKING:
66
from pandas.core.frame import DataFrame as PandasDataFrame
@@ -205,7 +205,7 @@ def range( # noqa: D102
205205

206206
return DataFrame(self.conn.table_function("range", parameters=[start, end, step]), self)
207207

208-
def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: # noqa: D102
208+
def sql(self, sqlQuery: str, **kwargs: Any) -> DataFrame: # noqa: D102, ANN401
209209
if kwargs:
210210
raise NotImplementedError
211211
relation = self.conn.sql(sqlQuery)
@@ -246,7 +246,7 @@ def sparkContext(self) -> SparkContext: # noqa: D102
246246
return self._context
247247

248248
@property
249-
def streams(self) -> Any: # noqa: D102
249+
def streams(self) -> NoReturn: # noqa: D102
250250
raise ContributionsAcceptedError
251251

252252
@property
@@ -278,7 +278,10 @@ def getOrCreate(self) -> "SparkSession": # noqa: D102
278278
return SparkSession(context)
279279

280280
def config( # noqa: D102
281-
self, key: Optional[str] = None, value: Optional[Any] = None, conf: Optional[SparkConf] = None
281+
self,
282+
key: Optional[str] = None,
283+
value: Optional[Any] = None, # noqa: ANN401
284+
conf: Optional[SparkConf] = None,
282285
) -> "SparkSession.Builder":
283286
return self
284287

duckdb/experimental/spark/sql/types.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
# This code is based on code from Apache Spark under the license found in the LICENSE # noqa: D100
1+
# ruff: noqa: D100
2+
# This code is based on code from Apache Spark under the license found in the LICENSE
23
# file located in the 'spark' folder.
34

45
import calendar
@@ -11,6 +12,7 @@
1112
from typing import (
1213
Any,
1314
ClassVar,
15+
NoReturn,
1416
Optional,
1517
TypeVar,
1618
Union,
@@ -102,11 +104,11 @@ def needConversion(self) -> bool:
102104
"""
103105
return False
104106

105-
def toInternal(self, obj: Any) -> Any:
107+
def toInternal(self, obj: Any) -> Any: # noqa: ANN401
106108
"""Converts a Python object into an internal SQL object."""
107109
return obj
108110

109-
def fromInternal(self, obj: Any) -> Any:
111+
def fromInternal(self, obj: Any) -> Any: # noqa: ANN401
110112
"""Converts an internal SQL object into a native Python object."""
111113
return obj
112114

@@ -889,7 +891,7 @@ def simpleString(self) -> str: # noqa: D102
889891
def __repr__(self) -> str: # noqa: D105
890892
return "StructType([%s])" % ", ".join(str(field) for field in self)
891893

892-
def __contains__(self, item: Any) -> bool: # noqa: D105
894+
def __contains__(self, item: str) -> bool: # noqa: D105
893895
return item in self.names
894896

895897
def extract_types_and_names(self) -> tuple[list[str], list[str]]: # noqa: D102
@@ -1010,21 +1012,21 @@ def _cachedSqlType(cls) -> DataType:
10101012
cls._cached_sql_type = cls.sqlType() # type: ignore[attr-defined]
10111013
return cls._cached_sql_type # type: ignore[attr-defined]
10121014

1013-
def toInternal(self, obj: Any) -> Any:
1015+
def toInternal(self, obj: Any) -> Any: # noqa: ANN401
10141016
if obj is not None:
10151017
return self._cachedSqlType().toInternal(self.serialize(obj))
10161018

1017-
def fromInternal(self, obj: Any) -> Any:
1019+
def fromInternal(self, obj: Any) -> Any: # noqa: ANN401
10181020
v = self._cachedSqlType().fromInternal(obj)
10191021
if v is not None:
10201022
return self.deserialize(v)
10211023

1022-
def serialize(self, obj: Any) -> Any:
1024+
def serialize(self, obj: Any) -> NoReturn: # noqa: ANN401
10231025
"""Converts a user-type object into a SQL datum."""
10241026
msg = "UDT must implement toInternal()."
10251027
raise NotImplementedError(msg)
10261028

1027-
def deserialize(self, datum: Any) -> Any:
1029+
def deserialize(self, datum: Any) -> NoReturn: # noqa: ANN401
10281030
"""Converts a SQL datum into a user-type object."""
10291031
msg = "UDT must implement fromInternal()."
10301032
raise NotImplementedError(msg)
@@ -1132,7 +1134,7 @@ class Row(tuple):
11321134
def __new__(cls, *args: str) -> "Row": ...
11331135

11341136
@overload
1135-
def __new__(cls, **kwargs: Any) -> "Row": ...
1137+
def __new__(cls, **kwargs: Any) -> "Row": ... # noqa: ANN401
11361138

11371139
def __new__(cls, *args: Optional[str], **kwargs: Optional[Any]) -> "Row": # noqa: D102
11381140
if args and kwargs:
@@ -1179,7 +1181,7 @@ def asDict(self, recursive: bool = False) -> dict[str, Any]:
11791181

11801182
if recursive:
11811183

1182-
def conv(obj: Any) -> Any:
1184+
def conv(obj: Row | list | dict | object) -> list | dict | object:
11831185
if isinstance(obj, Row):
11841186
return obj.asDict(True)
11851187
elif isinstance(obj, list):
@@ -1193,22 +1195,22 @@ def conv(obj: Any) -> Any:
11931195
else:
11941196
return dict(zip(self.__fields__, self))
11951197

1196-
def __contains__(self, item: Any) -> bool: # noqa: D105
1198+
def __contains__(self, item: Any) -> bool: # noqa: D105, ANN401
11971199
if hasattr(self, "__fields__"):
11981200
return item in self.__fields__
11991201
else:
12001202
return super(Row, self).__contains__(item)
12011203

12021204
# let object acts like class
1203-
def __call__(self, *args: Any) -> "Row":
1205+
def __call__(self, *args: Any) -> "Row": # noqa: ANN401
12041206
"""Create new Row object."""
12051207
if len(args) > len(self):
12061208
raise ValueError(
12071209
"Can not create Row with fields %s, expected %d values but got %s" % (self, len(self), args)
12081210
)
12091211
return _create_row(self, args)
12101212

1211-
def __getitem__(self, item: Any) -> Any: # noqa: D105
1213+
def __getitem__(self, item: Any) -> Any: # noqa: D105, ANN401
12121214
if isinstance(item, (int, slice)):
12131215
return super(Row, self).__getitem__(item)
12141216
try:
@@ -1221,7 +1223,7 @@ def __getitem__(self, item: Any) -> Any: # noqa: D105
12211223
except ValueError:
12221224
raise ValueError(item)
12231225

1224-
def __getattr__(self, item: str) -> Any: # noqa: D105
1226+
def __getattr__(self, item: str) -> Any: # noqa: D105, ANN401
12251227
if item.startswith("__"):
12261228
raise AttributeError(item)
12271229
try:
@@ -1234,7 +1236,7 @@ def __getattr__(self, item: str) -> Any: # noqa: D105
12341236
except ValueError:
12351237
raise AttributeError(item)
12361238

1237-
def __setattr__(self, key: Any, value: Any) -> None: # noqa: D105
1239+
def __setattr__(self, key: Any, value: Any) -> None: # noqa: D105, ANN401
12381240
if key != "__fields__":
12391241
msg = "Row is read-only"
12401242
raise RuntimeError(msg)

0 commit comments

Comments
 (0)