Skip to content

Commit a8324dd

Browse files
authored
[ENH] select columns by level (#1111)
* update tests; update logic for slicing when columns are not unique * skeleton for level parameter * skeleton for level parameter * clean up select_columns.py * changelog * avoid copying the dataframe, by making changes on the columns only for level parameter * cleanup _select_column_names * add explanation for level parameter logic * simply if-else logic in select_columns * use set_axis to avoid modifying the original dataframe * reuse variable name * clean up a bit * add more comments * simplify search for strings * fix error message * update comments * return early * fail loudly if search not found * cleanup * cleanup * fail loudly * fix logic for string columns - check the underlying dtype * fix logic for string columns - check the underlying dtype * improve string/categorical column logic * add tests for categorical columns
1 parent d5bea6a commit a8324dd

File tree

5 files changed

+180
-137
lines changed

5 files changed

+180
-137
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
- [DOC] Updated developer guide docs.
66
- [ENH] Allow column selection/renaming within conditional_join. #1102 @samukweku.
77
- [ENH] New decorator `deprecated_kwargs` for breaking API. #1103 @Zeroto521
8-
- [ENH] Extend select_columns to support non-string columns. #1105 @samukweku
8+
- [ENH] Extend select_columns to support non-string columns. Also allow selection on MultiIndex columns via level parameter. #1105 @samukweku
99
- [ENH] Performance improvement for groupby_topk. #1093 @samukweku
1010
- [EHN] `min_max_scale` drop `old_min` and `old_max` to fit sklearn's method API. Issue #1068 @Zeroto521
1111

janitor/functions/select_columns.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
"""Implementation of select_columns"""
2+
from typing import Optional, Union
23
import pandas_flavor as pf
34
import pandas as pd
4-
5-
from janitor.utils import deprecated_alias
5+
from pandas.api.types import is_list_like
6+
from janitor.utils import deprecated_alias, check
67

78
from janitor.functions.utils import _select_column_names
8-
from pandas.api.types import is_list_like
99

1010

1111
@pf.register_dataframe_method
1212
@deprecated_alias(search_cols="search_column_names")
1313
def select_columns(
1414
df: pd.DataFrame,
1515
*args,
16+
level: Optional[Union[int, str]] = None,
1617
invert: bool = False,
1718
) -> pd.DataFrame:
1819
"""
@@ -48,6 +49,8 @@ def select_columns(
4849
a callable which is applicable to each Series in the DataFrame,
4950
or variable arguments of all the aforementioned.
5051
A sequence of booleans is also acceptable.
52+
:param level: Determines which level in the columns should be used for the
53+
column selection.
5154
:param invert: Whether or not to invert the selection.
5255
This will result in the selection of the complement of the columns
5356
provided.
@@ -62,8 +65,26 @@ def select_columns(
6265
search_column_names.extend(arg)
6366
else:
6467
search_column_names.append(arg)
65-
full_column_list = _select_column_names(search_column_names, df)
66-
68+
if level is not None:
69+
# goal here is to capture the original columns
70+
# trim the df.columns to the specified level only,
71+
# and apply the selection (_select_column_names)
72+
# to get the relevant column labels.
73+
# note that no level is dropped; if there are three levels,
74+
# then three levels are returned, with the specified labels
75+
# selected/deselected.
76+
# A copy of the dataframe is made via set_axis,
77+
# to avoid mutating the original dataframe.
78+
df_columns = df.columns
79+
check("level", level, [int, str])
80+
full_column_list = df_columns.get_level_values(level)
81+
full_column_list = _select_column_names(
82+
search_column_names, df.set_axis(full_column_list, axis=1)
83+
)
84+
full_column_list = df_columns.isin(full_column_list, level=level)
85+
full_column_list = df_columns[full_column_list]
86+
else:
87+
full_column_list = _select_column_names(search_column_names, df)
6788
if invert:
6889
return df.drop(columns=full_column_list)
6990
return df.loc[:, full_column_list]

janitor/functions/utils.py

Lines changed: 39 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
union_categoricals,
1313
is_scalar,
1414
is_list_like,
15+
is_datetime64_dtype,
16+
is_string_dtype,
17+
is_categorical_dtype,
1518
)
1619
import numpy as np
1720
from multipledispatch import dispatch
@@ -212,7 +215,16 @@ def _select_column_names(columns_to_select, df):
212215
"""
213216
if columns_to_select in df.columns:
214217
return [columns_to_select]
215-
raise KeyError(f"No match was returned for '{columns_to_select}'.")
218+
raise KeyError(f"No match was returned for {columns_to_select}.")
219+
220+
221+
def _is_str_or_cat(df_columns):
222+
"""Check if the column is a string or categorical with strings."""
223+
if is_string_dtype(df_columns):
224+
return True
225+
if is_categorical_dtype(df_columns):
226+
return is_string_dtype(df_columns.categories)
227+
return False
216228

217229

218230
@_select_column_names.register(str) # noqa: F811
@@ -221,32 +233,26 @@ def _column_sel_dispatch(columns_to_select, df): # noqa: F811
221233
Base function for column selection.
222234
Applies only to strings.
223235
It is also applicable to shell-like glob strings,
224-
specifically, the `*`.
236+
which are supported by `fnmatch`.
225237
A list/pandas Index of matching column names is returned.
226238
"""
227239
df_columns = df.columns
228-
if pd.api.types.is_string_dtype(df_columns):
229-
if (
230-
"*" in columns_to_select
231-
): # shell-style glob string (e.g., `*_thing_*`)
232-
return fnmatch.filter(df_columns, columns_to_select)
240+
241+
if _is_str_or_cat(df_columns):
233242
if columns_to_select in df_columns:
234243
return [columns_to_select]
235-
raise KeyError(f"No match was returned for '{columns_to_select}'.")
236-
if pd.api.types.is_datetime64_any_dtype(df_columns):
237-
if not df_columns.is_monotonic_increasing:
238-
raise ValueError(
239-
"The column is a DatetimeIndex and should be "
240-
"monotonic increasing."
241-
)
244+
outcome = fnmatch.filter(df_columns, columns_to_select)
245+
if not outcome:
246+
raise KeyError(f"No match was returned for '{columns_to_select}'.")
247+
return outcome
248+
249+
if is_datetime64_dtype(df_columns):
242250
timestamp = df_columns.get_loc(columns_to_select)
243-
if isinstance(timestamp, slice):
251+
if not isinstance(timestamp, int):
244252
return df_columns[timestamp]
245253
return [df_columns[timestamp]]
246-
raise KeyError(
247-
f"String('{columns_to_select}') can be applied "
248-
"only to string/datetime columns."
249-
)
254+
255+
raise KeyError(f"No match was returned for '{columns_to_select}'.")
250256

251257

252258
@_select_column_names.register(re.Pattern) # noqa: F811
@@ -257,15 +263,16 @@ def _column_sel_dispatch(columns_to_select, df): # noqa: F811
257263
`re.compile` is required for the regular expression.
258264
A pandas Index of matching column names is returned.
259265
"""
260-
if pd.api.types.is_string_dtype(df.columns):
261-
bools = df.columns.str.contains(
266+
df_columns = df.columns
267+
268+
if _is_str_or_cat(df_columns):
269+
bools = df_columns.str.contains(
262270
columns_to_select, na=False, regex=True
263271
)
264-
return df.columns[bools]
265-
raise KeyError(
266-
f"Regular expressions('{columns_to_select}') "
267-
"can be applied only to string columns."
268-
)
272+
if not bools.any():
273+
raise KeyError(f"No match was returned for {columns_to_select}.")
274+
return df_columns[bools]
275+
raise KeyError(f"No match was returned for {columns_to_select}.")
269276

270277

271278
@_select_column_names.register(slice) # noqa: F811
@@ -291,13 +298,12 @@ def _column_sel_dispatch(columns_to_select, df): # noqa: F811
291298
step_check = None
292299
method = None
293300

294-
if not df_columns.is_unique:
301+
if not df_columns.is_unique and not df_columns.is_monotonic_increasing:
295302
raise ValueError(
296-
"The column labels are not unique. "
297-
"Kindly ensure the labels are unique "
298-
"to ensure the correct output."
303+
"Non-unique column labels should be monotonic increasing."
299304
)
300-
is_date_column = pd.api.types.is_datetime64_any_dtype(df_columns)
305+
306+
is_date_column = is_datetime64_dtype(df_columns)
301307
if is_date_column:
302308
if not df_columns.is_monotonic_increasing:
303309
raise ValueError(
@@ -377,6 +383,8 @@ def _column_sel_dispatch(columns_to_select, df): # noqa: F811
377383
raise TypeError(
378384
"The output of the applied callable should be a boolean array."
379385
)
386+
if not filtered_columns.any():
387+
raise KeyError(f"No match was returned for {columns_to_select}.")
380388

381389
return df.columns[filtered_columns]
382390

tests/functions/test_select_columns.py

Lines changed: 50 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import pandas as pd
22
import pytest
3+
import re
34
from pandas.testing import assert_frame_equal
45

56

@@ -51,36 +52,6 @@ def test_select_column_names_missing_columns(dataframe, columns):
5152
dataframe.select_columns(columns)
5253

5354

54-
@pytest.mark.functions
55-
@pytest.mark.parametrize(
56-
"columns",
57-
[
58-
pytest.param(
59-
"a",
60-
marks=pytest.mark.xfail(
61-
reason="`select_columns` now accepts strings"
62-
),
63-
),
64-
pytest.param(
65-
("a", "Bell__Chart"),
66-
marks=pytest.mark.xfail(
67-
reason="`select_columns` converts list-like into lists"
68-
),
69-
),
70-
pytest.param(
71-
{"a", "Bell__Chart"},
72-
marks=pytest.mark.xfail(
73-
reason="`select_columns` converts list-like into lists"
74-
),
75-
),
76-
],
77-
)
78-
def test_select_column_names_input(dataframe, columns):
79-
"""Check that passing an iterable that is not a list raises TypeError."""
80-
with pytest.raises(TypeError):
81-
dataframe.select_columns(columns)
82-
83-
8455
@pytest.mark.functions
8556
@pytest.mark.parametrize(
8657
"invert,expected",
@@ -116,20 +87,60 @@ def columns(x):
11687
assert_frame_equal(df, dataframe[expected])
11788

11889

119-
@pytest.mark.xfail(reason="Allow tuples which are acceptable in MultiIndex.")
120-
def test_MultiIndex():
121-
"""
122-
Raise ValueError if columns is a MultiIndex.
123-
"""
124-
df = pd.DataFrame(
90+
@pytest.fixture
91+
def df_tuple():
92+
"pytest fixture."
93+
frame = pd.DataFrame(
12594
{
12695
"A": {0: "a", 1: "b", 2: "c"},
12796
"B": {0: 1, 1: 3, 2: 5},
12897
"C": {0: 2, 1: 4, 2: 6},
12998
}
13099
)
100+
frame.columns = [list("ABC"), list("DEF")]
101+
return frame
102+
103+
104+
def test_multiindex(df_tuple):
105+
"""
106+
Test output for a MultiIndex and tuple passed.
107+
"""
108+
assert_frame_equal(
109+
df_tuple.select_columns(("A", "D")), df_tuple.loc[:, [("A", "D")]]
110+
)
131111

132-
df.columns = [list("ABC"), list("DEF")]
133112

134-
with pytest.raises(ValueError):
135-
df.select_columns("A")
113+
def test_level_callable(df_tuple):
114+
"""
115+
Test output if level is supplied for a callable.
116+
"""
117+
expected = df_tuple.select_columns(
118+
lambda df: df.name.startswith("A"), level=0
119+
)
120+
actual = df_tuple.xs("A", axis=1, drop_level=False, level=0)
121+
assert_frame_equal(actual, expected)
122+
123+
124+
def test_level_regex(df_tuple):
125+
"""
126+
Test output if level is supplied for a regex
127+
"""
128+
expected = df_tuple.select_columns(re.compile("D"), level=1)
129+
actual = df_tuple.xs("D", axis=1, drop_level=False, level=1)
130+
assert_frame_equal(actual, expected)
131+
132+
133+
def test_level_slice(df_tuple):
134+
"""
135+
Test output if level is supplied for a slice
136+
"""
137+
expected = df_tuple.select_columns(slice("F", "D"), level=1)
138+
assert_frame_equal(df_tuple, expected)
139+
140+
141+
def test_level_str(df_tuple):
142+
"""
143+
Test output if level is supplied for a string.
144+
"""
145+
expected = df_tuple.select_columns("A", level=0, invert=True)
146+
assert_frame_equal(df_tuple.drop(columns="A", axis=1, level=0), expected)

0 commit comments

Comments
 (0)