Skip to content

Commit fa30706

Browse files
authored
Merge branch 'main' into refac-pandas-concat
2 parents 1fa9f8f + 939e450 commit fa30706

File tree

7 files changed

+191
-133
lines changed

7 files changed

+191
-133
lines changed

narwhals/_arrow/series.py

Lines changed: 46 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,58 +1065,20 @@ def hist( # noqa: PLR0915
10651065

10661066
def _hist_from_bin_count(bin_count: int): # type: ignore[no-untyped-def] # noqa: ANN202
10671067
d = pc.min_max(self.native)
1068-
lower, upper = d["min"], d["max"]
1069-
pa_float = pa.type_for_alias("float")
1068+
lower, upper = d["min"].as_py(), d["max"].as_py()
10701069
if lower == upper:
1071-
range_: pa.Scalar[Any] = lit(1.0)
1072-
mid = lit(0.5)
1073-
width = pc.divide(range_, lit(bin_count))
1074-
lower = pc.subtract(lower, mid)
1075-
upper = pc.add(upper, mid)
1076-
else:
1077-
range_ = pc.subtract(upper, lower)
1078-
width = pc.divide(pc.cast(range_, pa_float), lit(float(bin_count)))
1079-
1080-
bin_proportions = pc.divide(pc.subtract(self.native, lower), width)
1081-
bin_indices = pc.floor(bin_proportions)
1082-
1083-
# shift bins so they are right-closed
1084-
bin_indices = pc.if_else(
1085-
pc.and_(
1086-
pc.equal(bin_indices, bin_proportions),
1087-
pc.greater(bin_indices, lit(0)),
1088-
),
1089-
pc.subtract(bin_indices, lit(1)),
1090-
bin_indices,
1091-
)
1092-
possible = pa.Table.from_arrays(
1093-
[pa.Array.from_pandas(np.arange(bin_count, dtype="int64"))], ["values"]
1094-
)
1095-
counts = ( # count bin id occurrences
1096-
pa.Table.from_arrays(
1097-
pc.value_counts(bin_indices).flatten(),
1098-
names=["values", "counts"],
1099-
)
1100-
# nan values are implicitly dropped in value_counts
1101-
.filter(~pc.field("values").is_nan())
1102-
.cast(pa.schema([("values", pa.int64()), ("counts", pa.int64())]))
1103-
# align bin ids to all possible bin ids (populate in missing bins)
1104-
.join(possible, keys="values", join_type="right outer")
1105-
.sort_by("values")
1106-
)
1107-
# empty bin intervals should have a 0 count
1108-
counts_coalesce = cast(
1109-
"ArrowArray", pc.coalesce(counts.column("counts"), lit(0))
1110-
)
1111-
counts = counts.set_column(0, "counts", counts_coalesce)
1112-
1113-
# extract left/right side of the intervals
1114-
bin_left = pc.add(lower, pc.multiply(counts.column("values"), width))
1115-
bin_right = pc.add(bin_left, width)
1116-
return counts.column("counts"), bin_right
1070+
lower -= 0.5
1071+
upper += 0.5
1072+
bins = np.linspace(lower, upper, bin_count + 1)
1073+
return _hist_from_bins(bins)
11171074

11181075
def _hist_from_bins(bins: Sequence[int | float]): # type: ignore[no-untyped-def] # noqa: ANN202
11191076
bin_indices = np.searchsorted(bins, self.native, side="left")
1077+
bin_indices = pc.if_else( # lowest bin is inclusive
1078+
pc.equal(self.native, lit(bins[0])), 1, bin_indices
1079+
)
1080+
1081+
# align unique categories and counts appropriately
11201082
obs_cats, obs_counts = np.unique(bin_indices, return_counts=True)
11211083
obj_cats = np.arange(1, len(bins))
11221084
counts = np.zeros_like(obj_cats)
@@ -1125,15 +1087,51 @@ def _hist_from_bins(bins: Sequence[int | float]): # type: ignore[no-untyped-def
11251087
bin_right = bins[1:]
11261088
return counts, bin_right
11271089

1090+
counts: Sequence[int | float] | np.typing.ArrayLike
1091+
bin_right: Sequence[int | float] | np.typing.ArrayLike
1092+
1093+
data_count = pc.sum(
1094+
pc.invert(pc.or_(pc.is_nan(self.native), pc.is_null(self.native))).cast(
1095+
pa.uint8()
1096+
),
1097+
min_count=0,
1098+
)
11281099
if bins is not None:
11291100
if len(bins) < 2:
11301101
counts, bin_right = [], []
1102+
1103+
elif data_count == pa.scalar(0, type=pa.uint64()): # type:ignore[comparison-overlap]
1104+
counts = np.zeros(len(bins) - 1)
1105+
bin_right = bins[1:]
1106+
1107+
elif len(bins) == 2:
1108+
counts = [
1109+
pc.sum(
1110+
pc.and_(
1111+
pc.greater_equal(self.native, lit(float(bins[0]))),
1112+
pc.less_equal(self.native, lit(float(bins[1]))),
1113+
).cast(pa.uint8())
1114+
)
1115+
]
1116+
bin_right = [bins[-1]]
11311117
else:
11321118
counts, bin_right = _hist_from_bins(bins)
11331119

11341120
elif bin_count is not None:
11351121
if bin_count == 0:
11361122
counts, bin_right = [], []
1123+
elif data_count == pa.scalar(0, type=pa.uint64()): # type:ignore[comparison-overlap]
1124+
counts, bin_right = (
1125+
np.zeros(bin_count),
1126+
np.linspace(0, 1, bin_count + 1)[1:],
1127+
)
1128+
elif bin_count == 1:
1129+
d = pc.min_max(self.native)
1130+
lower, upper = d["min"], d["max"]
1131+
if lower == upper:
1132+
counts, bin_right = [data_count], [pc.add(upper, pa.scalar(0.5))]
1133+
else:
1134+
counts, bin_right = [data_count], [upper]
11371135
else:
11381136
counts, bin_right = _hist_from_bin_count(bin_count)
11391137

narwhals/_dask/expr.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -644,9 +644,18 @@ def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
644644
message=".*`meta` is not specified",
645645
category=UserWarning,
646646
)
647-
res_native = df.native.groupby(partition_by)[
648-
list(output_names)
649-
].transform(dask_function_name, **self._call_kwargs)
647+
grouped = df.native.groupby(partition_by)
648+
if dask_function_name == "size":
649+
if len(output_names) != 1: # pragma: no cover
650+
msg = "Safety check failed, please report a bug."
651+
raise AssertionError(msg)
652+
res_native = grouped.transform(
653+
dask_function_name, **self._call_kwargs
654+
).to_frame(output_names[0])
655+
else:
656+
res_native = grouped[list(output_names)].transform(
657+
dask_function_name, **self._call_kwargs
658+
)
650659
result_frame = df._with_native(
651660
res_native.rename(columns=dict(zip(output_names, aliases)))
652661
).native

