Skip to content

Commit 5c2a2c6

Browse files
rey-esptswast
andauthored
feat: allow case_when to change dtypes if case list contains the condition (True, some_default_value) (#1311)
* feat: support forecast_limit_lower_bound and forecast_limit_upper_bound in ARIMA_PLUS (and ARIMA_PLUS_XREG) models * update doc string * feat: allow case_when to change dtypes if case list contains the condition True * revert bigframes/ml/forecasting.py * revert bigframes/ml/utils.py * revert tests/system/large/ml/test_forecasting.py * Update third_party/bigframes_vendored/pandas/core/series.py Co-authored-by: Tim Sweña (Swast) <[email protected]> * Update third_party/bigframes_vendored/pandas/core/series.py * Update bigframes/series.py * Update tests/system/small/test_series.py --------- Co-authored-by: Tim Sweña (Swast) <[email protected]>
1 parent 16b357e commit 5c2a2c6

File tree

3 files changed

+64
-1
lines changed

3 files changed

+64
-1
lines changed

bigframes/series.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,19 @@ def between(self, left, right, inclusive="both"):
483483
)
484484

485485
def case_when(self, caselist) -> Series:
486-
cases = list(itertools.chain(*caselist, (True, self)))
486+
cases = []
487+
488+
for condition, output in itertools.chain(caselist, [(True, self)]):
489+
cases.append(condition)
490+
cases.append(output)
491+
# In pandas, the default value if no case matches is the original value.
492+
# This makes it impossible to change the type of the column, but if
493+
# the condition is always True, we know it will match and no subsequent
494+
# conditions matter (including the fallback to `self`). This break allows
495+
# the type to change (see: internal issue 349926559).
496+
if condition is True:
497+
break
498+
487499
return self._apply_nary_op(
488500
ops.case_when_op,
489501
cases,

tests/system/small/test_series.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2862,6 +2862,42 @@ def test_series_case_when(scalars_dfs_maybe_ordered):
28622862
)
28632863

28642864

2865+
def test_series_case_when_change_type(scalars_dfs_maybe_ordered):
2866+
pytest.importorskip(
2867+
"pandas",
2868+
minversion="2.2.0",
2869+
reason="case_when added in pandas 2.2.0",
2870+
)
2871+
scalars_df, scalars_pandas_df = scalars_dfs_maybe_ordered
2872+
2873+
bf_series = scalars_df["int64_col"]
2874+
pd_series = scalars_pandas_df["int64_col"]
2875+
2876+
# TODO(tswast): pandas case_when appears to assume True when a value is
2877+
# null. I suspect this should be considered a bug in pandas.
2878+
2879+
bf_conditions = [
2880+
((bf_series > 645).fillna(True), scalars_df["string_col"]),
2881+
((bf_series <= -100).fillna(True), pd.NA),
2882+
(True, "not_found"),
2883+
]
2884+
2885+
pd_conditions = [
2886+
((pd_series > 645).fillna(True), scalars_pandas_df["string_col"]),
2887+
((pd_series <= -100).fillna(True), pd.NA),
2888+
# pandas currently fails if both the condition and the value are literals.
2889+
([True] * len(pd_series), ["not_found"] * len(pd_series)),
2890+
]
2891+
2892+
bf_result = bf_series.case_when(bf_conditions).to_pandas()
2893+
pd_result = pd_series.case_when(pd_conditions)
2894+
2895+
pd.testing.assert_series_equal(
2896+
bf_result,
2897+
pd_result.astype("string[pyarrow]"),
2898+
)
2899+
2900+
28652901
def test_to_frame(scalars_dfs):
28662902
scalars_df, scalars_pandas_df = scalars_dfs
28672903

third_party/bigframes_vendored/pandas/core/series.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2648,6 +2648,21 @@ def case_when(
26482648
3 2
26492649
Name: c, dtype: Int64
26502650
2651+
If you'd like to change the type, add a case with the condition True at the end of the case list
2652+
2653+
>>> c.case_when(
2654+
... caselist=[
2655+
... (a.gt(0), 'a'), # condition, replacement
2656+
... (b.gt(0), 'b'),
2657+
... (True, 'c'),
2658+
... ]
2659+
... )
2660+
0 c
2661+
1 b
2662+
2 a
2663+
3 a
2664+
Name: c, dtype: string
2665+
26512666
**See also:**
26522667
26532668
- :func:`bigframes.pandas.Series.mask` : Replace values where the condition is True.

0 commit comments

Comments
 (0)