@@ -28,7 +28,7 @@ class Dataset(BaseModel, torch.utils.data.Dataset, Generic[T]):
28
28
... )
29
29
>>> dataset = (
30
30
... Dataset.from_subscriptable(fruit_and_cost)
31
- ... .map (lambda fruit, cost: (
31
+ ... .starmap (lambda fruit, cost: (
32
32
... fruit,
33
33
... cost * 2,
34
34
... ))
@@ -39,27 +39,12 @@ class Dataset(BaseModel, torch.utils.data.Dataset, Generic[T]):
39
39
40
40
dataframe : Optional [pd .DataFrame ]
41
41
length : int
42
- functions : Tuple [Callable [..., Any ], ...]
43
- composed_fn : Callable [[pd .DataFrame , int ], T ]
42
+ get_item : Callable [[pd .DataFrame , int ], T ]
44
43
45
44
class Config :
46
45
arbitrary_types_allowed = True
47
46
allow_mutation = False
48
47
49
- def __init__ (
50
- self ,
51
- dataframe : pd .DataFrame ,
52
- length : int ,
53
- functions : Tuple [Callable [..., Any ], ...],
54
- ):
55
- BaseModel .__init__ (
56
- self ,
57
- dataframe = dataframe ,
58
- length = length ,
59
- functions = functions ,
60
- composed_fn = tools .starcompose (* functions ),
61
- )
62
-
63
48
@staticmethod
64
49
def from_subscriptable (subscriptable ) -> Dataset :
65
50
'''
@@ -83,11 +68,11 @@ def from_dataframe(dataframe: pd.DataFrame) -> Dataset[pd.Series]:
83
68
return Dataset (
84
69
dataframe = dataframe ,
85
70
length = len (dataframe ),
86
- functions = tuple ([ lambda df , index : df .iloc [index ]]) ,
71
+ get_item = lambda df , index : df .iloc [index ],
87
72
)
88
73
89
74
def __getitem__ (self : Dataset [T ], index : int ) -> T :
90
- return self .composed_fn (self .dataframe , index )
75
+ return self .get_item (self .dataframe , index )
91
76
92
77
def __len__ (self ):
93
78
return self .length
@@ -115,25 +100,42 @@ def __eq__(self: Dataset[T], other: Dataset[R]) -> bool:
115
100
return True
116
101
117
102
def map (
118
- self : Dataset [T ], function : Callable [Union [[ T ], [...] ], R ]
103
+ self : Dataset [T ], function : Callable [[ T ], R ]
119
104
) -> Dataset [R ]:
120
105
'''
121
106
Creates a new dataset with the function added to the dataset pipeline.
122
- Returned tuples are expanded as \\ *args for the next mapped function.
123
107
124
108
>>> (
125
109
... Dataset.from_subscriptable([1, 2, 3])
126
- ... .map(lambda number: (number, number + 1))
127
- ... .map(lambda number, plus_one: number + plus_one)
110
+ ... .map(lambda number: number + 1)
128
111
... )[-1]
129
- 7
112
+ 4
130
113
'''
131
114
return Dataset (
132
115
dataframe = self .dataframe ,
133
116
length = self .length ,
134
- functions = self .functions + tuple ([function ]),
117
+ get_item = lambda dataframe , index : function (
118
+ self .get_item (dataframe , index )
119
+ ),
135
120
)
136
121
122
+ def starmap (
123
+ self : Dataset [T ], function : Callable [Union [..., R ]]
124
+ ) -> Dataset [R ]:
125
+ '''
126
+ Creates a new dataset with the function added to the dataset pipeline.
127
+ The functions expects iterables that are expanded as \\ *args for the
128
+ mapped function.
129
+
130
+ >>> (
131
+ ... Dataset.from_subscriptable([1, 2, 3])
132
+ ... .map(lambda number: (number, number + 1))
133
+ ... .starmap(lambda number, plus_one: number + plus_one)
134
+ ... )[-1]
135
+ 7
136
+ '''
137
+ return self .map (tools .star (function ))
138
+
137
139
def subset (
138
140
self , mask_fn : Callable [
139
141
[pd .DataFrame ], Union [pd .Series , np .array , List [bool ]]
@@ -171,7 +173,7 @@ def subset(
171
173
return Dataset (
172
174
dataframe = self .dataframe .iloc [indices ],
173
175
length = len (indices ),
174
- functions = self .functions ,
176
+ get_item = self .get_item ,
175
177
)
176
178
177
179
def split (
@@ -224,7 +226,7 @@ def split(
224
226
split_name : Dataset (
225
227
dataframe = dataframe ,
226
228
length = len (dataframe ),
227
- functions = self .functions ,
229
+ get_item = self .get_item ,
228
230
)
229
231
for split_name , dataframe in split_dataframes (
230
232
self .dataframe ,
@@ -241,14 +243,13 @@ def zip_index(self: Dataset[T]) -> Dataset[Tuple[T, int]]:
241
243
Zip the output with its index. The output of the pipeline will be
242
244
a tuple ``(output, index)``.
243
245
'''
244
- composed_fn = self .composed_fn
245
246
return Dataset (
246
247
dataframe = self .dataframe ,
247
248
length = self .length ,
248
- functions = tuple ([ lambda dataframe , index : (
249
- composed_fn (dataframe , index ),
249
+ get_item = lambda dataframe , index : (
250
+ self . get_item (dataframe , index ),
250
251
index ,
251
- )]) ,
252
+ ),
252
253
)
253
254
254
255
@staticmethod
@@ -287,15 +288,14 @@ def concat(datasets: List[Dataset]) -> Dataset[R]:
287
288
'''
288
289
from_concat_mapping = Dataset .create_from_concat_mapping (datasets )
289
290
291
+ def get_item (dataframe , index ):
292
+ dataset_index , inner_index = from_concat_mapping (index )
293
+ return datasets [dataset_index ][inner_index ]
294
+
290
295
return Dataset (
291
- dataframe = pd . DataFrame ( dict ( dataset_index = range ( len ( datasets )))),
296
+ dataframe = None , # TODO: concat dataframes?
292
297
length = sum (map (len , datasets )),
293
- functions = (
294
- lambda index_dataframe , index : from_concat_mapping (index ),
295
- lambda dataset_index , inner_index : (
296
- datasets [dataset_index ][inner_index ]
297
- ),
298
- ),
298
+ get_item = get_item ,
299
299
)
300
300
301
301
@staticmethod
@@ -336,15 +336,17 @@ def combine(datasets: List[Dataset]) -> Dataset[Tuple]:
336
336
datasets are often very long and it is expensive to enumerate them.
337
337
'''
338
338
from_combine_mapping = Dataset .create_from_combine_mapping (datasets )
339
+
340
+ def get_item (dataframe , index ):
341
+ indices = from_combine_mapping (index )
342
+ return tuple ([
343
+ dataset [index ] for dataset , index in zip (datasets , indices )
344
+ ])
345
+
339
346
return Dataset (
340
347
dataframe = None ,
341
348
length = np .prod (list (map (len , datasets ))),
342
- functions = (
343
- lambda _ , index : from_combine_mapping (index ),
344
- lambda * indices : tuple ([
345
- dataset [index ] for dataset , index in zip (datasets , indices )
346
- ]),
347
- ),
349
+ get_item = get_item ,
348
350
)
349
351
350
352
@staticmethod
0 commit comments