Skip to content

Commit 9dd00c4

Browse files
authored
Lazy column (#7614)
* lazy column * docs * fix tests * fix tests * fix tests * again * again
1 parent 3573d75 commit 9dd00c4

File tree

10 files changed

+149
-97
lines changed

10 files changed

+149
-97
lines changed

docs/source/access.mdx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ You can combine row and column name indexing to return a specific value at a pos
5454
'the rock is destined to be the 21st century\'s new " conan " and that he\'s going to make a splash even greater than arnold schwarzenegger , jean-claud van damme or steven segal .'
5555
```
5656

57-
But it is important to remember that indexing order matters, especially when working with large audio and image datasets. Indexing by the column name returns all the values in the column first, then loads the value at that position. For large datasets, it may be slower to index by the column name first.
57+
Indexing order doesn't matter. Indexing by the column name first returns a [`Column`] object that you can index as usual with row indices as usual:
5858

5959
```py
6060
>>> import time
@@ -69,7 +69,7 @@ Elapsed time: 0.0031 seconds
6969
>>> text = dataset["text"][0]
7070
>>> end_time = time.time()
7171
>>> print(f"Elapsed time: {end_time - start_time:.4f} seconds")
72-
Elapsed time: 0.0094 seconds
72+
Elapsed time: 0.0042 seconds
7373
```
7474

7575
### Slicing

src/datasets/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
__version__ = "3.6.0.dev0"
1616

17-
from .arrow_dataset import Dataset
17+
from .arrow_dataset import Column, Dataset
1818
from .arrow_reader import ReadInstruction
1919
from .builder import ArrowBasedBuilder, BuilderConfig, DatasetBuilder, GeneratorBasedBuilder
2020
from .combine import concatenate_datasets, interleave_datasets
@@ -30,7 +30,7 @@
3030
get_dataset_infos,
3131
get_dataset_split_names,
3232
)
33-
from .iterable_dataset import IterableDataset
33+
from .iterable_dataset import IterableColumn, IterableDataset
3434
from .load import load_dataset, load_dataset_builder, load_from_disk
3535
from .splits import (
3636
NamedSplit,

src/datasets/arrow_dataset.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,48 @@ class NonExistentDatasetError(Exception):
627627
pass
628628

629629

630+
class Column(Sequence_):
631+
"""An iterable for a specific column of an [`Dataset`]."""
632+
633+
def __init__(self, source: Union["Dataset", "Column"], column_name: str):
634+
self.source = source
635+
self.column_name = column_name
636+
if not isinstance(source.features, dict) or column_name not in source.features:
637+
raise ValueError(f"Column '{column_name}' doesn't exist.")
638+
self.features = source.features[column_name]
639+
640+
def __iter__(self) -> Iterator[Any]:
641+
if isinstance(self.source, Dataset):
642+
source = self.source._fast_select_column(self.column_name)
643+
for example in source:
644+
yield example[self.column_name]
645+
646+
def __getitem__(self, key: Union[int, str, list[int]]) -> Any:
647+
if isinstance(key, str):
648+
return Column(self, key)
649+
elif isinstance(self.source, Dataset):
650+
return self.source._fast_select_column(self.column_name)[key][self.column_name]
651+
elif isinstance(key, int):
652+
return self.source[key][self.column_name]
653+
else:
654+
return [item[self.column_name] for item in self.source[key]]
655+
656+
def __len__(self) -> int:
657+
return len(self.source)
658+
659+
def __repr__(self):
660+
return "Column(" + repr(list(self[:5])) + ")"
661+
662+
def __str__(self):
663+
return "Column(" + str(list(self[:5])) + ")"
664+
665+
def __eq__(self, value):
666+
if isinstance(value, Column):
667+
return list(self) == list(value)
668+
else:
669+
return value == list(self)
670+
671+
630672
class Dataset(DatasetInfoMixin, IndexableMixin, TensorflowDatasetMixin):
631673
"""A Dataset backed by an Arrow table."""
632674

@@ -2354,6 +2396,13 @@ def select_columns(self, column_names: Union[str, list[str]], new_fingerprint: O
23542396
dataset._fingerprint = new_fingerprint
23552397
return dataset
23562398

2399+
@transmit_format
2400+
def _fast_select_column(self, column_name: str) -> "Dataset":
2401+
dataset = copy.copy(self)
2402+
dataset._data = dataset._data.select([column_name])
2403+
dataset._info = DatasetInfo(features=Features({column_name: self._info.features[column_name]}))
2404+
return dataset
2405+
23572406
def __len__(self):
23582407
"""Number of rows in the dataset.
23592408
@@ -2776,6 +2825,9 @@ def __getitem__(self, key: str) -> list: # noqa: F811
27762825

27772826
def __getitem__(self, key): # noqa: F811
27782827
"""Can be used to index columns (by string names) or rows (by integer index or iterable of indices or bools)."""
2828+
if isinstance(key, str):
2829+
if self._format_type is None or self._format_type not in ("arrow", "pandas", "polars"):
2830+
return Column(self, key)
27792831
return self._getitem(key)
27802832

27812833
def __getitems__(self, keys: list) -> list:

tests/features/test_array_xd.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,7 @@ def get_dict_examples(self, shape_1, shape_2):
173173
}
174174

175175
def _check_getitem_output_type(self, dataset, shape_1, shape_2, first_matrix):
176-
matrix_column = dataset["matrix"]
176+
matrix_column = dataset["matrix"][:]
177177
self.assertIsInstance(matrix_column, list)
178178
self.assertIsInstance(matrix_column[0], list)
179179
self.assertIsInstance(matrix_column[0][0], list)
@@ -192,7 +192,7 @@ def _check_getitem_output_type(self, dataset, shape_1, shape_2, first_matrix):
192192
self.assertTupleEqual(np.array(matrix_field_of_first_two_examples).shape, (2, *shape_2))
193193

194194
with dataset.formatted_as("numpy"):
195-
self.assertTupleEqual(dataset["matrix"].shape, (2, *shape_2))
195+
self.assertTupleEqual(dataset["matrix"][:].shape, (2, *shape_2))
196196
self.assertEqual(dataset[0]["matrix"].shape, shape_2)
197197
self.assertTupleEqual(dataset[:2]["matrix"].shape, (2, *shape_2))
198198

tests/features/test_audio.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import pyarrow as pa
55
import pytest
66

7-
from datasets import Dataset, concatenate_datasets, load_dataset
7+
from datasets import Column, Dataset, concatenate_datasets, load_dataset
88
from datasets.features import Audio, Features, Sequence, Value
99

1010
from ..utils import (
@@ -292,7 +292,7 @@ def test_dataset_with_audio_feature_with_none():
292292
assert isinstance(batch["audio"], list) and all(item is None for item in batch["audio"])
293293
column = dset["audio"]
294294
assert len(column) == 1
295-
assert isinstance(column, list) and all(item is None for item in column)
295+
assert isinstance(column, Column) and all(item is None for item in column)
296296

297297
# nested tests
298298

tests/features/test_features.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pytest
1010

1111
from datasets import Array2D
12-
from datasets.arrow_dataset import Dataset
12+
from datasets.arrow_dataset import Column, Dataset
1313
from datasets.features import Audio, ClassLabel, Features, Image, LargeList, Sequence, Value
1414
from datasets.features.features import (
1515
_align_features,
@@ -492,7 +492,7 @@ def test_dataset_feature_with_none(feature):
492492
assert isinstance(batch["col"], list) and all(item is None for item in batch["col"])
493493
column = dset["col"]
494494
assert len(column) == 1
495-
assert isinstance(column, list) and all(item is None for item in column)
495+
assert isinstance(column, Column) and all(item is None for item in column)
496496

497497
# nested tests
498498

tests/features/test_image.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import pyarrow as pa
1010
import pytest
1111

12-
from datasets import Dataset, Features, Image, Sequence, Value, concatenate_datasets, load_dataset
12+
from datasets import Column, Dataset, Features, Image, Sequence, Value, concatenate_datasets, load_dataset
1313
from datasets.features.image import encode_np_array, image_to_bytes
1414

1515
from ..utils import require_pil
@@ -149,7 +149,7 @@ def test_dataset_with_image_feature(shared_datadir):
149149
assert batch["image"][0].mode == "RGB"
150150
column = dset["image"]
151151
assert len(column) == 1
152-
assert isinstance(column, list) and all(isinstance(item, PIL.Image.Image) for item in column)
152+
assert isinstance(column, Column) and all(isinstance(item, PIL.Image.Image) for item in column)
153153
assert os.path.samefile(column[0].filename, image_path)
154154
assert column[0].format == "JPEG"
155155
assert column[0].size == (640, 480)
@@ -182,7 +182,7 @@ def test_dataset_with_image_feature_from_pil_image(infer_feature, shared_datadir
182182
assert batch["image"][0].mode == "RGB"
183183
column = dset["image"]
184184
assert len(column) == 1
185-
assert isinstance(column, list) and all(isinstance(item, PIL.Image.Image) for item in column)
185+
assert isinstance(column, Column) and all(isinstance(item, PIL.Image.Image) for item in column)
186186
assert os.path.samefile(column[0].filename, image_path)
187187
assert column[0].format == "JPEG"
188188
assert column[0].size == (640, 480)
@@ -215,7 +215,7 @@ def test_dataset_with_image_feature_from_np_array():
215215
assert batch["image"][0].size == (640, 480)
216216
column = dset["image"]
217217
assert len(column) == 1
218-
assert isinstance(column, list) and all(isinstance(item, PIL.Image.Image) for item in column)
218+
assert isinstance(column, Column) and all(isinstance(item, PIL.Image.Image) for item in column)
219219
np.testing.assert_array_equal(np.array(column[0]), image_array)
220220
assert column[0].filename == ""
221221
assert column[0].format in ["PNG", "TIFF"]
@@ -250,7 +250,7 @@ def test_dataset_with_image_feature_tar_jpg(tar_jpg_path):
250250
assert batch["image"][0].mode == "RGB"
251251
column = dset["image"]
252252
assert len(column) == 1
253-
assert isinstance(column, list) and all(isinstance(item, PIL.Image.Image) for item in column)
253+
assert isinstance(column, Column) and all(isinstance(item, PIL.Image.Image) for item in column)
254254
assert column[0].filename == ""
255255
assert column[0].format == "JPEG"
256256
assert column[0].size == (640, 480)
@@ -271,7 +271,7 @@ def test_dataset_with_image_feature_with_none():
271271
assert isinstance(batch["image"], list) and all(item is None for item in batch["image"])
272272
column = dset["image"]
273273
assert len(column) == 1
274-
assert isinstance(column, list) and all(item is None for item in column)
274+
assert isinstance(column, Column) and all(item is None for item in column)
275275

276276
# nested tests
277277

@@ -527,8 +527,8 @@ def test_formatted_dataset_with_image_feature(shared_datadir):
527527
assert batch["image"].shape == (1, 480, 640, 3)
528528
column = dset["image"]
529529
assert len(column) == 2
530-
assert isinstance(column, np.ndarray)
531-
assert column.shape == (2, 480, 640, 3)
530+
assert isinstance(column[:], np.ndarray)
531+
assert column[:].shape == (2, 480, 640, 3)
532532

533533
with dset.formatted_as("pandas"):
534534
item = dset[0]

tests/features/test_video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from datasets import Dataset, Features, Video
3+
from datasets import Column, Dataset, Features, Video
44

55
from ..utils import require_torchvision
66

@@ -53,7 +53,7 @@ def test_dataset_with_video_feature(shared_datadir):
5353
assert isinstance(next(batch["video"][0])["data"], torch.Tensor)
5454
column = dset["video"]
5555
assert len(column) == 1
56-
assert isinstance(column, list) and all(isinstance(item, VideoReader) for item in column)
56+
assert isinstance(column, Column) and all(isinstance(item, VideoReader) for item in column)
5757
assert next(column[0])["data"].shape == (3, 50, 66)
5858
assert isinstance(next(column[0])["data"], torch.Tensor)
5959

0 commit comments

Comments
 (0)