Skip to content

Commit 4aec00a

Browse files
authored
Remove nan-likes from category header (#1037)
* Remove nan-likes from category header Pandas does not accept None/nan as a category (note: of course it does allow nan-values in the data itself). However outside source (i.e. ARFF files) do allow nan as a category, so we must filter these. * Test output of _unpack_categories
1 parent 6c609b8 commit 4aec00a

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

openml/datasets/dataset.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,11 @@ def _encode_if_category(column):
639639

640640
@staticmethod
641641
def _unpack_categories(series, categories):
642+
# nan-likes can not be explicitly specified as a category
643+
def valid_category(cat):
644+
return isinstance(cat, str) or (cat is not None and not np.isnan(cat))
645+
646+
filtered_categories = [c for c in categories if valid_category(c)]
642647
col = []
643648
for x in series:
644649
try:
@@ -647,7 +652,7 @@ def _unpack_categories(series, categories):
647652
col.append(np.nan)
648653
# We require two lines to create a series of categories as detailed here:
649654
# https://pandas.pydata.org/pandas-docs/version/0.24/user_guide/categorical.html#series-creation # noqa E501
650-
raw_cat = pd.Categorical(col, ordered=True, categories=categories)
655+
raw_cat = pd.Categorical(col, ordered=True, categories=filtered_categories)
651656
return pd.Series(raw_cat, index=series.index, name=series.name)
652657

653658
def get_data(

tests/test_datasets/test_dataset.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,17 @@ def test_init_string_validation(self):
5050
name="somename", description="a description", citation="Something by Müller"
5151
)
5252

53+
def test__unpack_categories_with_nan_likes(self):
54+
# unpack_categories decodes numeric categorical values according to the header
55+
# Containing a 'non' category in the header shouldn't lead to failure.
56+
categories = ["a", "b", None, float("nan"), np.nan]
57+
series = pd.Series([0, 1, None, float("nan"), np.nan, 1, 0])
58+
clean_series = OpenMLDataset._unpack_categories(series, categories)
59+
60+
expected_values = ["a", "b", np.nan, np.nan, np.nan, "b", "a"]
61+
self.assertListEqual(list(clean_series.values), expected_values)
62+
self.assertListEqual(list(clean_series.cat.categories.values), list("ab"))
63+
5364
def test_get_data_array(self):
5465
# Basic usage
5566
rval, _, categorical, attribute_names = self.dataset.get_data(dataset_format="array")

0 commit comments

Comments
 (0)