1
- from itertools import count
2
1
from pandas .core .common import apply_if_callable
3
- from pandas . api . types import is_list_like
2
+ from typing import Any
4
3
import pandas_flavor as pf
5
4
import pandas as pd
6
-
5
+ from pandas .api .types import is_scalar
6
+ import warnings
7
7
from janitor .utils import check
8
8
9
+ warnings .simplefilter ("always" , DeprecationWarning )
10
+
9
11
10
12
@pf .register_dataframe_method
11
- def case_when (df : pd .DataFrame , * args , column_name : str ) -> pd .DataFrame :
13
+ def case_when (
14
+ df : pd .DataFrame , * args , default : Any = None , column_name : str
15
+ ) -> pd .DataFrame :
12
16
"""
13
17
Create a column based on a condition or multiple conditions.
14
18
@@ -33,8 +37,8 @@ def case_when(df: pd.DataFrame, *args, column_name: str) -> pd.DataFrame:
33
37
>>> df.case_when(
34
38
... ((df.a == 0) & (df.b != 0)) | (df.c == "wait"), df.a,
35
39
... (df.b == 0) & (df.a == 0), "x",
36
- ... df.c,
37
- ... column_name= "value",
40
+ ... default = df.c,
41
+ ... column_name = "value",
38
42
... )
39
43
a b c value
40
44
0 0 0 6 x
@@ -90,7 +94,7 @@ def case_when(df: pd.DataFrame, *args, column_name: str) -> pd.DataFrame:
90
94
:param df: A pandas DataFrame.
91
95
:param args: Variable argument of conditions and expected values.
92
96
Takes the form
93
- `condition0`, `value0`, `condition1`, `value1`, ..., `default` .
97
+ `condition0`, `value0`, `condition1`, `value1`, ... .
94
98
`condition` can be a 1-D boolean array, a callable, or a string.
95
99
If `condition` is a callable, it should evaluate
96
100
to a 1-D boolean array. The array should have the same length
@@ -99,84 +103,67 @@ def case_when(df: pd.DataFrame, *args, column_name: str) -> pd.DataFrame:
99
103
`result` can be a scalar, a 1-D array, or a callable.
100
104
If `result` is a callable, it should evaluate to a 1-D array.
101
105
For a 1-D array, it should have the same length as the DataFrame.
102
- The ` default` argument applies if none of `condition0`,
103
- `condition1`, ..., evaluates to `True`.
104
- Value can be a scalar, a callable, or a 1-D array. if `default` is a
105
- callable, it should evaluate to a 1-D array.
106
+ :param default: scalar, 1-D array or callable.
107
+ This is the element inserted in the output
108
+ when all conditions evaluate to False.
109
+ If callable, it should evaluate to a 1-D array.
106
110
The 1-D array should be the same length as the DataFrame.
111
+
107
112
:param column_name: Name of column to assign results to. A new column
108
113
is created, if it does not already exist in the DataFrame.
109
- :raises ValueError: If the condition fails to evaluate.
114
+ :raises ValueError: if condition/value fails to evaluate.
110
115
:returns: A pandas DataFrame.
111
116
"""
112
- conditions , targets , default = _case_when_checks (df , args , column_name )
113
-
114
- if len (conditions ) == 1 :
115
- default = default .mask (conditions [0 ], targets [0 ])
116
- return df .assign (** {column_name : default })
117
-
118
- # ensures value assignment is on a first come basis
119
- conditions = conditions [::- 1 ]
120
- targets = targets [::- 1 ]
121
- for condition , value , index in zip (conditions , targets , count ()):
122
- try :
123
- default = default .mask (condition , value )
124
- # error `feedoff` idea from SO
125
- # https://stackoverflow.com/a/46091127/7175713
126
- except Exception as e :
127
- raise ValueError (
128
- f"condition{ index } and value{ index } failed to evaluate. "
129
- f"Original error message: { e } "
130
- ) from e
131
-
132
- return df .assign (** {column_name : default })
133
-
134
-
135
- def _case_when_checks (df : pd .DataFrame , args , column_name ):
136
- """
137
- Preliminary checks on the case_when function.
138
- """
139
- if len (args ) < 3 :
140
- raise ValueError (
141
- "At least three arguments are required for the `args` parameter."
142
- )
143
- if len (args ) % 2 != 1 :
117
+ # Preliminary checks on the case_when function.
118
+ # The bare minimum checks are done; the remaining checks
119
+ # are done within `pd.Series.mask`.
120
+ check ("column_name" , column_name , [str ])
121
+ len_args = len (args )
122
+ if len_args < 2 :
144
123
raise ValueError (
145
- "It seems the `default` argument is missing from the variable "
146
- "`args` parameter."
124
+ "At least two arguments are required for the `args` parameter"
147
125
)
148
126
149
- check ("column_name" , column_name , [str ])
150
-
151
- * args , default = args
127
+ if len_args % 2 :
128
+ if default is None :
129
+ warnings .warn (
130
+ "The last argument in the variable arguments "
131
+ "has been assigned as the default. "
132
+ "Note however that this will be deprecated "
133
+ "in a future release; use an even number "
134
+ "of boolean conditions and values, "
135
+ "and pass the default argument to the `default` "
136
+ "parameter instead." ,
137
+ DeprecationWarning ,
138
+ stacklevel = 2 ,
139
+ )
140
+ * args , default = args
141
+ else :
142
+ raise ValueError (
143
+ "The number of conditions and values do not match. "
144
+ f"There are { len_args - len_args // 2 } conditions "
145
+ f"and { len_args // 2 } values."
146
+ )
152
147
153
148
booleans = []
154
149
replacements = []
150
+
155
151
for index , value in enumerate (args ):
156
- if index % 2 == 0 :
157
- booleans . append (value )
158
- else :
152
+ if index % 2 :
153
+ if callable (value ):
154
+ value = apply_if_callable ( value , df )
159
155
replacements .append (value )
160
-
161
- conditions = []
162
- for condition in booleans :
163
- if callable (condition ):
164
- condition = apply_if_callable (condition , df )
165
- elif isinstance (condition , str ):
166
- condition = df .eval (condition )
167
- conditions .append (condition )
168
-
169
- targets = []
170
- for replacement in replacements :
171
- if callable (replacement ):
172
- replacement = apply_if_callable (replacement , df )
173
- targets .append (replacement )
156
+ else :
157
+ if callable (value ):
158
+ value = apply_if_callable (value , df )
159
+ elif isinstance (value , str ):
160
+ value = df .eval (value )
161
+ booleans .append (value )
174
162
175
163
if callable (default ):
176
164
default = apply_if_callable (default , df )
177
- if not is_list_like (default ):
165
+ if is_scalar (default ):
178
166
default = pd .Series ([default ]).repeat (len (df ))
179
- default .index = df .index
180
167
if not hasattr (default , "shape" ):
181
168
default = pd .Series ([* default ])
182
169
if isinstance (default , pd .Index ):
@@ -185,14 +172,26 @@ def _case_when_checks(df: pd.DataFrame, args, column_name):
185
172
arr_ndim = default .ndim
186
173
if arr_ndim != 1 :
187
174
raise ValueError (
188
- "The `default` argument should either be a 1-D array, a scalar, "
175
+ "The argument for the `default` parameter "
176
+ "should either be a 1-D array, a scalar, "
189
177
"or a callable that can evaluate to a 1-D array."
190
178
)
191
179
if not isinstance (default , pd .Series ):
192
180
default = pd .Series (default )
193
- if default .size != len (df ):
194
- raise ValueError (
195
- "The length of the `default` argument should be equal to the "
196
- "length of the DataFrame."
197
- )
198
- return conditions , targets , default
181
+ default .index = df .index
182
+ # actual computation
183
+ # ensures value assignment is on a first come basis
184
+ booleans = booleans [::- 1 ]
185
+ replacements = replacements [::- 1 ]
186
+ for index , (condition , value ) in enumerate (zip (booleans , replacements )):
187
+ try :
188
+ default = default .mask (condition , value )
189
+ # error `feedoff` idea from SO
190
+ # https://stackoverflow.com/a/46091127/7175713
191
+ except Exception as error :
192
+ raise ValueError (
193
+ f"condition{ index } and value{ index } failed to evaluate. "
194
+ f"Original error message: { error } "
195
+ ) from error
196
+
197
+ return df .assign (** {column_name : default })
0 commit comments