Skip to content

Commit 959258e

Browse files
committed
Add some examples in docstrings
1 parent a0dd35a commit 959258e

File tree

2 files changed

+60
-20
lines changed

2 files changed

+60
-20
lines changed

datastream/dataset.py

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
class Dataset(BaseModel, torch.utils.data.Dataset, Generic[T]):
1818
'''
1919
A ``Dataset[T]`` is a mapping that allows pipelining of functions in a
20-
readable syntax returning an item of type ``T``.
20+
readable syntax returning an example of type ``T``.
2121
2222
>>> from datastream import Dataset
2323
>>> fruit_and_cost = (
@@ -64,14 +64,28 @@ def from_subscriptable(subscriptable) -> Dataset:
6464

6565
@staticmethod
6666
def from_dataframe(dataframe: pd.DataFrame) -> Dataset[pd.Series]:
67-
'''Create ``Dataset`` based on ``pandas.DataFrame``.'''
67+
'''
68+
Create ``Dataset`` based on ``pandas.DataFrame``.
69+
:func:`Dataset.__getitem__` will return a row from the dataframe and
70+
:func:`Dataset.map` should be given a function that takes a row from
71+
the dataframe as input.
72+
73+
>>> (
74+
... Dataset.from_dataframe(pd.DataFrame(dict(
75+
... number=[1, 2, 3]
76+
... )))
77+
... .map(lambda row: row['number'] + 1)
78+
... )[-1]
79+
4
80+
'''
6881
return Dataset(
6982
dataframe=dataframe,
7083
length=len(dataframe),
7184
get_item=lambda df, index: df.iloc[index],
7285
)
7386

7487
def __getitem__(self: Dataset[T], index: int) -> T:
88+
'''Get an example ``T`` from the ``Dataset[T]``'''
7589
return self.get_item(self.dataframe, index)
7690

7791
def __len__(self):
@@ -124,8 +138,8 @@ def starmap(
124138
) -> Dataset[R]:
125139
'''
126140
Creates a new dataset with the function added to the dataset pipeline.
127-
The functions expects iterables that are expanded as \\*args for the
128-
mapped function.
141+
The dataset's pipeline should return an iterable that will be
142+
expanded as \\*args to the mapped function.
129143
130144
>>> (
131145
... Dataset.from_subscriptable([1, 2, 3])
@@ -145,12 +159,15 @@ def subset(
145159
Select a subset of the dataset using a function that receives the
146160
source dataframe as input and is expected to return a boolean mask.
147161
162+
Note that this function can still be called after multiple operations
163+
such as mapping functions as it uses the source dataframe.
164+
148165
>>> (
149166
... Dataset.from_dataframe(pd.DataFrame(dict(
150167
... number=[1, 2, 3]
151168
... )))
152-
... .subset(lambda df: df['number'] <= 2)
153169
... .map(lambda row: row['number'])
170+
... .subset(lambda dataframe: dataframe['number'] <= 2)
154171
... )[-1]
155172
2
156173
'''
@@ -198,21 +215,22 @@ def split(
198215
* Adapt after removing examples from dataset
199216
* Adapt to new stratification
200217
201-
>>> split_file = Path('doctest_split_dataset.json')
202218
>>> split_datasets = (
203219
... Dataset.from_dataframe(pd.DataFrame(dict(
204220
... index=np.arange(100),
205-
... number=np.random.randn(100),
221+
... number=np.arange(100),
206222
... )))
223+
... .map(lambda row: row['number'])
207224
... .split(
208225
... key_column='index',
209226
... proportions=dict(train=0.8, test=0.2),
210-
... filepath=split_file,
227+
... seed=700,
211228
... )
212229
... )
213230
>>> len(split_datasets['train'])
214231
80
215-
>>> split_file.unlink() # clean up after doctest
232+
>>> split_datasets['test'][0]
233+
3
216234
'''
217235
if filepath is not None:
218236
filepath = Path(filepath)
@@ -242,6 +260,12 @@ def zip_index(self: Dataset[T]) -> Dataset[Tuple[T, int]]:
242260
'''
243261
Zip the output with its index. The output of the pipeline will be
244262
a tuple ``(output, index)``.
263+
264+
>>> (
265+
... Dataset.from_subscriptable([4, 5, 6])
266+
... .zip_index()
267+
... )[0]
268+
(4, 0)
245269
'''
246270
return Dataset(
247271
dataframe=self.dataframe,
@@ -362,6 +386,12 @@ def zip(datasets: List[Dataset]) -> Dataset[Tuple]:
362386
datasets' dataframes. It is concatenated over columns with an added
363387
multiindex column like this:
364388
``pd.concat(dataframes, axis=1, keys=['dataset0', 'dataset1', ...])``
389+
390+
>>> Dataset.zip([
391+
... Dataset.from_subscriptable([1, 2, 3]),
392+
... Dataset.from_subscriptable([4, 5, 6, 7]),
393+
... ])[-1]
394+
(3, 6)
365395
'''
366396
length = min(map(len, datasets))
367397
return (

datastream/datastream.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,12 @@ class Datastream(BaseModel, Generic[T]):
3535
drawn before allowing replacement can be changed with
3636
:func:`Datastream.sample_proportion`.
3737
38-
>>> from datastream import Dataset, Datastream
39-
>>> data_loader = (
40-
... Datastream(Dataset.from_subscriptable([1, 2, 3]))
41-
... .data_loader(batch_size=16, n_batches_per_epoch=100)
42-
... )
43-
>>> len(next(iter(data_loader)))
44-
16
38+
>>> data_loader = (
39+
... Datastream(Dataset.from_subscriptable([1, 2, 3]))
40+
... .data_loader(batch_size=16, n_batches_per_epoch=100)
41+
... )
42+
>>> len(next(iter(data_loader)))
43+
16
4544
'''
4645

4746
dataset: Dataset[T]
@@ -145,7 +144,16 @@ def data_loader(
145144
n_batches_per_epoch: int = None,
146145
**kwargs
147146
) -> torch.utils.data.DataLoader:
148-
'''Get ``torch.utils.data.DataLoader`` for use in pytorch pipeline.'''
147+
'''
148+
Get ``torch.utils.data.DataLoader`` for use in pytorch pipeline.
149+
150+
>>> data_loader = (
151+
... Datastream(Dataset.from_subscriptable([5, 5, 5]))
152+
... .data_loader(batch_size=5, n_batches_per_epoch=10)
153+
... )
154+
>>> list(data_loader)[0]
155+
tensor([5, 5, 5, 5, 5])
156+
'''
149157
if n_batches_per_epoch is None:
150158
sampler = self.sampler
151159
else:
@@ -163,8 +171,10 @@ def zip_index(self: Datastream[T]) -> Datastream[Tuple[T, int]]:
163171
Zip the output with its underlying `Dataset` index. The output of the
164172
pipeline will be a tuple ``(output, index)``
165173
166-
This method is used when you want modify your sample weights during
167-
training.
174+
This method is useful when you want modify your sample weights during
175+
training since that requires the index of the example.
176+
177+
See :func:`Dataset.zip_index` for more details.
168178
'''
169179
return Datastream(
170180
self.dataset.zip_index(),
@@ -203,7 +213,7 @@ def sample_proportion(
203213
)
204214

205215
def state_dict(self) -> Dict:
206-
'''Get state of datastream. Useful for checkpointing.'''
216+
'''Get state of datastream. Useful for checkpointing sample weights.'''
207217
return dict(sampler=self.sampler.state_dict())
208218

209219
def load_state_dict(self, state_dict: Dict):

0 commit comments

Comments
 (0)