Skip to content

Commit c64f15c

Browse files
committed
Zip dataframe
1 parent e5fc0a9 commit c64f15c

File tree

2 files changed

+36
-10
lines changed

2 files changed

+36
-10
lines changed

datastream/dataset.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class Dataset(BaseModel, torch.utils.data.Dataset, Generic[T]):
3737
('banana', 28)
3838
'''
3939

40-
dataframe: pd.DataFrame
40+
dataframe: Optional[pd.DataFrame]
4141
length: int
4242
functions: Tuple[Callable[..., Any], ...]
4343
composed_fn: Callable[[pd.DataFrame, int], T]
@@ -331,14 +331,16 @@ def combine(datasets: List[Dataset]) -> Dataset[Tuple]:
331331
Zip multiple datasets together so that all combinations of examples
332332
are possible (i.e. the product) creating tuples like
333333
``(example1, example2, ...)``.
334+
335+
The created dataset will not have a dataframe because combined
336+
datasets are often very long and it is expensive to enumerate them.
334337
'''
335338
from_combine_mapping = Dataset.create_from_combine_mapping(datasets)
336-
337339
return Dataset(
338-
dataframe=pd.DataFrame(dict(dataset_index=range(len(datasets)))),
340+
dataframe=None,
339341
length=np.prod(list(map(len, datasets))),
340342
functions=(
341-
lambda index_dataframe, index: from_combine_mapping(index),
343+
lambda _, index: from_combine_mapping(index),
342344
lambda *indices: tuple([
343345
dataset[index] for dataset, index in zip(datasets, indices)
344346
]),
@@ -353,13 +355,31 @@ def zip(datasets: List[Dataset]) -> Dataset[Tuple]:
353355
354356
The length of the created dataset is the minimum length of the zipped
355357
datasets.
358+
359+
The created dataset's dataframe is a the concatenation of the input
360+
datasets' dataframes. It is concatenated over columns with an added
361+
multiindex column like this:
362+
``pd.concat(dataframes, axis=1, keys=['dataset0', 'dataset1', ...])``
356363
'''
357-
return Dataset(
358-
dataframe=pd.DataFrame(dict(dataset_index=range(len(datasets)))),
359-
length=min(map(len, datasets)),
360-
functions=tuple([lambda index_dataframe, index: tuple(
364+
length = min(map(len, datasets))
365+
return (
366+
Dataset.from_dataframe(
367+
pd.concat(
368+
[
369+
dataset.dataframe.iloc[:length].reset_index()
370+
for dataset in datasets
371+
],
372+
axis=1,
373+
keys=[
374+
f'dataset{dataset_index}'
375+
for dataset_index in range(len(datasets))
376+
],
377+
).assign(_index=list(range(length)))
378+
)
379+
.map(lambda row: row['_index'].iloc[0])
380+
.map(lambda index: tuple(
361381
dataset[index] for dataset in datasets
362-
)]),
382+
))
363383
)
364384

365385

@@ -420,6 +440,12 @@ def test_zip_dataset():
420440

421441
assert dataset[3] == (3, 3)
422442

443+
for x, y in zip(
444+
dataset.subset(lambda df: np.arange(len(df)) <= 2),
445+
dataset,
446+
):
447+
assert x == y
448+
423449

424450
def test_combine_dataset():
425451
from itertools import product

datastream/datastream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def test_datastream_zip():
271271

272272
def test_datastream_merge_zip_merge():
273273
'''
274-
repeating because it only sometimes recreated an error that occured
274+
Repeating because it only sometimes recreated an error that occured
275275
when using mixup/mixmatch
276276
'''
277277

0 commit comments

Comments
 (0)