Skip to content

Commit 4185afe

Browse files
authored
fea: support multi index for dataframe where (#1881)
* feat: support multi index for dataframe where * fix test * fix * resolve the comments
1 parent e43d15d commit 4185afe

File tree

4 files changed

+165
-15
lines changed

4 files changed

+165
-15
lines changed

bigframes/dataframe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2741,9 +2741,9 @@ def where(self, cond, other=None):
27412741
if isinstance(other, bigframes.series.Series):
27422742
raise ValueError("Seires is not a supported replacement type!")
27432743

2744-
if self.columns.nlevels > 1 or self.index.nlevels > 1:
2744+
if self.columns.nlevels > 1:
27452745
raise NotImplementedError(
2746-
"The dataframe.where() method does not support multi-index and/or multi-column."
2746+
"The dataframe.where() method does not support multi-column."
27472747
)
27482748

27492749
aligned_block, (_, _) = self._block.join(cond._block, how="left")

tests/system/small/test_dataframe.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -375,15 +375,6 @@ def test_insert(scalars_dfs, loc, column, value, allow_duplicates):
375375
pd.testing.assert_frame_equal(bf_df.to_pandas(), pd_df, check_dtype=False)
376376

377377

378-
def test_where_series_cond(scalars_df_index, scalars_pandas_df_index):
379-
# Condition is dataframe, other is None (as default).
380-
cond_bf = scalars_df_index["int64_col"] > 0
381-
cond_pd = scalars_pandas_df_index["int64_col"] > 0
382-
bf_result = scalars_df_index.where(cond_bf).to_pandas()
383-
pd_result = scalars_pandas_df_index.where(cond_pd)
384-
pandas.testing.assert_frame_equal(bf_result, pd_result)
385-
386-
387378
def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index):
388379
cond_bf = scalars_df_index["int64_col"] > 0
389380
cond_pd = scalars_pandas_df_index["int64_col"] > 0
@@ -395,8 +386,8 @@ def test_mask_series_cond(scalars_df_index, scalars_pandas_df_index):
395386
pandas.testing.assert_frame_equal(bf_result, pd_result)
396387

397388

398-
def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index):
399-
# Test when a dataframe has multi-index or multi-columns.
389+
def test_where_multi_column(scalars_df_index, scalars_pandas_df_index):
390+
# Test when a dataframe has multi-columns.
400391
columns = ["int64_col", "float64_col"]
401392
dataframe_bf = scalars_df_index[columns]
402393

@@ -409,10 +400,19 @@ def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index):
409400
dataframe_bf.where(cond_bf).to_pandas()
410401
assert (
411402
str(context.value)
412-
== "The dataframe.where() method does not support multi-index and/or multi-column."
403+
== "The dataframe.where() method does not support multi-column."
413404
)
414405

415406

407+
def test_where_series_cond(scalars_df_index, scalars_pandas_df_index):
408+
# Condition is dataframe, other is None (as default).
409+
cond_bf = scalars_df_index["int64_col"] > 0
410+
cond_pd = scalars_pandas_df_index["int64_col"] > 0
411+
bf_result = scalars_df_index.where(cond_bf).to_pandas()
412+
pd_result = scalars_pandas_df_index.where(cond_pd)
413+
pandas.testing.assert_frame_equal(bf_result, pd_result)
414+
415+
416416
def test_where_series_cond_const_other(scalars_df_index, scalars_pandas_df_index):
417417
# Condition is a series, other is a constant.
418418
columns = ["int64_col", "float64_col"]

tests/system/small/test_multiindex.py

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,22 @@
1919
import bigframes.pandas as bpd
2020
from bigframes.testing.utils import assert_pandas_df_equal
2121

22+
# Sample MultiIndex for testing DataFrames where() method.
23+
_MULTI_INDEX = pandas.MultiIndex.from_tuples(
24+
[
25+
(0, "a"),
26+
(1, "b"),
27+
(2, "c"),
28+
(0, "d"),
29+
(1, "e"),
30+
(2, "f"),
31+
(0, "g"),
32+
(1, "h"),
33+
(2, "i"),
34+
],
35+
names=["A", "B"],
36+
)
37+
2238

