Skip to content

Commit b1c7e8e

Browse files
authored
chore: bump pyarrow-stubs==19.2 (#2458)
1 parent c23e56c commit b1c7e8e

File tree

10 files changed

+51
-37
lines changed

10 files changed

+51
-37
lines changed

narwhals/_arrow/dataframe.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def _gather(self, rows: SizedMultiIndexSelector[ChunkedArrayAny]) -> Self:
255255
return self._with_native(self.native.slice(0, 0))
256256
if self._backend_version < (18,) and isinstance(rows, tuple):
257257
rows = list(rows)
258-
return self._with_native(self.native.take(rows)) # pyright: ignore[reportArgumentType]
258+
return self._with_native(self.native.take(rows))
259259

260260
def _gather_slice(self, rows: _SliceIndex | range) -> Self:
261261
start = rows.start or 0
@@ -302,8 +302,7 @@ def _select_multi_name(
302302
selector = cast("Sequence[str]", columns.to_pylist())
303303
else:
304304
selector = columns
305-
# TODO @dangotbanned: Fix upstream `pa.Table.select` https://github.com/zen-xu/pyarrow-stubs/blob/f899bb35e10b36f7906a728e9f8acf3e0a1f9f64/pyarrow-stubs/__lib_pxi/table.pyi#L597
306-
# NOTE: Investigate what `cython` actually checks
305+
# NOTE: Fixed in https://github.com/zen-xu/pyarrow-stubs/pull/221
307306
return self._with_native(self.native.select(selector)) # pyright: ignore[reportArgumentType]
308307

309308
@property
@@ -370,13 +369,9 @@ def with_columns(self: ArrowDataFrame, *exprs: ArrowExpr) -> ArrowDataFrame:
370369
col_name = col_value.name
371370
column = self._extract_comparand(col_value)
372371
native_frame = (
373-
native_frame.set_column(
374-
columns.index(col_name),
375-
field_=col_name,
376-
column=column, # type: ignore[arg-type]
377-
)
372+
native_frame.set_column(columns.index(col_name), col_name, column=column)
378373
if col_name in columns
379-
else native_frame.append_column(field_=col_name, column=column)
374+
else native_frame.append_column(col_name, column=column)
380375
)
381376

382377
return self._with_native(native_frame, validate_column_names=False)
@@ -708,9 +703,9 @@ def unique(
708703
subset = list(subset or self.columns)
709704

710705
if keep in {"any", "first", "last"}:
711-
agg_func_map = {"any": "min", "first": "min", "last": "max"}
706+
from narwhals._arrow.group_by import ArrowGroupBy
712707

713-
agg_func = agg_func_map[keep]
708+
agg_func = ArrowGroupBy._REMAP_UNIQUE[keep]
714709
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
715710
keep_idx_native = (
716711
self.native.append_column(col_token, pa.array(np.arange(len(self))))

narwhals/_arrow/expr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
161161
# TODO(marco): is there a way to do this efficiently without
162162
# doing 2 sorts? Here we're sorting the dataframe and then
163163
# again calling `sort_indices`. `ArrowSeries.scatter` would also sort.
164-
sorting_indices = pc.sort_indices(df.get_column(token).native) # type: ignore[call-overload]
164+
sorting_indices = pc.sort_indices(df.get_column(token).native)
165165
return [s._with_native(s.native.take(sorting_indices)) for s in result]
166166
else:
167167

narwhals/_arrow/group_by.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@
2020
if TYPE_CHECKING:
2121
from narwhals._arrow.dataframe import ArrowDataFrame
2222
from narwhals._arrow.expr import ArrowExpr
23+
from narwhals._arrow.typing import AggregateOptions # type: ignore[attr-defined]
24+
from narwhals._arrow.typing import Aggregation # type: ignore[attr-defined]
2325
from narwhals._arrow.typing import Incomplete
2426
from narwhals._compliant.group_by import NarwhalsAggregation
27+
from narwhals.typing import UniqueKeepStrategy
2528

2629

27-
class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr"]):
28-
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]] = {
30+
class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]):
31+
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = {
2932
"sum": "sum",
3033
"mean": "mean",
3134
"median": "approximate_median",
@@ -37,6 +40,11 @@ class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr"]):
3740
"n_unique": "count_distinct",
3841
"count": "count",
3942
}
43+
_REMAP_UNIQUE: ClassVar[Mapping[UniqueKeepStrategy, Aggregation]] = {
44+
"any": "min",
45+
"first": "min",
46+
"last": "max",
47+
}
4048

4149
def __init__(
4250
self,
@@ -54,7 +62,7 @@ def __init__(
5462

5563
def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
5664
self._ensure_all_simple(exprs)
57-
aggs: list[tuple[str, str, Any]] = []
65+
aggs: list[tuple[str, Aggregation, AggregateOptions | None]] = []
5866
expected_pyarrow_column_names: list[str] = self._keys.copy()
5967
new_column_names: list[str] = self._keys.copy()
6068
exclude = (*self._keys, *self._output_key_names)

narwhals/_arrow/namespace.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,10 @@ def func(df: ArrowDataFrame) -> list[ArrowSeries]:
203203
context=self,
204204
)
205205

206-
# NOTE: Stub issue fixed in https://github.com/zen-xu/pyarrow-stubs/pull/203
207206
def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
208207
if self._backend_version >= (14,):
209-
return pa.concat_tables(dfs, promote_options="default") # type: ignore[arg-type]
210-
return pa.concat_tables(dfs, promote=True) # type: ignore[arg-type] # pragma: no cover
208+
return pa.concat_tables(dfs, promote_options="default")
209+
return pa.concat_tables(dfs, promote=True) # pragma: no cover
211210

212211
def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
213212
names = list(chain.from_iterable(df.column_names for df in dfs))
@@ -225,7 +224,7 @@ def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table:
225224
f" - dataframe {i}: {cols_current}\n"
226225
)
227226
raise TypeError(msg)
228-
return pa.concat_tables(dfs) # type: ignore[arg-type]
227+
return pa.concat_tables(dfs)
229228

230229
@property
231230
def selectors(self) -> ArrowSelectorNamespace:

narwhals/_arrow/series.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def __rmod__(self, other: Any) -> Self:
298298
return self._with_native(res)
299299

300300
def __invert__(self) -> Self:
301-
return self._with_native(pc.invert(self.native)) # type: ignore[call-overload]
301+
return self._with_native(pc.invert(self.native))
302302

303303
@property
304304
def _type(self) -> pa.DataType:
@@ -426,6 +426,7 @@ def _gather_slice(self, rows: _SliceIndex | range) -> Self:
426426
def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
427427
import numpy as np # ignore-banned-import
428428

429+
values_native: ArrayAny
429430
if isinstance(indices, int):
430431
indices_native = pa.array([indices])
431432
values_native = pa.array([values])
@@ -436,20 +437,25 @@ def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
436437
if isinstance(values, self.__class__):
437438
values_native = values.native.combine_chunks()
438439
else:
439-
values_native = pa.array(values)
440+
# NOTE: Requires fixes in https://github.com/zen-xu/pyarrow-stubs/pull/209
441+
pa_array: Incomplete = pa.array
442+
values_native = pa_array(values)
440443

441-
sorting_indices = pc.sort_indices(indices_native) # type: ignore[call-overload]
442-
indices_native = pc.take(indices_native, sorting_indices)
443-
values_native = pc.take(values_native, sorting_indices)
444+
sorting_indices = pc.sort_indices(indices_native)
445+
indices_native = indices_native.take(sorting_indices)
446+
values_native = values_native.take(sorting_indices)
444447

445448
mask: _1DArray = np.zeros(self.len(), dtype=bool)
446449
mask[indices_native] = True
447-
result = pc.replace_with_mask(
448-
self.native,
449-
cast("list[bool]", mask),
450-
values_native.take(indices_native),
450+
# NOTE: Multiple issues
451+
# - Missing `values` type
452+
# - `mask` accepts a `np.ndarray`, but not mentioned in stubs
453+
# - Missing `replacements` type
454+
# - Missing return type
455+
pc_replace_with_mask: Incomplete = pc.replace_with_mask
456+
return self._with_native(
457+
pc_replace_with_mask(self.native, mask, values_native.take(indices_native))
451458
)
452-
return self._with_native(result)
453459

454460
def to_list(self) -> list[Any]:
455461
return self.native.to_pylist()

narwhals/_arrow/typing.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
from typing_extensions import TypeAlias
1616

1717
import pyarrow as pa
18+
from pyarrow.__lib_pxi.table import AggregateOptions # noqa: F401
19+
from pyarrow.__lib_pxi.table import Aggregation # noqa: F401
1820
from pyarrow._stubs_typing import ( # pyright: ignore[reportMissingModuleSource]
1921
Indices, # noqa: F401
2022
)

narwhals/_arrow/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,8 @@ def narwhals_to_native_dtype(dtype: DType | type[DType], version: Version) -> pa
186186
if isinstance_or_issubclass(dtype, dtypes.Categorical):
187187
return pa.dictionary(pa.uint32(), pa.string())
188188
if isinstance_or_issubclass(dtype, dtypes.Datetime):
189-
return pa.timestamp(dtype.time_unit, tz=dtype.time_zone) # pyright: ignore[reportArgumentType]
189+
unit = dtype.time_unit
190+
return pa.timestamp(unit, tz) if (tz := dtype.time_zone) else pa.timestamp(unit)
190191
if isinstance_or_issubclass(dtype, dtypes.Duration):
191192
return pa.duration(dtype.time_unit)
192193
if isinstance_or_issubclass(dtype, dtypes.Date):
@@ -278,15 +279,18 @@ def floordiv_compat(left: ArrayOrScalar, right: ArrayOrScalar) -> Any:
278279

279280
if pa.types.is_integer(left.type) and pa.types.is_integer(right.type):
280281
divided = pc.divide_checked(left, right)
282+
# TODO @dangotbanned: Use a `TypeVar` in guards
283+
# Narrowing to a `Union` isn't interacting well with the rest of the stubs
284+
# https://github.com/zen-xu/pyarrow-stubs/pull/215
281285
if pa.types.is_signed_integer(divided.type):
282-
# GH 56676
286+
div_type = cast("pa._lib.Int64Type", divided.type)
283287
has_remainder = pc.not_equal(pc.multiply(divided, right), left)
284288
has_one_negative_operand = pc.less(
285-
pc.bit_wise_xor(left, right), lit(0, type=divided.type)
289+
pc.bit_wise_xor(left, right), lit(0, div_type)
286290
)
287291
result = pc.if_else(
288292
pc.and_(has_remainder, has_one_negative_operand),
289-
pc.subtract(divided, lit(1, type=divided.type)),
293+
pc.subtract(divided, lit(1, div_type)),
290294
divided,
291295
)
292296
else:

narwhals/_compliant/group_by.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -195,9 +195,9 @@ def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any:
195195

196196

197197
class EagerGroupBy(
198-
DepthTrackingGroupBy[CompliantDataFrameT, EagerExprT_contra, str],
198+
DepthTrackingGroupBy[CompliantDataFrameT, EagerExprT_contra, NativeAggregationT_co],
199199
DataFrameGroupBy[CompliantDataFrameT, EagerExprT_contra],
200-
Protocol38[CompliantDataFrameT, EagerExprT_contra],
200+
Protocol38[CompliantDataFrameT, EagerExprT_contra, NativeAggregationT_co],
201201
): ...
202202

203203

narwhals/_pandas_like/group_by.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from narwhals._pandas_like.expr import PandasLikeExpr
2222

2323

24-
class PandasLikeGroupBy(EagerGroupBy["PandasLikeDataFrame", "PandasLikeExpr"]):
24+
class PandasLikeGroupBy(EagerGroupBy["PandasLikeDataFrame", "PandasLikeExpr", str]):
2525
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]] = {
2626
"sum": "sum",
2727
"mean": "mean",

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ typing = [
6868
"typing_extensions",
6969
"mypy~=1.15.0",
7070
"pyright",
71-
"pyarrow-stubs==19.1",
71+
"pyarrow-stubs==19.2",
7272
"sqlframe",
7373
"polars==1.25.2",
7474
"uv",

0 commit comments

Comments
 (0)