Skip to content

Commit 939e450

Browse files
Fix/hist new polars (#2374)
* enh hist to always include lowest bin * fix hist(bin_count=…) when data is empty * add shuffling to simple smoke tests * add hist fastpaths for no data and minimal bins * simplify hist hypothesis tests * add coverage for unreachable code * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent ee535b4 commit 939e450

File tree

4 files changed

+154
-123
lines changed

4 files changed

+154
-123
lines changed

narwhals/_arrow/series.py

Lines changed: 46 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1065,58 +1065,20 @@ def hist( # noqa: PLR0915
10651065

10661066
def _hist_from_bin_count(bin_count: int): # type: ignore[no-untyped-def] # noqa: ANN202
10671067
d = pc.min_max(self.native)
1068-
lower, upper = d["min"], d["max"]
1069-
pa_float = pa.type_for_alias("float")
1068+
lower, upper = d["min"].as_py(), d["max"].as_py()
10701069
if lower == upper:
1071-
range_: pa.Scalar[Any] = lit(1.0)
1072-
mid = lit(0.5)
1073-
width = pc.divide(range_, lit(bin_count))
1074-
lower = pc.subtract(lower, mid)
1075-
upper = pc.add(upper, mid)
1076-
else:
1077-
range_ = pc.subtract(upper, lower)
1078-
width = pc.divide(pc.cast(range_, pa_float), lit(float(bin_count)))
1079-
1080-
bin_proportions = pc.divide(pc.subtract(self.native, lower), width)
1081-
bin_indices = pc.floor(bin_proportions)
1082-
1083-
# shift bins so they are right-closed
1084-
bin_indices = pc.if_else(
1085-
pc.and_(
1086-
pc.equal(bin_indices, bin_proportions),
1087-
pc.greater(bin_indices, lit(0)),
1088-
),
1089-
pc.subtract(bin_indices, lit(1)),
1090-
bin_indices,
1091-
)
1092-
possible = pa.Table.from_arrays(
1093-
[pa.Array.from_pandas(np.arange(bin_count, dtype="int64"))], ["values"]
1094-
)
1095-
counts = ( # count bin id occurrences
1096-
pa.Table.from_arrays(
1097-
pc.value_counts(bin_indices).flatten(),
1098-
names=["values", "counts"],
1099-
)
1100-
# nan values are implicitly dropped in value_counts
1101-
.filter(~pc.field("values").is_nan())
1102-
.cast(pa.schema([("values", pa.int64()), ("counts", pa.int64())]))
1103-
# align bin ids to all possible bin ids (populate in missing bins)
1104-
.join(possible, keys="values", join_type="right outer")
1105-
.sort_by("values")
1106-
)
1107-
# empty bin intervals should have a 0 count
1108-
counts_coalesce = cast(
1109-
"ArrowArray", pc.coalesce(counts.column("counts"), lit(0))
1110-
)
1111-
counts = counts.set_column(0, "counts", counts_coalesce)
1112-
1113-
# extract left/right side of the intervals
1114-
bin_left = pc.add(lower, pc.multiply(counts.column("values"), width))
1115-
bin_right = pc.add(bin_left, width)
1116-
return counts.column("counts"), bin_right
1070+
lower -= 0.5
1071+
upper += 0.5
1072+
bins = np.linspace(lower, upper, bin_count + 1)
1073+
return _hist_from_bins(bins)
11171074

11181075
def _hist_from_bins(bins: Sequence[int | float]): # type: ignore[no-untyped-def] # noqa: ANN202
11191076
bin_indices = np.searchsorted(bins, self.native, side="left")
1077+
bin_indices = pc.if_else( # lowest bin is inclusive
1078+
pc.equal(self.native, lit(bins[0])), 1, bin_indices
1079+
)
1080+
1081+
# align unique categories and counts appropriately
11201082
obs_cats, obs_counts = np.unique(bin_indices, return_counts=True)
11211083
obj_cats = np.arange(1, len(bins))
11221084
counts = np.zeros_like(obj_cats)
@@ -1125,15 +1087,51 @@ def _hist_from_bins(bins: Sequence[int | float]): # type: ignore[no-untyped-def
11251087
bin_right = bins[1:]
11261088
return counts, bin_right
11271089

1090+
counts: Sequence[int | float] | np.typing.ArrayLike
1091+
bin_right: Sequence[int | float] | np.typing.ArrayLike
1092+
1093+
data_count = pc.sum(
1094+
pc.invert(pc.or_(pc.is_nan(self.native), pc.is_null(self.native))).cast(
1095+
pa.uint8()
1096+
),
1097+
min_count=0,
1098+
)
11281099
if bins is not None:
11291100
if len(bins) < 2:
11301101
counts, bin_right = [], []
1102+
1103+
elif data_count == pa.scalar(0, type=pa.uint64()): # type:ignore[comparison-overlap]
1104+
counts = np.zeros(len(bins) - 1)
1105+
bin_right = bins[1:]
1106+
1107+
elif len(bins) == 2:
1108+
counts = [
1109+
pc.sum(
1110+
pc.and_(
1111+
pc.greater_equal(self.native, lit(float(bins[0]))),
1112+
pc.less_equal(self.native, lit(float(bins[1]))),
1113+
).cast(pa.uint8())
1114+
)
1115+
]
1116+
bin_right = [bins[-1]]
11311117
else:
11321118
counts, bin_right = _hist_from_bins(bins)
11331119

11341120
elif bin_count is not None:
11351121
if bin_count == 0:
11361122
counts, bin_right = [], []
1123+
elif data_count == pa.scalar(0, type=pa.uint64()): # type:ignore[comparison-overlap]
1124+
counts, bin_right = (
1125+
np.zeros(bin_count),
1126+
np.linspace(0, 1, bin_count + 1)[1:],
1127+
)
1128+
elif bin_count == 1:
1129+
d = pc.min_max(self.native)
1130+
lower, upper = d["min"], d["max"]
1131+
if lower == upper:
1132+
counts, bin_right = [data_count], [pc.add(upper, pa.scalar(0.5))]
1133+
else:
1134+
counts, bin_right = [data_count], [upper]
11371135
else:
11381136
counts, bin_right = _hist_from_bin_count(bin_count)
11391137

narwhals/_pandas_like/series.py

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -966,34 +966,50 @@ def hist(
966966
data["breakpoint"] = []
967967
data["count"] = []
968968
return PandasLikeDataFrame.from_native(ns.DataFrame(data), context=self)
969-
elif self.native.count() < 1:
969+
970+
if self.native.count() < 1:
970971
if bins is not None:
971972
data = {"breakpoint": bins[1:], "count": zeros(shape=len(bins) - 1)}
972973
else:
973974
count = cast("int", bin_count)
974-
data = {"breakpoint": linspace(0, 1, count), "count": zeros(shape=count)}
975+
if bin_count == 1:
976+
data = {"breakpoint": [1.0], "count": [0]}
977+
else:
978+
data = {
979+
"breakpoint": linspace(0, 1, count + 1)[1:],
980+
"count": zeros(shape=count),
981+
}
975982
if not include_breakpoint:
976983
del data["breakpoint"]
977984
return PandasLikeDataFrame.from_native(ns.DataFrame(data), context=self)
978985

979-
elif bin_count is not None: # use Polars binning behavior
986+
if bin_count is not None:
987+
# use Polars binning behavior
980988
lower, upper = self.native.min(), self.native.max()
981-
pad_lowest_bin = False
982989
if lower == upper:
983990
lower -= 0.5
984991
upper += 0.5
985-
else:
986-
pad_lowest_bin = True
992+
993+
if bin_count == 1:
994+
data = {
995+
"breakpoint": [upper],
996+
"count": [self.native.count()],
997+
}
998+
if not include_breakpoint:
999+
del data["breakpoint"]
1000+
return PandasLikeDataFrame.from_native(ns.DataFrame(data), context=self)
9871001

9881002
bins = linspace(lower, upper, bin_count + 1)
989-
if pad_lowest_bin and bins is not None:
990-
bins[0] -= 0.001 * abs(bins[0]) if bins[0] != 0 else 0.001
9911003
bin_count = None
9921004

9931005
# pandas (2.2.*) .value_counts(bins=int) adjusts the lowest bin twice, result in improper counts.
9941006
# pandas (2.2.*) .value_counts(bins=[...]) adjusts the lowest bin which should not happen since
9951007
# the bins were explicitly passed in.
996-
categories = ns.cut(self.native, bins=bins if bin_count is None else bin_count)
1008+
categories = ns.cut(
1009+
self.native,
1010+
bins=bins if bin_count is None else bin_count,
1011+
include_lowest=True, # Polars 1.27.0 always includes the lowest bin
1012+
)
9971013
# modin (0.32.0) .value_counts(...) silently drops bins with empty observations, .reindex
9981014
# is necessary to restore these bins.
9991015
result = categories.value_counts(dropna=True, sort=False).reindex(

narwhals/_polars/series.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -499,18 +499,26 @@ def hist(
499499
data.append(pl.Series("breakpoint", [], dtype=pl.Float64))
500500
data.append(pl.Series("count", [], dtype=pl.UInt32))
501501
return PolarsDataFrame.from_native(pl.DataFrame(data), context=self)
502-
elif (self._backend_version < (1, 15)) and self.native.count() < 1:
502+
503+
if self.native.count() < 1:
503504
data_dict: dict[str, Sequence[Any] | pl.Series]
504505
if bins is not None:
505506
data_dict = {
506507
"breakpoint": bins[1:],
507508
"count": pl.zeros(n=len(bins) - 1, dtype=pl.Int64, eager=True),
508509
}
509-
elif bin_count is not None:
510+
elif (bin_count is not None) and bin_count == 1:
511+
data_dict = {"breakpoint": [1.0], "count": [0]}
512+
elif (bin_count is not None) and bin_count > 1:
510513
data_dict = {
511-
"breakpoint": pl.int_range(0, bin_count, eager=True) / bin_count,
514+
"breakpoint": pl.int_range(1, bin_count + 1, eager=True) / bin_count,
512515
"count": pl.zeros(n=bin_count, dtype=pl.Int64, eager=True),
513516
}
517+
else: # pragma: no cover
518+
msg = (
519+
"congratulations, you entered unreachable code - please report a bug"
520+
)
521+
raise AssertionError(msg)
514522
if not include_breakpoint:
515523
del data_dict["breakpoint"]
516524
return PolarsDataFrame.from_native(pl.DataFrame(data_dict), context=self)
@@ -519,25 +527,19 @@ def hist(
519527
# polars <1.5 with bin_count=...
520528
# returns bins that range from -inf to +inf and has bin_count + 1 bins.
521529
# for compat: convert `bin_count=` call to `bins=`
522-
if (
523-
(self._backend_version < (1, 15))
524-
and (bin_count is not None)
525-
and (self.native.count() > 0)
530+
if (self._backend_version < (1, 15)) and (
531+
bin_count is not None
526532
): # pragma: no cover
527533
lower = cast("float", self.native.min())
528534
upper = cast("float", self.native.max())
529-
pad_lowest_bin = False
530535
if lower == upper:
531536
width = 1 / bin_count
532537
lower -= 0.5
533538
upper += 0.5
534539
else:
535-
pad_lowest_bin = True
536540
width = (upper - lower) / bin_count
537541

538542
bins = (pl.int_range(0, bin_count + 1, eager=True) * width + lower).to_list()
539-
if pad_lowest_bin:
540-
bins[0] -= 0.001 * abs(bins[0]) if bins[0] != 0 else 0.001
541543
bin_count = None
542544

543545
# Polars inconsistently handles NaN values when computing histograms
@@ -552,16 +554,22 @@ def hist(
552554
include_category=False,
553555
include_breakpoint=include_breakpoint,
554556
)
557+
555558
if not include_breakpoint:
556559
df.columns = ["count"]
557560

561+
if self._backend_version < (1, 0) and include_breakpoint:
562+
df = df.rename({"break_point": "breakpoint"})
563+
558564
# polars<1.15 implicitly adds -inf and inf to either end of bins
559565
if self._backend_version < (1, 15) and bins is not None: # pragma: no cover
560566
r = pl.int_range(0, len(df))
561567
df = df.filter((r > 0) & (r < len(df) - 1))
562568

563-
if self._backend_version < (1, 0) and include_breakpoint:
564-
df = df.rename({"break_point": "breakpoint"})
569+
# polars<1.27 makes the lowest bin a left/right closed interval.
570+
if self._backend_version < (1, 27) and bins is not None:
571+
df[0, "count"] += (series == bins[0]).sum()
572+
565573
return PolarsDataFrame.from_native(df, context=self)
566574

567575
def to_polars(self: Self) -> pl.Series:

0 commit comments

Comments
 (0)