2339
def test_multi_index_from_arrays():
2440
bf_idx = bpd.MultiIndex.from_arrays(
@@ -541,6 +557,140 @@ def test_multi_index_dataframe_join_on(scalars_dfs, how):
541557
assert_pandas_df_equal(bf_result, pd_result, ignore_order=True)
542558

543559

560+
def test_multi_index_dataframe_where_series_cond_none_other(
561+
scalars_df_index, scalars_pandas_df_index
562+
):
563+
columns = ["int64_col", "float64_col"]
564+
565+
# Create multi-index dataframe.
566+
dataframe_bf = bpd.DataFrame(
567+
scalars_df_index[columns].values,
568+
index=_MULTI_INDEX,
569+
columns=scalars_df_index[columns].columns,
570+
)
571+
dataframe_pd = pandas.DataFrame(
572+
scalars_pandas_df_index[columns].values,
573+
index=_MULTI_INDEX,
574+
columns=scalars_pandas_df_index[columns].columns,
575+
)
576+
dataframe_bf.columns.name = "test_name"
577+
dataframe_pd.columns.name = "test_name"
578+
579+
# When condition is series and other is None.
580+
series_cond_bf = dataframe_bf["int64_col"] > 0
581+
series_cond_pd = dataframe_pd["int64_col"] > 0
582+
583+
bf_result = dataframe_bf.where(series_cond_bf).to_pandas()
584+
pd_result = dataframe_pd.where(series_cond_pd)
585+
pandas.testing.assert_frame_equal(
586+
bf_result,
587+
pd_result,
588+
check_index_type=False,
589+
check_dtype=False,
590+
)
591+
# Assert the index is still MultiIndex after the operation.
592+
assert isinstance(bf_result.index, pandas.MultiIndex), "Expected a MultiIndex"
593+
assert isinstance(pd_result.index, pandas.MultiIndex), "Expected a MultiIndex"
594+
595+
596+
def test_multi_index_dataframe_where_series_cond_dataframe_other(
597+
scalars_df_index, scalars_pandas_df_index
598+
):
599+
columns = ["int64_col", "int64_too"]
600+
601+
# Create multi-index dataframe.
602+
dataframe_bf = bpd.DataFrame(
603+
scalars_df_index[columns].values,
604+
index=_MULTI_INDEX,
605+
columns=scalars_df_index[columns].columns,
606+
)
607+
dataframe_pd = pandas.DataFrame(
608+
scalars_pandas_df_index[columns].values,
609+
index=_MULTI_INDEX,
610+
columns=scalars_pandas_df_index[columns].columns,
611+
)
612+
613+
# When condition is series and other is dataframe.
614+
series_cond_bf = dataframe_bf["int64_col"] > 1000.0
615+
series_cond_pd = dataframe_pd["int64_col"] > 1000.0
616+
dataframe_other_bf = dataframe_bf * 100.0
617+
dataframe_other_pd = dataframe_pd * 100.0
618+
619+
bf_result = dataframe_bf.where(series_cond_bf, dataframe_other_bf).to_pandas()
620+
pd_result = dataframe_pd.where(series_cond_pd, dataframe_other_pd)
621+
pandas.testing.assert_frame_equal(
622+
bf_result,
623+
pd_result,
624+
check_index_type=False,
625+
check_dtype=False,
626+
)
627+
628+
629+
def test_multi_index_dataframe_where_dataframe_cond_constant_other(
630+
scalars_df_index, scalars_pandas_df_index
631+
):
632+
columns = ["int64_col", "float64_col"]
633+
634+
# Create multi-index dataframe.
635+
dataframe_bf = bpd.DataFrame(
636+
scalars_df_index[columns].values,
637+
index=_MULTI_INDEX,
638+
columns=scalars_df_index[columns].columns,
639+
)
640+
dataframe_pd = pandas.DataFrame(
641+
scalars_pandas_df_index[columns].values,
642+
index=_MULTI_INDEX,
643+
columns=scalars_pandas_df_index[columns].columns,
644+
)
645+
646+
# When condition is dataframe and other is a constant.
647+
dataframe_cond_bf = dataframe_bf > 0
648+
dataframe_cond_pd = dataframe_pd > 0
649+
other = 0
650+
651+
bf_result = dataframe_bf.where(dataframe_cond_bf, other).to_pandas()
652+
pd_result = dataframe_pd.where(dataframe_cond_pd, other)
653+
pandas.testing.assert_frame_equal(
654+
bf_result,
655+
pd_result,
656+
check_index_type=False,
657+
check_dtype=False,
658+
)
659+
660+
661+
def test_multi_index_dataframe_where_dataframe_cond_dataframe_other(
662+
scalars_df_index, scalars_pandas_df_index
663+
):
664+
columns = ["int64_col", "int64_too", "float64_col"]
665+
666+
# Create multi-index dataframe.
667+
dataframe_bf = bpd.DataFrame(
668+
scalars_df_index[columns].values,
669+
index=_MULTI_INDEX,
670+
columns=scalars_df_index[columns].columns,
671+
)
672+
dataframe_pd = pandas.DataFrame(
673+
scalars_pandas_df_index[columns].values,
674+
index=_MULTI_INDEX,
675+
columns=scalars_pandas_df_index[columns].columns,
676+
)
677+
678+
# When condition is dataframe and other is dataframe.
679+
dataframe_cond_bf = dataframe_bf < 1000.0
680+
dataframe_cond_pd = dataframe_pd < 1000.0
681+
dataframe_other_bf = dataframe_bf * -1.0
682+
dataframe_other_pd = dataframe_pd * -1.0
683+
684+
bf_result = dataframe_bf.where(dataframe_cond_bf, dataframe_other_bf).to_pandas()
685+
pd_result = dataframe_pd.where(dataframe_cond_pd, dataframe_other_pd)
686+
pandas.testing.assert_frame_equal(
687+
bf_result,
688+
pd_result,
689+
check_index_type=False,
690+
check_dtype=False,
691+
)
692+
693+
544694
@pytest.mark.parametrize(
545695
("level",),
546696
[

tests/unit/test_dataframe_polars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ def test_where_series_multi_index(scalars_df_index, scalars_pandas_df_index):
364364
dataframe_bf.where(cond_bf).to_pandas()
365365
assert (
366366
str(context.value)
367-
== "The dataframe.where() method does not support multi-index and/or multi-column."
367+
== "The dataframe.where() method does not support multi-column."
368368
)
369369

370370

0 commit comments

Comments
 (0)