Skip to content

Commit fc5c84f

Browse files
Add column_names to IterableDataset (#5582)
* Add column_names to IterableDataset (#5383) * Add column_names property * Add multiple tests for this new property * Style --------- Co-authored-by: Mario Šaško <[email protected]>
1 parent 3e62699 commit fc5c84f

File tree

2 files changed

+37
-2
lines changed

2 files changed

+37
-2
lines changed

src/datasets/iterable_dataset.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1374,6 +1374,21 @@ def take(self, n) -> "IterableDataset":
13741374
token_per_repo_id=self._token_per_repo_id,
13751375
)
13761376

1377+
@property
1378+
def column_names(self) -> Optional[List[str]]:
1379+
"""Names of the columns in the dataset.
1380+
1381+
Example:
1382+
1383+
```py
1384+
>>> from datasets import load_dataset
1385+
>>> ds = load_dataset("rotten_tomatoes", split="validation", streaming=True)
1386+
>>> ds.column_names
1387+
['text', 'label']
1388+
```
1389+
"""
1390+
return list(self._info.features.keys()) if self._info.features is not None else None
1391+
13771392
def add_column(self, name: str, column: Union[list, np.array]) -> "IterableDataset":
13781393
"""Add column to Dataset.
13791394

tests/test_iterable_dataset.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,8 @@ def test_iterable_dataset_add_column(dataset_with_several_columns):
966966
assert list(new_dataset) == [
967967
{**example, "new_column": idx} for idx, example in enumerate(dataset_with_several_columns)
968968
]
969+
new_dataset = new_dataset._resolve_features()
970+
assert "new_column" in new_dataset.column_names
969971

970972

971973
def test_iterable_dataset_rename_column(dataset_with_several_columns):
@@ -974,9 +976,13 @@ def test_iterable_dataset_rename_column(dataset_with_several_columns):
974976
{("new_id" if k == "id" else k): v for k, v in example.items()} for example in dataset_with_several_columns
975977
]
976978
assert new_dataset.features is None
979+
assert new_dataset.column_names is None
977980
# rename the column if ds.features was not None
978981
new_dataset = dataset_with_several_columns._resolve_features().rename_column("id", "new_id")
979982
assert new_dataset.features is not None
983+
assert new_dataset.column_names is not None
984+
assert "id" not in new_dataset.column_names
985+
assert "new_id" in new_dataset.column_names
980986

981987

982988
def test_iterable_dataset_rename_columns(dataset_with_several_columns):
@@ -986,9 +992,13 @@ def test_iterable_dataset_rename_columns(dataset_with_several_columns):
986992
{column_mapping.get(k, k): v for k, v in example.items()} for example in dataset_with_several_columns
987993
]
988994
assert new_dataset.features is None
995+
assert new_dataset.column_names is None
989996
# rename the columns if ds.features was not None
990997
new_dataset = dataset_with_several_columns._resolve_features().rename_columns(column_mapping)
991998
assert new_dataset.features is not None
999+
assert new_dataset.column_names is not None
1000+
assert all(c not in new_dataset.column_names for c in ["id", "filepath"])
1001+
assert all(c in new_dataset.column_names for c in ["new_id", "filename"])
9921002

9931003

9941004
def test_iterable_dataset_remove_columns(dataset_with_several_columns):
@@ -1002,10 +1012,13 @@ def test_iterable_dataset_remove_columns(dataset_with_several_columns):
10021012
{k: v for k, v in example.items() if k != "id" and k != "filepath"} for example in dataset_with_several_columns
10031013
]
10041014
assert new_dataset.features is None
1015+
assert new_dataset.column_names is None
10051016
# remove the columns if ds.features was not None
10061017
new_dataset = dataset_with_several_columns._resolve_features().remove_columns(["id", "filepath"])
10071018
assert new_dataset.features is not None
1019+
assert new_dataset.column_names is not None
10081020
assert all(c not in new_dataset.features for c in ["id", "filepath"])
1021+
assert all(c not in new_dataset.column_names for c in ["id", "filepath"])
10091022

10101023

10111024
def test_iterable_dataset_select_columns(dataset_with_several_columns):
@@ -1019,10 +1032,12 @@ def test_iterable_dataset_select_columns(dataset_with_several_columns):
10191032
{k: v for k, v in example.items() if k in ("id", "filepath")} for example in dataset_with_several_columns
10201033
]
10211034
assert new_dataset.features is None
1022-
# remove the columns if ds.features was not None
1035+
# select the columns if ds.features was not None
10231036
new_dataset = dataset_with_several_columns._resolve_features().select_columns(["id", "filepath"])
10241037
assert new_dataset.features is not None
1038+
assert new_dataset.column_names is not None
10251039
assert all(c in new_dataset.features for c in ["id", "filepath"])
1040+
assert all(c in new_dataset.column_names for c in ["id", "filepath"])
10261041

10271042

10281043
def test_iterable_dataset_cast_column():
@@ -1046,12 +1061,16 @@ def test_iterable_dataset_cast():
10461061

10471062
def test_iterable_dataset_resolve_features():
10481063
ex_iterable = ExamplesIterable(generate_examples_fn, {})
1049-
dataset = IterableDataset(ex_iterable)._resolve_features()
1064+
dataset = IterableDataset(ex_iterable)
1065+
assert dataset.features is None
1066+
assert dataset.column_names is None
1067+
dataset = dataset._resolve_features()
10501068
assert dataset.features == Features(
10511069
{
10521070
"id": Value("int64"),
10531071
}
10541072
)
1073+
assert dataset.column_names == ["id"]
10551074

10561075

10571076
def test_iterable_dataset_resolve_features_keep_order():
@@ -1062,6 +1081,7 @@ def gen():
10621081
dataset = IterableDataset(ex_iterable)._resolve_features()
10631082
# columns appear in order of appearance in the dataset
10641083
assert list(dataset.features) == ["a", "c", "b"]
1084+
assert dataset.column_names == ["a", "c", "b"]
10651085

10661086

10671087
def test_iterable_dataset_with_features_fill_with_none():

0 commit comments

Comments
 (0)