Skip to content

Commit d7e085f

Browse files
samediiFelixAbrahamsson
authored andcommitted
improve: concat keeps dataframe
1 parent 57cb8f3 commit d7e085f

File tree

2 files changed

+44
-9
lines changed

2 files changed

+44
-9
lines changed

datastream/dataset.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -413,15 +413,28 @@ def concat(datasets: List[Dataset]) -> Dataset[R]:
413413
'''
414414
from_concat_mapping = Dataset.create_from_concat_mapping(datasets)
415415

416-
def get_item(dataframe, index):
417-
dataset_index, inner_index = from_concat_mapping(index)
418-
return datasets[dataset_index][inner_index]
416+
if any([dataset.dataframe is None for dataset in datasets]):
419417

420-
return Dataset(
421-
dataframe=None, # TODO: concat dataframes?
422-
length=sum(map(len, datasets)),
423-
get_item=get_item,
424-
)
418+
def get_item(dataframe, index):
419+
dataset_index, inner_index = from_concat_mapping(index)
420+
return datasets[dataset_index][inner_index]
421+
422+
return Dataset(
423+
dataframe=None,
424+
length=sum(map(len, datasets)),
425+
get_item=get_item,
426+
)
427+
else:
428+
429+
def get_item(dataframe, index):
430+
dataset_index, _ = from_concat_mapping(index)
431+
return datasets[dataset_index].get_item(dataframe, index)
432+
433+
return Dataset(
434+
dataframe=pd.concat([dataset.dataframe for dataset in datasets]),
435+
length=sum(map(len, datasets)),
436+
get_item=get_item,
437+
)
425438

426439
@staticmethod
427440
def create_from_combine_mapping(datasets):
@@ -600,6 +613,28 @@ def test_concat_dataset():
600613
assert dataset[6] == 1
601614

602615

616+
def test_concat_heterogenous_datasets():
617+
dataset1 = Dataset.from_dataframe(
618+
pd.DataFrame(dict(a=[1], b=['a'])).set_index('a'),
619+
)
620+
dataset2 = Dataset.from_dataframe(
621+
pd.DataFrame(dict(a=[1], b=[1], c=[2])).set_index('a'),
622+
)
623+
dataset = (
624+
Dataset.concat([dataset1, dataset2])
625+
.map(lambda row: row['b'])
626+
)
627+
628+
assert list(dataset) == ['a', 1]
629+
630+
dataset_other_functions = Dataset.concat([
631+
dataset1.map(lambda row: row['b']),
632+
dataset2.map(lambda row: row['c']),
633+
])
634+
635+
assert list(dataset_other_functions) == ['a', 2]
636+
637+
603638
def test_zip_dataset():
604639
dataset = Dataset.zip([
605640
Dataset.from_subscriptable(list(range(5))),

datastream/datastream.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ def sample_proportion(
236236
def take(
237237
self: Datastream[T],
238238
n_samples: PositiveInt,
239-
) -> Datastream[T]:
239+
) -> Datastream[T]:
240240
'''
241241
Like :func:`Datastream.sample_proportion` but specify the number of
242242
samples instead of a proportion.

0 commit comments

Comments
 (0)