|
19 | 19 | import bigframes.pandas as bpd
|
20 | 20 | from bigframes.testing.utils import assert_pandas_df_equal
|
21 | 21 |
|
| 22 | +# Sample MultiIndex for testing DataFrames where() method. |
| 23 | +_MULTI_INDEX = pandas.MultiIndex.from_tuples( |
| 24 | + [ |
| 25 | + (0, "a"), |
| 26 | + (1, "b"), |
| 27 | + (2, "c"), |
| 28 | + (0, "d"), |
| 29 | + (1, "e"), |
| 30 | + (2, "f"), |
| 31 | + (0, "g"), |
| 32 | + (1, "h"), |
| 33 | + (2, "i"), |
| 34 | + ], |
| 35 | + names=["A", "B"], |
| 36 | +) |
| 37 | + |
22 | 38 |
|
23 | 39 | def test_multi_index_from_arrays():
|
24 | 40 | bf_idx = bpd.MultiIndex.from_arrays(
|
@@ -541,6 +557,140 @@ def test_multi_index_dataframe_join_on(scalars_dfs, how):
|
541 | 557 | assert_pandas_df_equal(bf_result, pd_result, ignore_order=True)
|
542 | 558 |
|
543 | 559 |
|
| 560 | +def test_multi_index_dataframe_where_series_cond_none_other( |
| 561 | + scalars_df_index, scalars_pandas_df_index |
| 562 | +): |
| 563 | + columns = ["int64_col", "float64_col"] |
| 564 | + |
| 565 | + # Create multi-index dataframe. |
| 566 | + dataframe_bf = bpd.DataFrame( |
| 567 | + scalars_df_index[columns].values, |
| 568 | + index=_MULTI_INDEX, |
| 569 | + columns=scalars_df_index[columns].columns, |
| 570 | + ) |
| 571 | + dataframe_pd = pandas.DataFrame( |
| 572 | + scalars_pandas_df_index[columns].values, |
| 573 | + index=_MULTI_INDEX, |
| 574 | + columns=scalars_pandas_df_index[columns].columns, |
| 575 | + ) |
| 576 | + dataframe_bf.columns.name = "test_name" |
| 577 | + dataframe_pd.columns.name = "test_name" |
| 578 | + |
| 579 | + # When condition is series and other is None. |
| 580 | + series_cond_bf = dataframe_bf["int64_col"] > 0 |
| 581 | + series_cond_pd = dataframe_pd["int64_col"] > 0 |
| 582 | + |
| 583 | + bf_result = dataframe_bf.where(series_cond_bf).to_pandas() |
| 584 | + pd_result = dataframe_pd.where(series_cond_pd) |
| 585 | + pandas.testing.assert_frame_equal( |
| 586 | + bf_result, |
| 587 | + pd_result, |
| 588 | + check_index_type=False, |
| 589 | + check_dtype=False, |
| 590 | + ) |
| 591 | + # Assert the index is still MultiIndex after the operation. |
| 592 | + assert isinstance(bf_result.index, pandas.MultiIndex), "Expected a MultiIndex" |
| 593 | + assert isinstance(pd_result.index, pandas.MultiIndex), "Expected a MultiIndex" |
| 594 | + |
| 595 | + |
| 596 | +def test_multi_index_dataframe_where_series_cond_dataframe_other( |
| 597 | + scalars_df_index, scalars_pandas_df_index |
| 598 | +): |
| 599 | + columns = ["int64_col", "int64_too"] |
| 600 | + |
| 601 | + # Create multi-index dataframe. |
| 602 | + dataframe_bf = bpd.DataFrame( |
| 603 | + scalars_df_index[columns].values, |
| 604 | + index=_MULTI_INDEX, |
| 605 | + columns=scalars_df_index[columns].columns, |
| 606 | + ) |
| 607 | + dataframe_pd = pandas.DataFrame( |
| 608 | + scalars_pandas_df_index[columns].values, |
| 609 | + index=_MULTI_INDEX, |
| 610 | + columns=scalars_pandas_df_index[columns].columns, |
| 611 | + ) |
| 612 | + |
| 613 | + # When condition is series and other is dataframe. |
| 614 | + series_cond_bf = dataframe_bf["int64_col"] > 1000.0 |
| 615 | + series_cond_pd = dataframe_pd["int64_col"] > 1000.0 |
| 616 | + dataframe_other_bf = dataframe_bf * 100.0 |
| 617 | + dataframe_other_pd = dataframe_pd * 100.0 |
| 618 | + |
| 619 | + bf_result = dataframe_bf.where(series_cond_bf, dataframe_other_bf).to_pandas() |
| 620 | + pd_result = dataframe_pd.where(series_cond_pd, dataframe_other_pd) |
| 621 | + pandas.testing.assert_frame_equal( |
| 622 | + bf_result, |
| 623 | + pd_result, |
| 624 | + check_index_type=False, |
| 625 | + check_dtype=False, |
| 626 | + ) |
| 627 | + |
| 628 | + |
| 629 | +def test_multi_index_dataframe_where_dataframe_cond_constant_other( |
| 630 | + scalars_df_index, scalars_pandas_df_index |
| 631 | +): |
| 632 | + columns = ["int64_col", "float64_col"] |
| 633 | + |
| 634 | + # Create multi-index dataframe. |
| 635 | + dataframe_bf = bpd.DataFrame( |
| 636 | + scalars_df_index[columns].values, |
| 637 | + index=_MULTI_INDEX, |
| 638 | + columns=scalars_df_index[columns].columns, |
| 639 | + ) |
| 640 | + dataframe_pd = pandas.DataFrame( |
| 641 | + scalars_pandas_df_index[columns].values, |
| 642 | + index=_MULTI_INDEX, |
| 643 | + columns=scalars_pandas_df_index[columns].columns, |
| 644 | + ) |
| 645 | + |
| 646 | + # When condition is dataframe and other is a constant. |
| 647 | + dataframe_cond_bf = dataframe_bf > 0 |
| 648 | + dataframe_cond_pd = dataframe_pd > 0 |
| 649 | + other = 0 |
| 650 | + |
| 651 | + bf_result = dataframe_bf.where(dataframe_cond_bf, other).to_pandas() |
| 652 | + pd_result = dataframe_pd.where(dataframe_cond_pd, other) |
| 653 | + pandas.testing.assert_frame_equal( |
| 654 | + bf_result, |
| 655 | + pd_result, |
| 656 | + check_index_type=False, |
| 657 | + check_dtype=False, |
| 658 | + ) |
| 659 | + |
| 660 | + |
| 661 | +def test_multi_index_dataframe_where_dataframe_cond_dataframe_other( |
| 662 | + scalars_df_index, scalars_pandas_df_index |
| 663 | +): |
| 664 | + columns = ["int64_col", "int64_too", "float64_col"] |
| 665 | + |
| 666 | + # Create multi-index dataframe. |
| 667 | + dataframe_bf = bpd.DataFrame( |
| 668 | + scalars_df_index[columns].values, |
| 669 | + index=_MULTI_INDEX, |
| 670 | + columns=scalars_df_index[columns].columns, |
| 671 | + ) |
| 672 | + dataframe_pd = pandas.DataFrame( |
| 673 | + scalars_pandas_df_index[columns].values, |
| 674 | + index=_MULTI_INDEX, |
| 675 | + columns=scalars_pandas_df_index[columns].columns, |
| 676 | + ) |
| 677 | + |
| 678 | + # When condition is dataframe and other is dataframe. |
| 679 | + dataframe_cond_bf = dataframe_bf < 1000.0 |
| 680 | + dataframe_cond_pd = dataframe_pd < 1000.0 |
| 681 | + dataframe_other_bf = dataframe_bf * -1.0 |
| 682 | + dataframe_other_pd = dataframe_pd * -1.0 |
| 683 | + |
| 684 | + bf_result = dataframe_bf.where(dataframe_cond_bf, dataframe_other_bf).to_pandas() |
| 685 | + pd_result = dataframe_pd.where(dataframe_cond_pd, dataframe_other_pd) |
| 686 | + pandas.testing.assert_frame_equal( |
| 687 | + bf_result, |
| 688 | + pd_result, |
| 689 | + check_index_type=False, |
| 690 | + check_dtype=False, |
| 691 | + ) |
| 692 | + |
| 693 | + |
544 | 694 | @pytest.mark.parametrize(
|
545 | 695 | ("level",),
|
546 | 696 | [
|
|
0 commit comments