Skip to content

Commit 8aff128

Browse files
authored
feat: add support for the 'right' parameter in 'pandas.cut' (#1496)
1 parent c382a44 commit 8aff128

File tree

5 files changed

+171
-64
lines changed

5 files changed

+171
-64
lines changed

bigframes/core/compile/aggregate_compiler.py

Lines changed: 40 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -364,8 +364,12 @@ def _(
364364

365365
if op.labels is False:
366366
for this_bin in range(op.bins - 1):
367+
if op.right:
368+
case_expr = x <= (col_min + (this_bin + 1) * bin_width)
369+
else:
370+
case_expr = x < (col_min + (this_bin + 1) * bin_width)
367371
out = out.when(
368-
x <= (col_min + (this_bin + 1) * bin_width),
372+
case_expr,
369373
compile_ibis_types.literal_to_ibis_scalar(
370374
this_bin, force_dtype=pd.Int64Dtype()
371375
),
@@ -375,32 +379,49 @@ def _(
375379
interval_struct = None
376380
adj = (col_max - col_min) * 0.001
377381
for this_bin in range(op.bins):
378-
left_edge = (
379-
col_min + this_bin * bin_width - (0 if this_bin > 0 else adj)
380-
)
381-
right_edge = col_min + (this_bin + 1) * bin_width
382-
interval_struct = ibis_types.struct(
383-
{
384-
"left_exclusive": left_edge,
385-
"right_inclusive": right_edge,
386-
}
387-
)
382+
left_edge_adj = adj if this_bin == 0 and op.right else 0
383+
right_edge_adj = adj if this_bin == op.bins - 1 and not op.right else 0
384+
385+
left_edge = col_min + this_bin * bin_width - left_edge_adj
386+
right_edge = col_min + (this_bin + 1) * bin_width + right_edge_adj
387+
388+
if op.right:
389+
interval_struct = ibis_types.struct(
390+
{
391+
"left_exclusive": left_edge,
392+
"right_inclusive": right_edge,
393+
}
394+
)
395+
else:
396+
interval_struct = ibis_types.struct(
397+
{
398+
"left_inclusive": left_edge,
399+
"right_exclusive": right_edge,
400+
}
401+
)
388402

389403
if this_bin < op.bins - 1:
390-
out = out.when(
391-
x <= (col_min + (this_bin + 1) * bin_width),
392-
interval_struct,
393-
)
404+
if op.right:
405+
case_expr = x <= (col_min + (this_bin + 1) * bin_width)
406+
else:
407+
case_expr = x < (col_min + (this_bin + 1) * bin_width)
408+
out = out.when(case_expr, interval_struct)
394409
else:
395410
out = out.when(x.notnull(), interval_struct)
396411
else: # Interpret as intervals
397412
for interval in op.bins:
398413
left = compile_ibis_types.literal_to_ibis_scalar(interval[0])
399414
right = compile_ibis_types.literal_to_ibis_scalar(interval[1])
400-
condition = (x > left) & (x <= right)
401-
interval_struct = ibis_types.struct(
402-
{"left_exclusive": left, "right_inclusive": right}
403-
)
415+
if op.right:
416+
condition = (x > left) & (x <= right)
417+
interval_struct = ibis_types.struct(
418+
{"left_exclusive": left, "right_inclusive": right}
419+
)
420+
else:
421+
condition = (x >= left) & (x < right)
422+
interval_struct = ibis_types.struct(
423+
{"left_inclusive": left, "right_exclusive": right}
424+
)
404425
out = out.when(condition, interval_struct)
405426
return out.end()
406427

bigframes/core/reshape/tile.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from __future__ import annotations
1616

1717
import typing
18-
from typing import Iterable, Optional, Union
1918

2019
import bigframes_vendored.constants as constants
2120
import bigframes_vendored.pandas.core.reshape.tile as vendored_pandas_tile
@@ -33,26 +32,34 @@
3332

3433
def cut(
3534
x: bigframes.series.Series,
36-
bins: Union[
35+
bins: typing.Union[
3736
int,
3837
pd.IntervalIndex,
39-
Iterable,
38+
typing.Iterable,
4039
],
4140
*,
42-
labels: Union[Iterable[str], bool, None] = None,
41+
right: typing.Optional[bool] = True,
42+
labels: typing.Union[typing.Iterable[str], bool, None] = None,
4343
) -> bigframes.series.Series:
4444
if isinstance(bins, int) and bins <= 0:
4545
raise ValueError("`bins` should be a positive integer.")
4646

47-
if isinstance(bins, Iterable):
47+
# TODO: Check `right` does not apply for IntervalIndex.
48+
49+
if isinstance(bins, typing.Iterable):
4850
if isinstance(bins, pd.IntervalIndex):
51+
# TODO: test an empty internval index
4952
as_index: pd.IntervalIndex = bins
5053
bins = tuple((bin.left.item(), bin.right.item()) for bin in bins)
54+
# To maintain consistency with pandas' behavior
55+
right = True
5156
elif len(list(bins)) == 0:
5257
raise ValueError("`bins` iterable should have at least one item")
5358
elif isinstance(list(bins)[0], tuple):
5459
as_index = pd.IntervalIndex.from_tuples(list(bins))
5560
bins = tuple(bins)
61+
# To maintain consistency with pandas' behavior
62+
right = True
5663
elif pd.api.types.is_number(list(bins)[0]):
5764
bins_list = list(bins)
5865
if len(bins_list) < 2:
@@ -82,7 +89,8 @@ def cut(
8289
)
8390

8491
return x._apply_window_op(
85-
agg_ops.CutOp(bins, labels=labels), window_spec=window_specs.unbound()
92+
agg_ops.CutOp(bins, right=right, labels=labels),
93+
window_spec=window_specs.unbound(),
8694
)
8795

8896

@@ -93,7 +101,7 @@ def qcut(
93101
x: bigframes.series.Series,
94102
q: typing.Union[int, typing.Sequence[float]],
95103
*,
96-
labels: Optional[bool] = None,
104+
labels: typing.Optional[bool] = None,
97105
duplicates: typing.Literal["drop", "error"] = "error",
98106
) -> bigframes.series.Series:
99107
if isinstance(q, int) and q <= 0:

bigframes/operations/aggregations.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
339339
class CutOp(UnaryWindowOp):
340340
# TODO: Unintuitive, refactor into multiple ops?
341341
bins: typing.Union[int, Iterable]
342+
right: Optional[bool]
342343
labels: Optional[bool]
343344

344345
@property
@@ -357,10 +358,19 @@ def output_type(self, *input_types: dtypes.ExpressionType) -> dtypes.ExpressionT
357358
)
358359
pa_type = pa.struct(
359360
[
360-
pa.field("left_exclusive", interval_dtype, nullable=True),
361-
pa.field("right_inclusive", interval_dtype, nullable=True),
361+
pa.field(
362+
"left_exclusive" if self.right else "left_inclusive",
363+
interval_dtype,
364+
nullable=True,
365+
),
366+
pa.field(
367+
"right_inclusive" if self.right else "right_exclusive",
368+
interval_dtype,
369+
nullable=True,
370+
),
362371
]
363372
)
373+
364374
return pd.ArrowDtype(pa_type)
365375

366376
@property

tests/system/small/test_pandas.py

Lines changed: 75 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -387,33 +387,52 @@ def test_merge_series(scalars_dfs, merge_how):
387387
assert_pandas_df_equal(bf_result, pd_result, ignore_order=True)
388388

389389

390-
def test_cut(scalars_dfs):
390+
@pytest.mark.parametrize(
391+
("right"),
392+
[
393+
pytest.param(True),
394+
pytest.param(False),
395+
],
396+
)
397+
def test_cut(scalars_dfs, right):
391398
scalars_df, scalars_pandas_df = scalars_dfs
392399

393-
pd_result = pd.cut(scalars_pandas_df["float64_col"], 5, labels=False)
394-
bf_result = bpd.cut(scalars_df["float64_col"], 5, labels=False)
400+
pd_result = pd.cut(scalars_pandas_df["float64_col"], 5, labels=False, right=right)
401+
bf_result = bpd.cut(scalars_df["float64_col"], 5, labels=False, right=right)
395402

396403
# make sure the result is a supported dtype
397404
assert bf_result.dtype == bpd.Int64Dtype()
398405
pd_result = pd_result.astype("Int64")
399406
pd.testing.assert_series_equal(bf_result.to_pandas(), pd_result)
400407

401408

402-
def test_cut_default_labels(scalars_dfs):
409+
@pytest.mark.parametrize(
410+
("right"),
411+
[
412+
pytest.param(True),
413+
pytest.param(False),
414+
],
415+
)
416+
def test_cut_default_labels(scalars_dfs, right):
403417
scalars_df, scalars_pandas_df = scalars_dfs
404418

405-
pd_result = pd.cut(scalars_pandas_df["float64_col"], 5)
406-
bf_result = bpd.cut(scalars_df["float64_col"], 5).to_pandas()
419+
pd_result = pd.cut(scalars_pandas_df["float64_col"], 5, right=right)
420+
bf_result = bpd.cut(scalars_df["float64_col"], 5, right=right).to_pandas()
407421

408422
# Convert to match data format
423+
pd_interval = pd_result.cat.categories[pd_result.cat.codes]
424+
if pd_interval.closed == "left":
425+
left_key = "left_inclusive"
426+
right_key = "right_exclusive"
427+
else:
428+
left_key = "left_exclusive"
429+
right_key = "right_inclusive"
409430
pd_result_converted = pd.Series(
410431
[
411-
{"left_exclusive": interval.left, "right_inclusive": interval.right}
432+
{left_key: interval.left, right_key: interval.right}
412433
if pd.notna(val)
413434
else pd.NA
414-
for val, interval in zip(
415-
pd_result, pd_result.cat.categories[pd_result.cat.codes]
416-
)
435+
for val, interval in zip(pd_result, pd_interval)
417436
],
418437
name=pd_result.name,
419438
)
@@ -424,28 +443,35 @@ def test_cut_default_labels(scalars_dfs):
424443

425444

426445
@pytest.mark.parametrize(
427-
("breaks",),
446+
("breaks", "right"),
428447
[
429-
([0, 5, 10, 15, 20, 100, 1000],), # ints
430-
([0.5, 10.5, 15.5, 20.5, 100.5, 1000.5],), # floats
431-
([0, 5, 10.5, 15.5, 20, 100, 1000.5],), # mixed
448+
pytest.param([0, 5, 10, 15, 20, 100, 1000], True, id="int_right"),
449+
pytest.param([0, 5, 10, 15, 20, 100, 1000], False, id="int_left"),
450+
pytest.param([0.5, 10.5, 15.5, 20.5, 100.5, 1000.5], False, id="float_left"),
451+
pytest.param([0, 5, 10.5, 15.5, 20, 100, 1000.5], True, id="mixed_right"),
432452
],
433453
)
434-
def test_cut_numeric_breaks(scalars_dfs, breaks):
454+
def test_cut_numeric_breaks(scalars_dfs, breaks, right):
435455
scalars_df, scalars_pandas_df = scalars_dfs
436456

437-
pd_result = pd.cut(scalars_pandas_df["float64_col"], breaks)
438-
bf_result = bpd.cut(scalars_df["float64_col"], breaks).to_pandas()
457+
pd_result = pd.cut(scalars_pandas_df["float64_col"], breaks, right=right)
458+
bf_result = bpd.cut(scalars_df["float64_col"], breaks, right=right).to_pandas()
439459

440460
# Convert to match data format
461+
pd_interval = pd_result.cat.categories[pd_result.cat.codes]
462+
if pd_interval.closed == "left":
463+
left_key = "left_inclusive"
464+
right_key = "right_exclusive"
465+
else:
466+
left_key = "left_exclusive"
467+
right_key = "right_inclusive"
468+
441469
pd_result_converted = pd.Series(
442470
[
443-
{"left_exclusive": interval.left, "right_inclusive": interval.right}
471+
{left_key: interval.left, right_key: interval.right}
444472
if pd.notna(val)
445473
else pd.NA
446-
for val, interval in zip(
447-
pd_result, pd_result.cat.categories[pd_result.cat.codes]
448-
)
474+
for val, interval in zip(pd_result, pd_interval)
449475
],
450476
name=pd_result.name,
451477
)
@@ -476,29 +502,47 @@ def test_cut_errors(scalars_dfs, bins):
476502

477503

478504
@pytest.mark.parametrize(
479-
("bins",),
505+
("bins", "right"),
480506
[
481-
([(-5, 2), (2, 3), (-3000, -10)],),
482-
(pd.IntervalIndex.from_tuples([(1, 2), (2, 3), (4, 5)]),),
507+
pytest.param([(-5, 2), (2, 3), (-3000, -10)], True, id="tuple_right"),
508+
pytest.param([(-5, 2), (2, 3), (-3000, -10)], False, id="tuple_left"),
509+
pytest.param(
510+
pd.IntervalIndex.from_tuples([(1, 2), (2, 3), (4, 5)]),
511+
True,
512+
id="interval_right",
513+
),
514+
pytest.param(
515+
pd.IntervalIndex.from_tuples([(1, 2), (2, 3), (4, 5)]),
516+
False,
517+
id="interval_left",
518+
),
483519
],
484520
)
485-
def test_cut_with_interval(scalars_dfs, bins):
521+
def test_cut_with_interval(scalars_dfs, bins, right):
486522
scalars_df, scalars_pandas_df = scalars_dfs
487-
bf_result = bpd.cut(scalars_df["int64_too"], bins, labels=False).to_pandas()
523+
bf_result = bpd.cut(
524+
scalars_df["int64_too"], bins, labels=False, right=right
525+
).to_pandas()
488526

489527
if isinstance(bins, list):
490528
bins = pd.IntervalIndex.from_tuples(bins)
491-
pd_result = pd.cut(scalars_pandas_df["int64_too"], bins, labels=False)
529+
pd_result = pd.cut(scalars_pandas_df["int64_too"], bins, labels=False, right=right)
492530

493531
# Convert to match data format
532+
pd_interval = pd_result.cat.categories[pd_result.cat.codes]
533+
if pd_interval.closed == "left":
534+
left_key = "left_inclusive"
535+
right_key = "right_exclusive"
536+
else:
537+
left_key = "left_exclusive"
538+
right_key = "right_inclusive"
539+
494540
pd_result_converted = pd.Series(
495541
[
496-
{"left_exclusive": interval.left, "right_inclusive": interval.right}
542+
{left_key: interval.left, right_key: interval.right}
497543
if pd.notna(val)
498544
else pd.NA
499-
for val, interval in zip(
500-
pd_result, pd_result.cat.categories[pd_result.cat.codes]
501-
)
545+
for val, interval in zip(pd_result, pd_interval)
502546
],
503547
name=pd_result.name,
504548
)

0 commit comments

Comments
 (0)