narwhals/_pandas_like/expr.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def cum_sum(self: Self, *, reverse: bool) -> Self:
198198
def shift(self: Self, n: int) -> Self:
199199
return self._reuse_series("shift", call_kwargs={"n": n})
200200

201-
def over(
201+
def over( # noqa: PLR0915
202202
self: Self,
203203
partition_by: Sequence[str],
204204
order_by: Sequence[str] | None,
@@ -265,21 +265,25 @@ def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
265265
elif reverse:
266266
columns = list(set(partition_by).union(output_names))
267267
df = df[columns][::-1]
268+
grouped = df._native_frame.groupby(partition_by)
268269
if function_name.startswith("rolling"):
269-
rolling = df._native_frame.groupby(partition_by)[
270-
list(output_names)
271-
].rolling(**pandas_kwargs)
270+
rolling = grouped[list(output_names)].rolling(**pandas_kwargs)
272271
assert pandas_function_name is not None # help mypy # noqa: S101
273272
if pandas_function_name in {"std", "var"}:
274273
res_native = getattr(rolling, pandas_function_name)(
275274
ddof=self._call_kwargs["ddof"]
276275
)
277276
else:
278277
res_native = getattr(rolling, pandas_function_name)()
278+
elif function_name == "len":
279+
if len(output_names) != 1: # pragma: no cover
280+
msg = "Safety check failed, please report a bug."
281+
raise AssertionError(msg)
282+
res_native = grouped.transform("size").to_frame(aliases[0])
279283
else:
280-
res_native = df._native_frame.groupby(partition_by)[
281-
list(output_names)
282-
].transform(pandas_function_name, **pandas_kwargs)
284+
res_native = grouped[list(output_names)].transform(
285+
pandas_function_name, **pandas_kwargs
286+
)
283287
result_frame = df._with_native(res_native).rename(
284288
dict(zip(output_names, aliases))
285289
)

narwhals/_pandas_like/series.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -966,34 +966,50 @@ def hist(
966966
data["breakpoint"] = []
967967
data["count"] = []
968968
return PandasLikeDataFrame.from_native(ns.DataFrame(data), context=self)
969-
elif self.native.count() < 1:
969+
970+
if self.native.count() < 1:
970971
if bins is not None:
971972
data = {"breakpoint": bins[1:], "count": zeros(shape=len(bins) - 1)}
972973
else:
973974
count = cast("int", bin_count)
974-
data = {"breakpoint": linspace(0, 1, count), "count": zeros(shape=count)}
975+
if bin_count == 1:
976+
data = {"breakpoint": [1.0], "count": [0]}
977+
else:
978+
data = {
979+
"breakpoint": linspace(0, 1, count + 1)[1:],
980+
"count": zeros(shape=count),
981+
}
975982
if not include_breakpoint:
976983
del data["breakpoint"]
977984
return PandasLikeDataFrame.from_native(ns.DataFrame(data), context=self)
978985

979-
elif bin_count is not None: # use Polars binning behavior
986+
if bin_count is not None:
987+
# use Polars binning behavior
980988
lower, upper = self.native.min(), self.native.max()
981-
pad_lowest_bin = False
982989
if lower == upper:
983990
lower -= 0.5
984991
upper += 0.5
985-
else:
986-
pad_lowest_bin = True
992+
993+
if bin_count == 1:
994+
data = {
995+
"breakpoint": [upper],
996+
"count": [self.native.count()],
997+
}
998+
if not include_breakpoint:
999+
del data["breakpoint"]
1000+
return PandasLikeDataFrame.from_native(ns.DataFrame(data), context=self)
9871001

9881002
bins = linspace(lower, upper, bin_count + 1)
989-
if pad_lowest_bin and bins is not None:
990-
bins[0] -= 0.001 * abs(bins[0]) if bins[0] != 0 else 0.001
9911003
bin_count = None
9921004

9931005
# pandas (2.2.*) .value_counts(bins=int) adjusts the lowest bin twice, result in improper counts.
9941006
# pandas (2.2.*) .value_counts(bins=[...]) adjusts the lowest bin which should not happen since
9951007
# the bins were explicitly passed in.
996-
categories = ns.cut(self.native, bins=bins if bin_count is None else bin_count)
1008+
categories = ns.cut(
1009+
self.native,
1010+
bins=bins if bin_count is None else bin_count,
1011+
include_lowest=True, # Polars 1.27.0 always includes the lowest bin
1012+
)
9971013
# modin (0.32.0) .value_counts(...) silently drops bins with empty observations, .reindex
9981014
# is necessary to restore these bins.
9991015
result = categories.value_counts(dropna=True, sort=False).reindex(

narwhals/_polars/series.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -499,18 +499,26 @@ def hist(
499499
data.append(pl.Series("breakpoint", [], dtype=pl.Float64))
500500
data.append(pl.Series("count", [], dtype=pl.UInt32))
501501
return PolarsDataFrame.from_native(pl.DataFrame(data), context=self)
502-
elif (self._backend_version < (1, 15)) and self.native.count() < 1:
502+
503+
if self.native.count() < 1:
503504
data_dict: dict[str, Sequence[Any] | pl.Series]
504505
if bins is not None:
505506
data_dict = {
506507
"breakpoint": bins[1:],
507508
"count": pl.zeros(n=len(bins) - 1, dtype=pl.Int64, eager=True),
508509
}
509-
elif bin_count is not None:
510+
elif (bin_count is not None) and bin_count == 1:
511+
data_dict = {"breakpoint": [1.0], "count": [0]}
512+
elif (bin_count is not None) and bin_count > 1:
510513
data_dict = {
511-
"breakpoint": pl.int_range(0, bin_count, eager=True) / bin_count,
514+
"breakpoint": pl.int_range(1, bin_count + 1, eager=True) / bin_count,
512515
"count": pl.zeros(n=bin_count, dtype=pl.Int64, eager=True),
513516
}
517+
else: # pragma: no cover
518+
msg = (
519+
"congratulations, you entered unreachable code - please report a bug"
520+
)
521+
raise AssertionError(msg)
514522
if not include_breakpoint:
515523
del data_dict["breakpoint"]
516524
return PolarsDataFrame.from_native(pl.DataFrame(data_dict), context=self)
@@ -519,25 +527,19 @@ def hist(
519527
# polars <1.5 with bin_count=...
520528
# returns bins that range from -inf to +inf and has bin_count + 1 bins.
521529
# for compat: convert `bin_count=` call to `bins=`
522-
if (
523-
(self._backend_version < (1, 15))
524-
and (bin_count is not None)
525-
and (self.native.count() > 0)
530+
if (self._backend_version < (1, 15)) and (
531+
bin_count is not None
526532
): # pragma: no cover
527533
lower = cast("float", self.native.min())
528534
upper = cast("float", self.native.max())
529-
pad_lowest_bin = False
530535
if lower == upper:
531536
width = 1 / bin_count
532537
lower -= 0.5
533538
upper += 0.5
534539
else:
535-
pad_lowest_bin = True
536540
width = (upper - lower) / bin_count
537541

538542
bins = (pl.int_range(0, bin_count + 1, eager=True) * width + lower).to_list()
539-
if pad_lowest_bin:
540-
bins[0] -= 0.001 * abs(bins[0]) if bins[0] != 0 else 0.001
541543
bin_count = None
542544

543545
# Polars inconsistently handles NaN values when computing histograms
@@ -552,16 +554,22 @@ def hist(
552554
include_category=False,
553555
include_breakpoint=include_breakpoint,
554556
)
557+
555558
if not include_breakpoint:
556559
df.columns = ["count"]
557560

561+
if self._backend_version < (1, 0) and include_breakpoint:
562+
df = df.rename({"break_point": "breakpoint"})
563+
558564
# polars<1.15 implicitly adds -inf and inf to either end of bins
559565
if self._backend_version < (1, 15) and bins is not None: # pragma: no cover
560566
r = pl.int_range(0, len(df))
561567
df = df.filter((r > 0) & (r < len(df) - 1))
562568

563-
if self._backend_version < (1, 0) and include_breakpoint:
564-
df = df.rename({"break_point": "breakpoint"})
569+
# polars<1.27 makes the lowest bin a left/right closed interval.
570+
if self._backend_version < (1, 27) and bins is not None:
571+
df[0, "count"] += (series == bins[0]).sum()
572+
565573
return PolarsDataFrame.from_native(df, context=self)
566574

567575
def to_polars(self: Self) -> pl.Series:

tests/expr_and_series/over_test.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -428,3 +428,17 @@ def test_over_without_partition_by(
428428
)
429429
expected = {"a": [1, 2, -1], "b": [1, 3, 4], "i": [0, 1, 2]}
430430
assert_equal_data(result, expected)
431+
432+
433+
def test_len_over_2369(constructor: Constructor, request: pytest.FixtureRequest) -> None:
434+
if "duckdb" in str(constructor) and DUCKDB_VERSION < (1, 3):
435+
pytest.skip()
436+
if "pandas" in str(constructor) and PANDAS_VERSION < (1, 5):
437+
pytest.skip()
438+
if any(x in str(constructor) for x in ("modin",)):
439+
# https://github.com/modin-project/modin/issues/7508
440+
request.applymarker(pytest.mark.xfail)
441+
df = nw.from_native(constructor({"a": [1, 2, 4], "b": ["x", "x", "y"]}))
442+
result = df.with_columns(a_len_per_group=nw.len().over("b")).sort("a")
443+
expected = {"a": [1, 2, 4], "b": ["x", "x", "y"], "a_len_per_group": [2, 2, 1]}
444+
assert_equal_data(result, expected)

0 commit comments

Comments
 (0)