Skip to content

Commit 352977c

Browse files
samukwekusammychocoZeroto521root
authored
[ENH] explicit default parameter for case_when (#1165)
* add if_else function as standalone for use in pandas.assign * changelog * Update janitor/functions/case_when.py Co-authored-by: 40% <[email protected]> * Update janitor/functions/case_when.py Co-authored-by: 40% <[email protected]> * changelog * updates based on feedback * if_else addition to docs * drop if_else idea; make default parameter mandatory * changelog * add deprecation warning * single function for checks and computation - separation unnecessary * clean up tests * update tests Co-authored-by: sammychoco <[email protected]> Co-authored-by: 40% <[email protected]> Co-authored-by: root <root@45d364731ba2>
1 parent 27b3201 commit 352977c

File tree

4 files changed

+162
-110
lines changed

4 files changed

+162
-110
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
- [BUG] Avoid `change_type` mutating original `DataFrame`. PR #1162 @Zeroto521
2727
- [ENH] The parameter `column_name` of `change_type` totally supports inputing multi-column now. #1163 @Zeroto521
2828
- [ENH] Fix error when `sort_by_appearance=True` is combined with `dropna=True`. Issue #1168 @samukweku
29+
- [ENH] Add explicit default parameter to `case_when` function. Issue #1159 @samukweku
30+
2931

3032
## [v0.23.1] - 2022-05-03
3133

janitor/functions/case_when.py

Lines changed: 74 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1-
from itertools import count
21
from pandas.core.common import apply_if_callable
3-
from pandas.api.types import is_list_like
2+
from typing import Any
43
import pandas_flavor as pf
54
import pandas as pd
6-
5+
from pandas.api.types import is_scalar
6+
import warnings
77
from janitor.utils import check
88

9+
warnings.simplefilter("always", DeprecationWarning)
10+
911

1012
@pf.register_dataframe_method
11-
def case_when(df: pd.DataFrame, *args, column_name: str) -> pd.DataFrame:
13+
def case_when(
14+
df: pd.DataFrame, *args, default: Any = None, column_name: str
15+
) -> pd.DataFrame:
1216
"""
1317
Create a column based on a condition or multiple conditions.
1418
@@ -33,8 +37,8 @@ def case_when(df: pd.DataFrame, *args, column_name: str) -> pd.DataFrame:
3337
>>> df.case_when(
3438
... ((df.a == 0) & (df.b != 0)) | (df.c == "wait"), df.a,
3539
... (df.b == 0) & (df.a == 0), "x",
36-
... df.c,
37-
... column_name="value",
40+
... default = df.c,
41+
... column_name = "value",
3842
... )
3943
a b c value
4044
0 0 0 6 x
@@ -90,7 +94,7 @@ def case_when(df: pd.DataFrame, *args, column_name: str) -> pd.DataFrame:
9094
:param df: A pandas DataFrame.
9195
:param args: Variable argument of conditions and expected values.
9296
Takes the form
93-
`condition0`, `value0`, `condition1`, `value1`, ..., `default`.
97+
`condition0`, `value0`, `condition1`, `value1`, ... .
9498
`condition` can be a 1-D boolean array, a callable, or a string.
9599
If `condition` is a callable, it should evaluate
96100
to a 1-D boolean array. The array should have the same length
@@ -99,84 +103,67 @@ def case_when(df: pd.DataFrame, *args, column_name: str) -> pd.DataFrame:
99103
`result` can be a scalar, a 1-D array, or a callable.
100104
If `result` is a callable, it should evaluate to a 1-D array.
101105
For a 1-D array, it should have the same length as the DataFrame.
102-
The `default` argument applies if none of `condition0`,
103-
`condition1`, ..., evaluates to `True`.
104-
Value can be a scalar, a callable, or a 1-D array. if `default` is a
105-
callable, it should evaluate to a 1-D array.
106+
:param default: scalar, 1-D array or callable.
107+
This is the element inserted in the output
108+
when all conditions evaluate to False.
109+
If callable, it should evaluate to a 1-D array.
106110
The 1-D array should be the same length as the DataFrame.
111+
107112
:param column_name: Name of column to assign results to. A new column
108113
is created, if it does not already exist in the DataFrame.
109-
:raises ValueError: If the condition fails to evaluate.
114+
:raises ValueError: if condition/value fails to evaluate.
110115
:returns: A pandas DataFrame.
111116
"""
112-
conditions, targets, default = _case_when_checks(df, args, column_name)
113-
114-
if len(conditions) == 1:
115-
default = default.mask(conditions[0], targets[0])
116-
return df.assign(**{column_name: default})
117-
118-
# ensures value assignment is on a first come basis
119-
conditions = conditions[::-1]
120-
targets = targets[::-1]
121-
for condition, value, index in zip(conditions, targets, count()):
122-
try:
123-
default = default.mask(condition, value)
124-
# error `feedoff` idea from SO
125-
# https://stackoverflow.com/a/46091127/7175713
126-
except Exception as e:
127-
raise ValueError(
128-
f"condition{index} and value{index} failed to evaluate. "
129-
f"Original error message: {e}"
130-
) from e
131-
132-
return df.assign(**{column_name: default})
133-
134-
135-
def _case_when_checks(df: pd.DataFrame, args, column_name):
136-
"""
137-
Preliminary checks on the case_when function.
138-
"""
139-
if len(args) < 3:
140-
raise ValueError(
141-
"At least three arguments are required for the `args` parameter."
142-
)
143-
if len(args) % 2 != 1:
117+
# Preliminary checks on the case_when function.
118+
# The bare minimum checks are done; the remaining checks
119+
# are done within `pd.Series.mask`.
120+
check("column_name", column_name, [str])
121+
len_args = len(args)
122+
if len_args < 2:
144123
raise ValueError(
145-
"It seems the `default` argument is missing from the variable "
146-
"`args` parameter."
124+
"At least two arguments are required for the `args` parameter"
147125
)
148126

149-
check("column_name", column_name, [str])
150-
151-
*args, default = args
127+
if len_args % 2:
128+
if default is None:
129+
warnings.warn(
130+
"The last argument in the variable arguments "
131+
"has been assigned as the default. "
132+
"Note however that this will be deprecated "
133+
"in a future release; use an even number "
134+
"of boolean conditions and values, "
135+
"and pass the default argument to the `default` "
136+
"parameter instead.",
137+
DeprecationWarning,
138+
stacklevel=2,
139+
)
140+
*args, default = args
141+
else:
142+
raise ValueError(
143+
"The number of conditions and values do not match. "
144+
f"There are {len_args - len_args//2} conditions "
145+
f"and {len_args//2} values."
146+
)
152147

153148
booleans = []
154149
replacements = []
150+
155151
for index, value in enumerate(args):
156-
if index % 2 == 0:
157-
booleans.append(value)
158-
else:
152+
if index % 2:
153+
if callable(value):
154+
value = apply_if_callable(value, df)
159155
replacements.append(value)
160-
161-
conditions = []
162-
for condition in booleans:
163-
if callable(condition):
164-
condition = apply_if_callable(condition, df)
165-
elif isinstance(condition, str):
166-
condition = df.eval(condition)
167-
conditions.append(condition)
168-
169-
targets = []
170-
for replacement in replacements:
171-
if callable(replacement):
172-
replacement = apply_if_callable(replacement, df)
173-
targets.append(replacement)
156+
else:
157+
if callable(value):
158+
value = apply_if_callable(value, df)
159+
elif isinstance(value, str):
160+
value = df.eval(value)
161+
booleans.append(value)
174162

175163
if callable(default):
176164
default = apply_if_callable(default, df)
177-
if not is_list_like(default):
165+
if is_scalar(default):
178166
default = pd.Series([default]).repeat(len(df))
179-
default.index = df.index
180167
if not hasattr(default, "shape"):
181168
default = pd.Series([*default])
182169
if isinstance(default, pd.Index):
@@ -185,14 +172,26 @@ def _case_when_checks(df: pd.DataFrame, args, column_name):
185172
arr_ndim = default.ndim
186173
if arr_ndim != 1:
187174
raise ValueError(
188-
"The `default` argument should either be a 1-D array, a scalar, "
175+
"The argument for the `default` parameter "
176+
"should either be a 1-D array, a scalar, "
189177
"or a callable that can evaluate to a 1-D array."
190178
)
191179
if not isinstance(default, pd.Series):
192180
default = pd.Series(default)
193-
if default.size != len(df):
194-
raise ValueError(
195-
"The length of the `default` argument should be equal to the "
196-
"length of the DataFrame."
197-
)
198-
return conditions, targets, default
181+
default.index = df.index
182+
# actual computation
183+
# ensures value assignment is on a first come basis
184+
booleans = booleans[::-1]
185+
replacements = replacements[::-1]
186+
for index, (condition, value) in enumerate(zip(booleans, replacements)):
187+
try:
188+
default = default.mask(condition, value)
189+
# error `feedoff` idea from SO
190+
# https://stackoverflow.com/a/46091127/7175713
191+
except Exception as error:
192+
raise ValueError(
193+
f"condition{index} and value{index} failed to evaluate. "
194+
f"Original error message: {error}"
195+
) from error
196+
197+
return df.assign(**{column_name: default})

janitor/functions/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from janitor.utils import check_column
2424
import functools
2525

26+
warnings.simplefilter("always", DeprecationWarning)
27+
2628

2729
def unionize_dataframe_categories(
2830
*dataframes, column_names: Optional[Iterable[pd.CategoricalDtype]] = None

0 commit comments

Comments
 (0)