Skip to content

Commit b9035c4

Browse files
glemaitremfeurer
authored andcommitted
[MRG] EHN: support SparseDataFrame when creating a dataset (#583)
* EHN: support SparseDataFrame when creating a dataset * TST: check attributes inference dtype * PEP8 * EXA: add sparse dataframe in the example * Fix typos. * Fix typo. * Refactoring task.py (#588) * [MRG] EHN: inferred row_id_attribute from dataframe to create a dataset (#586) * EHN: inferred row_id_attribute from dataframe to create a dataset * reset the index of dataframe after inference * TST: check the size of the dataset * PEP8 * TST: check that an error is raised when row_id_attributes is not a known attribute * DOC: Update the docstring * PEP8 * add examples to the menu, remove double progress (#554) * PEP8 * PEP8
1 parent c69b0a6 commit b9035c4

File tree

3 files changed

+100
-4
lines changed

3 files changed

+100
-4
lines changed

examples/create_upload_tutorial.py

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# * A list
2525
# * A pandas dataframe
2626
# * A sparse matrix
27+
# * A pandas sparse dataframe
2728

2829
############################################################################
2930
# Dataset is a numpy array
@@ -243,7 +244,7 @@
243244

244245
sparse_data = coo_matrix((
245246
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
246-
([0, 1, 1, 2, 2, 3, 3], [0, 1, 2, 0, 2, 0, 1]),
247+
([0, 1, 1, 2, 2, 3, 3], [0, 1, 2, 0, 2, 0, 1])
247248
))
248249

249250
column_names = [
@@ -273,3 +274,38 @@
273274

274275
upload_did = xor_dataset.publish()
275276
print('URL for dataset: %s/data/%d' % (openml.config.server, upload_did))
277+
278+
279+
############################################################################
280+
# Dataset is a pandas sparse dataframe
281+
# ====================================
282+
283+
sparse_data = coo_matrix((
284+
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
285+
([0, 1, 1, 2, 2, 3, 3], [0, 1, 2, 0, 2, 0, 1])
286+
))
287+
column_names = ['input1', 'input2', 'y']
288+
df = pd.SparseDataFrame(sparse_data, columns=column_names)
289+
print(df.info())
290+
291+
xor_dataset = create_dataset(
292+
name="XOR",
293+
description='Dataset representing the XOR operation',
294+
creator=None,
295+
contributor=None,
296+
collection_date=None,
297+
language='English',
298+
licence=None,
299+
default_target_attribute='y',
300+
row_id_attribute=None,
301+
ignore_attribute=None,
302+
citation=None,
303+
attributes='auto',
304+
data=df,
305+
version_label='example',
306+
)
307+
308+
############################################################################
309+
310+
upload_did = xor_dataset.publish()
311+
print('URL for dataset: %s/data/%d' % (openml.config.server, upload_did))

openml/datasets/functions.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -502,8 +502,8 @@ def create_dataset(name, description, creator, contributor,
502502
if attributes == 'auto' or isinstance(attributes, dict):
503503
if not hasattr(data, "columns"):
504504
raise ValueError("Automatically inferring the attributes required "
505-
"a pandas DataFrame. A {!r} was given instead."
506-
.format(data))
505+
"a pandas DataFrame or SparseDataFrame. "
506+
"A {!r} was given instead.".format(data))
507507
# infer the type of data for each column of the DataFrame
508508
attributes_ = attributes_arff_from_df(data)
509509
if isinstance(attributes, dict):
@@ -525,7 +525,16 @@ def create_dataset(name, description, creator, contributor,
525525
.format(row_id_attribute, [attr[0] for attr in attributes_])
526526
)
527527

528-
data = data.values if hasattr(data, "columns") else data
528+
if hasattr(data, "columns"):
529+
if isinstance(data, pd.SparseDataFrame):
530+
data = data.to_coo()
531+
# liac-arff only support COO matrices with sorted rows
532+
row_idx_sorted = np.argsort(data.row)
533+
data.row = data.row[row_idx_sorted]
534+
data.col = data.col[row_idx_sorted]
535+
data.data = data.data[row_idx_sorted]
536+
else:
537+
data = data.values
529538

530539
if format is not None:
531540
warn("The format parameter will be deprecated in the future,"

tests/test_datasets/test_dataset_functions.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,7 @@ def test_data_status(self):
411411
self.assertEqual(result[did]['status'], 'active')
412412

413413
def test_attributes_arff_from_df(self):
414+
# DataFrame case
414415
df = pd.DataFrame(
415416
[[1, 1.0, 'xxx', 'A', True], [2, 2.0, 'yyy', 'B', False]],
416417
columns=['integer', 'floating', 'string', 'category', 'boolean']
@@ -422,6 +423,16 @@ def test_attributes_arff_from_df(self):
422423
('string', 'STRING'),
423424
('category', ['A', 'B']),
424425
('boolean', ['True', 'False'])])
426+
# SparseDataFrame case
427+
df = pd.SparseDataFrame([[1, 1.0],
428+
[2, 2.0],
429+
[0, 0]],
430+
columns=['integer', 'floating'],
431+
default_fill_value=0)
432+
df['integer'] = df['integer'].astype(np.int64)
433+
attributes = attributes_arff_from_df(df)
434+
self.assertEqual(attributes, [('integer', 'INTEGER'),
435+
('floating', 'REAL')])
425436

426437
def test_attributes_arff_from_df_mixed_dtype_categories(self):
427438
# liac-arff imposed categorical attributes to be of sting dtype. We
@@ -769,6 +780,46 @@ def test_create_dataset_pandas(self):
769780
"Uploaded ARFF does not match original one"
770781
)
771782

783+
# Check that SparseDataFrame are supported properly
784+
sparse_data = scipy.sparse.coo_matrix((
785+
[0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
786+
([0, 1, 1, 2, 2, 3, 3], [0, 1, 2, 0, 2, 0, 1])
787+
))
788+
column_names = ['input1', 'input2', 'y']
789+
df = pd.SparseDataFrame(sparse_data, columns=column_names)
790+
# meta-information
791+
description = 'Synthetic dataset created from a Pandas SparseDataFrame'
792+
dataset = openml.datasets.functions.create_dataset(
793+
name=name,
794+
description=description,
795+
creator=creator,
796+
contributor=None,
797+
collection_date=collection_date,
798+
language=language,
799+
licence=licence,
800+
default_target_attribute=default_target_attribute,
801+
row_id_attribute=None,
802+
ignore_attribute=None,
803+
citation=citation,
804+
attributes='auto',
805+
data=df,
806+
format=None,
807+
version_label='test',
808+
original_data_url=original_data_url,
809+
paper_url=paper_url
810+
)
811+
upload_did = dataset.publish()
812+
self.assertEqual(
813+
_get_online_dataset_arff(upload_did),
814+
dataset._dataset,
815+
"Uploaded ARFF does not match original one"
816+
)
817+
self.assertEqual(
818+
_get_online_dataset_format(upload_did),
819+
'sparse_arff',
820+
"Wrong format for dataset"
821+
)
822+
772823
# Check that we can overwrite the attributes
773824
data = [['a'], ['b'], ['c'], ['d'], ['e']]
774825
column_names = ['rnd_str']

0 commit comments

Comments
 (0)