Skip to content

Commit 920a1e2

Browse files
committed
feature: implement method take for datastream
1 parent ae71922 commit 920a1e2

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

datastream/datastream.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ def __init__(
6565
)
6666
)
6767

68+
def __len__(self):
69+
return len(self.sampler)
70+
6871
@staticmethod
6972
def merge(datastreams_and_ns: Tuple[Union[
7073
Datastream[T],
@@ -213,6 +216,16 @@ def sample_proportion(
213216
self.sampler.sample_proportion(proportion),
214217
)
215218

219+
def take(
220+
self: Datastream[T],
221+
n_samples: int,
222+
) -> Datastream[T]:
223+
'''
224+
Like :func:`Datastream.sample_proportion` but specify the number of
225+
samples instead of a proportion.
226+
'''
227+
return self.sample_proportion(min(1, n_samples / len(self)))
228+
216229
def state_dict(self) -> Dict:
217230
'''Get state of datastream. Useful for checkpointing sample weights.'''
218231
return dict(sampler=self.sampler.state_dict())

datastream/samplers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def __init__(self, length, proportion=1.0, replacement=False):
2424
replacement=replacement,
2525
sampler=torch.utils.data.WeightedRandomSampler(
2626
torch.ones(length).double(),
27-
num_samples=int(length * proportion),
27+
num_samples=max(int(length * proportion), 1),
2828
replacement=replacement,
2929
)
3030
)

0 commit comments

Comments
 (0)