Skip to content

Commit b22a5e3

Browse files
Make Class Label Retrieval More Lenient (#1315)
* mark production tests * make production test run * fix test bug -1/N * add retry raise again after refactor * fix str dict representation * test: Fix non-writable home mocks * testing: not not a change * testing: trigger CI * typing: Update typing * ci: Update testing matrix * testing: Fixup run flow error check * ci: Manual dispatch, disable double testing * ci: Prevent further ci duplication * ci: Add concurrency checks to all * ci: Remove the max-parallel on test ci There are a lot less now and they cancel previous puhes in the same pr now so it shouldn't be a problem anymore * testing: Fix windows path generation * add pytest for server state * add assert cache state * some formatting * fix with cache fixture * finally remove th finally * doc: Fix link * update test matrix * doc: Update to just point to contributing * add linkcheck ignore for test server * add special case for class labels that are dtype string * fix bug and add test * formatting --------- Co-authored-by: eddiebergman <[email protected]>
1 parent 8665b34 commit b22a5e3

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

openml/datasets/dataset.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -908,8 +908,18 @@ def retrieve_class_labels(self, target_name: str = "class") -> None | list[str]:
908908
list
909909
"""
910910
for feature in self.features.values():
911-
if (feature.name == target_name) and (feature.data_type == "nominal"):
912-
return feature.nominal_values
911+
if feature.name == target_name:
912+
if feature.data_type == "nominal":
913+
return feature.nominal_values
914+
915+
if feature.data_type == "string":
916+
# Rel.: #1311
917+
# The target is invalid for a classification task if the feature type is string
918+
# and not nominal. For such miss-configured tasks, we silently fix it here as
919+
# we can safely interpreter string as nominal.
920+
df, *_ = self.get_data()
921+
return list(df[feature.name].unique())
922+
913923
return None
914924

915925
def get_features_by_type( # noqa: C901

tests/test_datasets/test_dataset_functions.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -626,11 +626,18 @@ def test__retrieve_class_labels(self):
626626
openml.config.set_root_cache_directory(self.static_cache_dir)
627627
labels = openml.datasets.get_dataset(2, download_data=False).retrieve_class_labels()
628628
assert labels == ["1", "2", "3", "4", "5", "U"]
629+
629630
labels = openml.datasets.get_dataset(2, download_data=False).retrieve_class_labels(
630631
target_name="product-type",
631632
)
632633
assert labels == ["C", "H", "G"]
633634

635+
# Test workaround for string-typed class labels
636+
custom_ds = openml.datasets.get_dataset(2, download_data=False)
637+
custom_ds.features[31].data_type = "string"
638+
labels = custom_ds.retrieve_class_labels(target_name=custom_ds.features[31].name)
639+
assert labels == ["COIL", "SHEET"]
640+
634641
def test_upload_dataset_with_url(self):
635642
dataset = OpenMLDataset(
636643
"%s-UploadTestWithURL" % self._get_sentinel(),

0 commit comments

Comments
 (0)