Skip to content

Commit a0dd35a

Browse files
committed
Split map into map/starmap and compose functions without starcompose
1 parent c64f15c commit a0dd35a

File tree

2 files changed

+62
-48
lines changed

2 files changed

+62
-48
lines changed

datastream/dataset.py

Lines changed: 46 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class Dataset(BaseModel, torch.utils.data.Dataset, Generic[T]):
2828
... )
2929
>>> dataset = (
3030
... Dataset.from_subscriptable(fruit_and_cost)
31-
... .map(lambda fruit, cost: (
31+
... .starmap(lambda fruit, cost: (
3232
... fruit,
3333
... cost * 2,
3434
... ))
@@ -39,27 +39,12 @@ class Dataset(BaseModel, torch.utils.data.Dataset, Generic[T]):
3939

4040
dataframe: Optional[pd.DataFrame]
4141
length: int
42-
functions: Tuple[Callable[..., Any], ...]
43-
composed_fn: Callable[[pd.DataFrame, int], T]
42+
get_item: Callable[[pd.DataFrame, int], T]
4443

4544
class Config:
4645
arbitrary_types_allowed = True
4746
allow_mutation = False
4847

49-
def __init__(
50-
self,
51-
dataframe: pd.DataFrame,
52-
length: int,
53-
functions: Tuple[Callable[..., Any], ...],
54-
):
55-
BaseModel.__init__(
56-
self,
57-
dataframe=dataframe,
58-
length=length,
59-
functions=functions,
60-
composed_fn=tools.starcompose(*functions),
61-
)
62-
6348
@staticmethod
6449
def from_subscriptable(subscriptable) -> Dataset:
6550
'''
@@ -83,11 +68,11 @@ def from_dataframe(dataframe: pd.DataFrame) -> Dataset[pd.Series]:
8368
return Dataset(
8469
dataframe=dataframe,
8570
length=len(dataframe),
86-
functions=tuple([lambda df, index: df.iloc[index]]),
71+
get_item=lambda df, index: df.iloc[index],
8772
)
8873

8974
def __getitem__(self: Dataset[T], index: int) -> T:
90-
return self.composed_fn(self.dataframe, index)
75+
return self.get_item(self.dataframe, index)
9176

9277
def __len__(self):
9378
return self.length
@@ -115,25 +100,42 @@ def __eq__(self: Dataset[T], other: Dataset[R]) -> bool:
115100
return True
116101

117102
def map(
118-
self: Dataset[T], function: Callable[Union[[T], [...]], R]
103+
self: Dataset[T], function: Callable[[T], R]
119104
) -> Dataset[R]:
120105
'''
121106
Creates a new dataset with the function added to the dataset pipeline.
122-
Returned tuples are expanded as \\*args for the next mapped function.
123107
124108
>>> (
125109
... Dataset.from_subscriptable([1, 2, 3])
126-
... .map(lambda number: (number, number + 1))
127-
... .map(lambda number, plus_one: number + plus_one)
110+
... .map(lambda number: number + 1)
128111
... )[-1]
129-
7
112+
4
130113
'''
131114
return Dataset(
132115
dataframe=self.dataframe,
133116
length=self.length,
134-
functions=self.functions + tuple([function]),
117+
get_item=lambda dataframe, index: function(
118+
self.get_item(dataframe, index)
119+
),
135120
)
136121

122+
def starmap(
123+
self: Dataset[T], function: Callable[Union[..., R]]
124+
) -> Dataset[R]:
125+
'''
126+
Creates a new dataset with the function added to the dataset pipeline.
127+
The functions expects iterables that are expanded as \\*args for the
128+
mapped function.
129+
130+
>>> (
131+
... Dataset.from_subscriptable([1, 2, 3])
132+
... .map(lambda number: (number, number + 1))
133+
... .starmap(lambda number, plus_one: number + plus_one)
134+
... )[-1]
135+
7
136+
'''
137+
return self.map(tools.star(function))
138+
137139
def subset(
138140
self, mask_fn: Callable[
139141
[pd.DataFrame], Union[pd.Series, np.array, List[bool]]
@@ -171,7 +173,7 @@ def subset(
171173
return Dataset(
172174
dataframe=self.dataframe.iloc[indices],
173175
length=len(indices),
174-
functions=self.functions,
176+
get_item=self.get_item,
175177
)
176178

177179
def split(
@@ -224,7 +226,7 @@ def split(
224226
split_name: Dataset(
225227
dataframe=dataframe,
226228
length=len(dataframe),
227-
functions=self.functions,
229+
get_item=self.get_item,
228230
)
229231
for split_name, dataframe in split_dataframes(
230232
self.dataframe,
@@ -241,14 +243,13 @@ def zip_index(self: Dataset[T]) -> Dataset[Tuple[T, int]]:
241243
Zip the output with its index. The output of the pipeline will be
242244
a tuple ``(output, index)``.
243245
'''
244-
composed_fn = self.composed_fn
245246
return Dataset(
246247
dataframe=self.dataframe,
247248
length=self.length,
248-
functions=tuple([lambda dataframe, index: (
249-
composed_fn(dataframe, index),
249+
get_item=lambda dataframe, index: (
250+
self.get_item(dataframe, index),
250251
index,
251-
)]),
252+
),
252253
)
253254

254255
@staticmethod
@@ -287,15 +288,14 @@ def concat(datasets: List[Dataset]) -> Dataset[R]:
287288
'''
288289
from_concat_mapping = Dataset.create_from_concat_mapping(datasets)
289290

291+
def get_item(dataframe, index):
292+
dataset_index, inner_index = from_concat_mapping(index)
293+
return datasets[dataset_index][inner_index]
294+
290295
return Dataset(
291-
dataframe=pd.DataFrame(dict(dataset_index=range(len(datasets)))),
296+
dataframe=None, # TODO: concat dataframes?
292297
length=sum(map(len, datasets)),
293-
functions=(
294-
lambda index_dataframe, index: from_concat_mapping(index),
295-
lambda dataset_index, inner_index: (
296-
datasets[dataset_index][inner_index]
297-
),
298-
),
298+
get_item=get_item,
299299
)
300300

301301
@staticmethod
@@ -336,15 +336,17 @@ def combine(datasets: List[Dataset]) -> Dataset[Tuple]:
336336
datasets are often very long and it is expensive to enumerate them.
337337
'''
338338
from_combine_mapping = Dataset.create_from_combine_mapping(datasets)
339+
340+
def get_item(dataframe, index):
341+
indices = from_combine_mapping(index)
342+
return tuple([
343+
dataset[index] for dataset, index in zip(datasets, indices)
344+
])
345+
339346
return Dataset(
340347
dataframe=None,
341348
length=np.prod(list(map(len, datasets))),
342-
functions=(
343-
lambda _, index: from_combine_mapping(index),
344-
lambda *indices: tuple([
345-
dataset[index] for dataset, index in zip(datasets, indices)
346-
]),
347-
),
349+
get_item=get_item,
348350
)
349351

350352
@staticmethod

datastream/datastream.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def zip(datastreams: List[Datastream]) -> Datastream[Tuple]:
117117
)
118118

119119
def map(
120-
self: Datastream[T], function: Callable[Union[[T], [...]], R]
120+
self: Datastream[T], function: Callable[[T], R]
121121
) -> Datastream[R]:
122122
'''
123123
Creates a new Datastream with a new mapped dataset. See
@@ -128,6 +128,18 @@ def map(
128128
self.sampler,
129129
)
130130

131+
def starmap(
132+
self: Datastream[T], function: Callable[[...], R]
133+
) -> Datastream[R]:
134+
'''
135+
Creates a new Datastream with a new starmapped dataset. See
136+
:func:`Dataset.starmap` for details.
137+
'''
138+
return Datastream(
139+
self.dataset.starmap(function),
140+
self.sampler,
141+
)
142+
131143
def data_loader(
132144
self,
133145
n_batches_per_epoch: int = None,
@@ -306,7 +318,7 @@ def test_datastream_simple_weights():
306318
datastream = (
307319
Datastream(dataset)
308320
.zip_index()
309-
.map(lambda integer, index: dict(
321+
.starmap(lambda integer, index: dict(
310322
integer=integer,
311323
index=index,
312324
))
@@ -342,7 +354,7 @@ def test_merge_datastream_weights():
342354
for dataset in datasets
343355
])
344356
.zip_index()
345-
.map(lambda integer, index: dict(
357+
.starmap(lambda integer, index: dict(
346358
integer=integer,
347359
index=index,
348360
))
@@ -371,7 +383,7 @@ def test_multi_sample():
371383
.multi_sample(n_multi_sample)
372384
.sample_proportion(0.5)
373385
.zip_index()
374-
.map(lambda number, index: (number ** 0.5, index))
386+
.starmap(lambda number, index: (number ** 0.5, index))
375387
)
376388

377389
output = [

0 commit comments

Comments
 (0)