17
17
class Dataset (BaseModel , torch .utils .data .Dataset , Generic [T ]):
18
18
'''
19
19
A ``Dataset[T]`` is a mapping that allows pipelining of functions in a
20
- readable syntax returning an item of type ``T``.
20
+ readable syntax returning an example of type ``T``.
21
21
22
22
>>> from datastream import Dataset
23
23
>>> fruit_and_cost = (
@@ -64,14 +64,28 @@ def from_subscriptable(subscriptable) -> Dataset:
64
64
65
65
@staticmethod
66
66
def from_dataframe (dataframe : pd .DataFrame ) -> Dataset [pd .Series ]:
67
- '''Create ``Dataset`` based on ``pandas.DataFrame``.'''
67
+ '''
68
+ Create ``Dataset`` based on ``pandas.DataFrame``.
69
+ :func:`Dataset.__getitem__` will return a row from the dataframe and
70
+ :func:`Dataset.map` should be given a function that takes a row from
71
+ the dataframe as input.
72
+
73
+ >>> (
74
+ ... Dataset.from_dataframe(pd.DataFrame(dict(
75
+ ... number=[1, 2, 3]
76
+ ... )))
77
+ ... .map(lambda row: row['number'] + 1)
78
+ ... )[-1]
79
+ 4
80
+ '''
68
81
return Dataset (
69
82
dataframe = dataframe ,
70
83
length = len (dataframe ),
71
84
get_item = lambda df , index : df .iloc [index ],
72
85
)
73
86
74
87
def __getitem__ (self : Dataset [T ], index : int ) -> T :
88
+ '''Get an example ``T`` from the ``Dataset[T]``'''
75
89
return self .get_item (self .dataframe , index )
76
90
77
91
def __len__ (self ):
@@ -124,8 +138,8 @@ def starmap(
124
138
) -> Dataset [R ]:
125
139
'''
126
140
Creates a new dataset with the function added to the dataset pipeline.
127
- The functions expects iterables that are expanded as \\ *args for the
128
- mapped function.
141
+ The dataset's pipeline should return an iterable that will be
142
+ expanded as \\ *args to the mapped function.
129
143
130
144
>>> (
131
145
... Dataset.from_subscriptable([1, 2, 3])
@@ -145,12 +159,15 @@ def subset(
145
159
Select a subset of the dataset using a function that receives the
146
160
source dataframe as input and is expected to return a boolean mask.
147
161
162
+ Note that this function can still be called after multiple operations
163
+ such as mapping functions as it uses the source dataframe.
164
+
148
165
>>> (
149
166
... Dataset.from_dataframe(pd.DataFrame(dict(
150
167
... number=[1, 2, 3]
151
168
... )))
152
- ... .subset(lambda df: df['number'] <= 2)
153
169
... .map(lambda row: row['number'])
170
+ ... .subset(lambda dataframe: dataframe['number'] <= 2)
154
171
... )[-1]
155
172
2
156
173
'''
@@ -198,21 +215,22 @@ def split(
198
215
* Adapt after removing examples from dataset
199
216
* Adapt to new stratification
200
217
201
- >>> split_file = Path('doctest_split_dataset.json')
202
218
>>> split_datasets = (
203
219
... Dataset.from_dataframe(pd.DataFrame(dict(
204
220
... index=np.arange(100),
205
- ... number=np.random.randn (100),
221
+ ... number=np.arange (100),
206
222
... )))
223
+ ... .map(lambda row: row['number'])
207
224
... .split(
208
225
... key_column='index',
209
226
... proportions=dict(train=0.8, test=0.2),
210
- ... filepath=split_file ,
227
+ ... seed=700 ,
211
228
... )
212
229
... )
213
230
>>> len(split_datasets['train'])
214
231
80
215
- >>> split_file.unlink() # clean up after doctest
232
+ >>> split_datasets['test'][0]
233
+ 3
216
234
'''
217
235
if filepath is not None :
218
236
filepath = Path (filepath )
@@ -242,6 +260,12 @@ def zip_index(self: Dataset[T]) -> Dataset[Tuple[T, int]]:
242
260
'''
243
261
Zip the output with its index. The output of the pipeline will be
244
262
a tuple ``(output, index)``.
263
+
264
+ >>> (
265
+ ... Dataset.from_subscriptable([4, 5, 6])
266
+ ... .zip_index()
267
+ ... )[0]
268
+ (4, 0)
245
269
'''
246
270
return Dataset (
247
271
dataframe = self .dataframe ,
@@ -362,6 +386,12 @@ def zip(datasets: List[Dataset]) -> Dataset[Tuple]:
362
386
datasets' dataframes. It is concatenated over columns with an added
363
387
multiindex column like this:
364
388
``pd.concat(dataframes, axis=1, keys=['dataset0', 'dataset1', ...])``
389
+
390
+ >>> Dataset.zip([
391
+ ... Dataset.from_subscriptable([1, 2, 3]),
392
+ ... Dataset.from_subscriptable([4, 5, 6, 7]),
393
+ ... ])[-1]
394
+ (3, 6)
365
395
'''
366
396
length = min (map (len , datasets ))
367
397
return (
0 commit comments