Skip to content

Commit 696db49

Browse files
glemaitremfeurer
authored andcommitted
[MRG] EHN: inferred row_id_attribute from dataframe to create a dataset (#586)
* EHN: inferred row_id_attribute from dataframe to create a dataset * reset the index of dataframe after inference * TST: check the size of the dataset * PEP8 * TST: check that an error is raised when row_id_attributes is not a known attribute * DOC: Update the docstring * PEP8
1 parent 6c75554 commit 696db49

File tree

2 files changed

+134
-9
lines changed

2 files changed

+134
-9
lines changed

openml/datasets/functions.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -417,8 +417,9 @@ def attributes_arff_from_df(df):
417417
def create_dataset(name, description, creator, contributor,
418418
collection_date, language,
419419
licence, attributes, data,
420-
default_target_attribute, row_id_attribute,
421-
ignore_attribute, citation, format=None,
420+
default_target_attribute,
421+
ignore_attribute, citation,
422+
row_id_attribute=None, format=None,
422423
original_data_url=None, paper_url=None,
423424
update_comment=None, version_label=None):
424425
"""Create a dataset.
@@ -433,11 +434,6 @@ def create_dataset(name, description, creator, contributor,
433434
Name of the dataset.
434435
description : str
435436
Description of the dataset.
436-
format : str, optional
437-
Format of the dataset which can be either 'arff' or 'sparse_arff'.
438-
By default, the format is automatically inferred.
439-
.. deprecated: 0.8
440-
``format`` is deprecated in 0.8 and will be removed in 0.10.
441437
creator : str
442438
The person who created the dataset.
443439
contributor : str
@@ -463,14 +459,25 @@ def create_dataset(name, description, creator, contributor,
463459
default_target_attribute : str
464460
The default target attribute, if it exists.
465461
Can have multiple values, comma separated.
466-
row_id_attribute : str
467-
The attribute that represents the row-id column, if present in the dataset.
468462
ignore_attribute : str | list
469463
Attributes that should be excluded in modelling, such as identifiers and indexes.
470464
citation : str
471465
Reference(s) that should be cited when building on this data.
472466
version_label : str, optional
473467
Version label provided by user, can be a date, hash, or some other type of id.
468+
row_id_attribute : str, optional
469+
The attribute that represents the row-id column, if present in the
470+
dataset. If ``data`` is a dataframe and ``row_id_attribute`` is not
471+
specified, the index of the dataframe will be used as the
472+
``row_id_attribute``. If the name of the index is ``None``, it will
473+
be discarded.
474+
.. versionadded: 0.8
475+
Inference of ``row_id_attribute`` from a dataframe.
476+
format : str, optional
477+
Format of the dataset which can be either 'arff' or 'sparse_arff'.
478+
By default, the format is automatically inferred.
479+
.. deprecated: 0.8
480+
``format`` is deprecated in 0.8 and will be removed in 0.10.
474481
original_data_url : str, optional
475482
For derived data, the url to the original dataset.
476483
paper_url : str, optional
@@ -483,6 +490,15 @@ def create_dataset(name, description, creator, contributor,
483490
class:`openml.OpenMLDataset`
484491
Dataset description."""
485492

493+
if isinstance(data, (pd.DataFrame, pd.SparseDataFrame)):
494+
# infer the row id from the index of the dataset
495+
if row_id_attribute is None:
496+
row_id_attribute = data.index.name
497+
# When calling data.values, the index will be skipped. We need to reset
498+
# the index such that it is part of the data.
499+
if data.index.name is not None:
500+
data = data.reset_index()
501+
486502
if attributes == 'auto' or isinstance(attributes, dict):
487503
if not hasattr(data, "columns"):
488504
raise ValueError("Automatically inferring the attributes required "
@@ -499,6 +515,16 @@ def create_dataset(name, description, creator, contributor,
499515
else:
500516
attributes_ = attributes
501517

518+
if row_id_attribute is not None:
519+
is_row_id_an_attribute = any([attr[0] == row_id_attribute
520+
for attr in attributes_])
521+
if not is_row_id_an_attribute:
522+
raise ValueError(
523+
"'row_id_attribute' should be one of the data attribute. "
524+
" Got '{}' while candidates are {}."
525+
.format(row_id_attribute, [attr[0] for attr in attributes_])
526+
)
527+
502528
data = data.values if hasattr(data, "columns") else data
503529

504530
if format is not None:

tests/test_datasets/test_dataset_functions.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import sys
44
import random
5+
from itertools import product
56
if sys.version_info[0] >= 3:
67
from unittest import mock
78
else:
@@ -803,6 +804,104 @@ def test_create_dataset_pandas(self):
803804
self.assertTrue(
804805
'@ATTRIBUTE rnd_str {a, b, c, d, e, f, g}' in downloaded_data)
805806

807+
def test_create_dataset_row_id_attribute_error(self):
808+
# meta-information
809+
name = 'Pandas_testing_dataset'
810+
description = 'Synthetic dataset created from a Pandas DataFrame'
811+
creator = 'OpenML tester'
812+
collection_date = '01-01-2018'
813+
language = 'English'
814+
licence = 'MIT'
815+
default_target_attribute = 'target'
816+
citation = 'None'
817+
original_data_url = 'http://openml.github.io/openml-python'
818+
paper_url = 'http://openml.github.io/openml-python'
819+
# Check that the index name is well inferred.
820+
data = [['a', 1, 0],
821+
['b', 2, 1],
822+
['c', 3, 0],
823+
['d', 4, 1],
824+
['e', 5, 0]]
825+
column_names = ['rnd_str', 'integer', 'target']
826+
df = pd.DataFrame(data, columns=column_names)
827+
# affecting row_id_attribute to an unknown column should raise an error
828+
err_msg = ("should be one of the data attribute.")
829+
with pytest.raises(ValueError, match=err_msg):
830+
openml.datasets.functions.create_dataset(
831+
name=name,
832+
description=description,
833+
creator=creator,
834+
contributor=None,
835+
collection_date=collection_date,
836+
language=language,
837+
licence=licence,
838+
default_target_attribute=default_target_attribute,
839+
ignore_attribute=None,
840+
citation=citation,
841+
attributes='auto',
842+
data=df,
843+
row_id_attribute='unknown_row_id',
844+
format=None,
845+
version_label='test',
846+
original_data_url=original_data_url,
847+
paper_url=paper_url
848+
)
849+
850+
def test_create_dataset_row_id_attribute_inference(self):
851+
# meta-information
852+
name = 'Pandas_testing_dataset'
853+
description = 'Synthetic dataset created from a Pandas DataFrame'
854+
creator = 'OpenML tester'
855+
collection_date = '01-01-2018'
856+
language = 'English'
857+
licence = 'MIT'
858+
default_target_attribute = 'target'
859+
citation = 'None'
860+
original_data_url = 'http://openml.github.io/openml-python'
861+
paper_url = 'http://openml.github.io/openml-python'
862+
# Check that the index name is well inferred.
863+
data = [['a', 1, 0],
864+
['b', 2, 1],
865+
['c', 3, 0],
866+
['d', 4, 1],
867+
['e', 5, 0]]
868+
column_names = ['rnd_str', 'integer', 'target']
869+
df = pd.DataFrame(data, columns=column_names)
870+
row_id_attr = [None, 'integer']
871+
df_index_name = [None, 'index_name']
872+
expected_row_id = [None, 'index_name', 'integer', 'integer']
873+
for output_row_id, (row_id, index_name) in zip(expected_row_id,
874+
product(row_id_attr,
875+
df_index_name)):
876+
df.index.name = index_name
877+
dataset = openml.datasets.functions.create_dataset(
878+
name=name,
879+
description=description,
880+
creator=creator,
881+
contributor=None,
882+
collection_date=collection_date,
883+
language=language,
884+
licence=licence,
885+
default_target_attribute=default_target_attribute,
886+
ignore_attribute=None,
887+
citation=citation,
888+
attributes='auto',
889+
data=df,
890+
row_id_attribute=row_id,
891+
format=None,
892+
version_label='test',
893+
original_data_url=original_data_url,
894+
paper_url=paper_url
895+
)
896+
self.assertEqual(dataset.row_id_attribute, output_row_id)
897+
upload_did = dataset.publish()
898+
arff_dataset = arff.loads(_get_online_dataset_arff(upload_did))
899+
arff_data = np.array(arff_dataset['data'], dtype=object)
900+
# if we set the name of the index then the index will be added to
901+
# the data
902+
expected_shape = (5, 3) if index_name is None else (5, 4)
903+
self.assertEqual(arff_data.shape, expected_shape)
904+
806905
def test_create_dataset_attributes_auto_without_df(self):
807906
# attributes cannot be inferred without passing a dataframe
808907
data = np.array([[1, 2, 3],

0 commit comments

Comments
 (0)