@@ -37,7 +37,7 @@ class Dataset(BaseModel, torch.utils.data.Dataset, Generic[T]):
37
37
('banana', 28)
38
38
'''
39
39
40
- dataframe : pd .DataFrame
40
+ dataframe : Optional [ pd .DataFrame ]
41
41
length : int
42
42
functions : Tuple [Callable [..., Any ], ...]
43
43
composed_fn : Callable [[pd .DataFrame , int ], T ]
@@ -331,14 +331,16 @@ def combine(datasets: List[Dataset]) -> Dataset[Tuple]:
331
331
Zip multiple datasets together so that all combinations of examples
332
332
are possible (i.e. the product) creating tuples like
333
333
``(example1, example2, ...)``.
334
+
335
+ The created dataset will not have a dataframe because combined
336
+ datasets are often very long and it is expensive to enumerate them.
334
337
'''
335
338
from_combine_mapping = Dataset .create_from_combine_mapping (datasets )
336
-
337
339
return Dataset (
338
- dataframe = pd . DataFrame ( dict ( dataset_index = range ( len ( datasets )))) ,
340
+ dataframe = None ,
339
341
length = np .prod (list (map (len , datasets ))),
340
342
functions = (
341
- lambda index_dataframe , index : from_combine_mapping (index ),
343
+ lambda _ , index : from_combine_mapping (index ),
342
344
lambda * indices : tuple ([
343
345
dataset [index ] for dataset , index in zip (datasets , indices )
344
346
]),
@@ -353,13 +355,31 @@ def zip(datasets: List[Dataset]) -> Dataset[Tuple]:
353
355
354
356
The length of the created dataset is the minimum length of the zipped
355
357
datasets.
358
+
359
+ The created dataset's dataframe is a the concatenation of the input
360
+ datasets' dataframes. It is concatenated over columns with an added
361
+ multiindex column like this:
362
+ ``pd.concat(dataframes, axis=1, keys=['dataset0', 'dataset1', ...])``
356
363
'''
357
- return Dataset (
358
- dataframe = pd .DataFrame (dict (dataset_index = range (len (datasets )))),
359
- length = min (map (len , datasets )),
360
- functions = tuple ([lambda index_dataframe , index : tuple (
364
+ length = min (map (len , datasets ))
365
+ return (
366
+ Dataset .from_dataframe (
367
+ pd .concat (
368
+ [
369
+ dataset .dataframe .iloc [:length ].reset_index ()
370
+ for dataset in datasets
371
+ ],
372
+ axis = 1 ,
373
+ keys = [
374
+ f'dataset{ dataset_index } '
375
+ for dataset_index in range (len (datasets ))
376
+ ],
377
+ ).assign (_index = list (range (length )))
378
+ )
379
+ .map (lambda row : row ['_index' ].iloc [0 ])
380
+ .map (lambda index : tuple (
361
381
dataset [index ] for dataset in datasets
362
- )]),
382
+ ))
363
383
)
364
384
365
385
@@ -420,6 +440,12 @@ def test_zip_dataset():
420
440
421
441
assert dataset [3 ] == (3 , 3 )
422
442
443
+ for x , y in zip (
444
+ dataset .subset (lambda df : np .arange (len (df )) <= 2 ),
445
+ dataset ,
446
+ ):
447
+ assert x == y
448
+
423
449
424
450
def test_combine_dataset ():
425
451
from itertools import product
0 commit comments