Skip to content

Commit 1b3cd0f

Browse files
authored
Add random_split and Subset dataset (#29291) (#32090)
As the title
1 parent 62c2173 commit 1b3cd0f

File tree

3 files changed

+242
-19
lines changed

3 files changed

+242
-19
lines changed

python/paddle/fluid/dataloader/dataset.py

100644100755
Lines changed: 129 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
__all__ = [
2121
"Dataset", "IterableDataset", "TensorDataset", "ComposeDataset",
22-
"ChainDataset"
22+
"ChainDataset", "random_split", "Subset"
2323
]
2424

2525

@@ -405,3 +405,131 @@ def __iter__(self):
405405
for dataset in self.datasets:
406406
for sample in dataset:
407407
yield sample
408+
409+
410+
class Subset(Dataset):
411+
"""
412+
Subset of a dataset at specified indices.
413+
414+
Args:
415+
dataset (Dataset): The whole Dataset.
416+
indices (sequence): Indices in the whole set selected for subset.
417+
418+
Returns:
419+
Dataset: A Dataset which is the subset of the original dataset.
420+
421+
Example code:
422+
423+
.. code-block:: python
424+
425+
import paddle
426+
from paddle.io import Subset
427+
428+
# example 1:
429+
a = paddle.io.Subset(dataset=range(1, 4), indices=[0, 2])
430+
print(list(a))
431+
# [1, 3]
432+
433+
# example 2:
434+
b = paddle.io.Subset(dataset=range(1, 4), indices=[1, 1])
435+
print(list(b))
436+
# [2, 2]
437+
"""
438+
439+
def __init__(self, dataset, indices):
440+
self.dataset = dataset
441+
self.indices = indices
442+
443+
def __getitem__(self, idx):
444+
return self.dataset[self.indices[idx]]
445+
446+
def __len__(self):
447+
return len(self.indices)
448+
449+
450+
def random_split(dataset, lengths, generator=None):
451+
"""
452+
Randomly split a dataset into non-overlapping new datasets of given lengths.
453+
Optionally fix the generator for reproducible results, e.g.:
454+
455+
Args:
456+
dataset (Dataset): Dataset to be split
457+
lengths (sequence): lengths of splits to be produced
458+
generator (Generator, optional): Generator used for the random permutation. Default is None then the DefaultGenerator is used in manual_seed().
459+
460+
Returns:
461+
Datasets: A list of subset Datasets, which are the non-overlapping subsets of the original Dataset.
462+
463+
Example code:
464+
465+
.. code-block:: python
466+
467+
import paddle
468+
from paddle.io import random_split
469+
470+
a_list = paddle.io.random_split(range(10), [3, 7])
471+
print(len(a_list))
472+
# 2
473+
474+
for idx, v in enumerate(a_list[0]):
475+
print(idx, v)
476+
477+
# output of the first subset
478+
# 0 1
479+
# 1 3
480+
# 2 9
481+
482+
for idx, v in enumerate(a_list[1]):
483+
print(idx, v)
484+
# output of the second subset
485+
# 0 5
486+
# 1 7
487+
# 2 8
488+
# 3 6
489+
# 4 0
490+
# 5 2
491+
# 6 4
492+
"""
493+
# Cannot verify that dataset is Sized
494+
if sum(lengths) != len(dataset): # type: ignore
495+
raise ValueError(
496+
"Sum of input lengths does not equal the length of the input dataset!"
497+
)
498+
# TODO(@Joejiong): support Variable or Tensor type with .tolist class member function.
499+
# For example var.item() and var.tolist()
500+
indices = paddle.randperm(sum(lengths)).numpy().tolist()
501+
return [
502+
Subset(dataset, indices[offset - length:offset])
503+
for offset, length in zip(_accumulate(lengths), lengths)
504+
]
505+
506+
507+
def _accumulate(iterable, fn=lambda x, y: x + y):
508+
"""
509+
Return running totals
510+
511+
Args:
512+
iterable: any iterable object for example dataset.
513+
y (x): one element in the iterable object.
514+
fn (x, y): Defaults to lambdax.
515+
516+
Yields:
517+
yields total from beginning iterator to current iterator.
518+
519+
Example code:
520+
521+
.. code-block:: python
522+
523+
_accumulate([1,2,3,4,5]) --> 1 3 6 10 15
524+
_accumulate([1,2,3,4,5], operator.mul) --> 1 2 6 24 120
525+
"""
526+
527+
it = iter(iterable)
528+
try:
529+
total = next(it)
530+
except StopIteration:
531+
return
532+
yield total
533+
for element in it:
534+
total = fn(total, element)
535+
yield total

python/paddle/fluid/tests/unittests/test_multiprocess_dataloader_dataset.py

100644100755
Lines changed: 110 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@
2020
import paddle
2121
import paddle.fluid as fluid
2222
from paddle.io import Dataset, IterableDataset, TensorDataset, \
23-
ComposeDataset, ChainDataset, DataLoader
24-
from paddle.fluid.dygraph.base import to_variable
23+
ComposeDataset, ChainDataset, DataLoader, random_split, Subset
2524

2625
IMAGE_SIZE = 32
2726

@@ -54,14 +53,14 @@ def __iter__(self):
5453

5554
class TestTensorDataset(unittest.TestCase):
5655
def run_main(self, num_workers, places):
57-
fluid.default_startup_program().random_seed = 1
58-
fluid.default_main_program().random_seed = 1
59-
place = fluid.CPUPlace()
56+
paddle.static.default_startup_program().random_seed = 1
57+
paddle.static.default_main_program().random_seed = 1
58+
place = paddle.CPUPlace()
6059
with fluid.dygraph.guard(place):
6160
input_np = np.random.random([16, 3, 4]).astype('float32')
62-
input = to_variable(input_np)
61+
input = paddle.to_tensor(input_np)
6362
label_np = np.random.random([16, 1]).astype('int32')
64-
label = to_variable(label_np)
63+
label = paddle.to_tensor(label_np)
6564

6665
dataset = TensorDataset([input, label])
6766
assert len(dataset) == 16
@@ -83,17 +82,17 @@ def run_main(self, num_workers, places):
8382
assert np.allclose(label.numpy(), label_np[i])
8483

8584
def test_main(self):
86-
places = [fluid.CPUPlace()]
87-
if fluid.core.is_compiled_with_cuda():
88-
places.append(fluid.CUDAPlace(0))
85+
places = [paddle.CPUPlace()]
86+
if paddle.is_compiled_with_cuda():
87+
places.append(paddle.CUDAPlace(0))
8988
for p in places:
9089
self.run_main(num_workers=0, places=p)
9190

9291

9392
class TestComposeDataset(unittest.TestCase):
9493
def test_main(self):
95-
fluid.default_startup_program().random_seed = 1
96-
fluid.default_main_program().random_seed = 1
94+
paddle.static.default_startup_program().random_seed = 1
95+
paddle.static.default_main_program().random_seed = 1
9796

9897
dataset1 = RandomDataset(10)
9998
dataset2 = RandomDataset(10)
@@ -110,10 +109,104 @@ def test_main(self):
110109
assert np.allclose(label2, label2_t)
111110

112111

112+
class TestRandomSplitApi(unittest.TestCase):
113+
def test_main(self):
114+
paddle.static.default_startup_program().random_seed = 1
115+
paddle.static.default_main_program().random_seed = 1
116+
117+
dataset1, dataset2 = paddle.io.random_split(range(5), [1, 4])
118+
119+
self.assertTrue(len(dataset1) == 1)
120+
self.assertTrue(len(dataset2) == 4)
121+
122+
elements_list = list(range(5))
123+
124+
for _, val in enumerate(dataset1):
125+
elements_list.remove(val)
126+
127+
for _, val in enumerate(dataset2):
128+
elements_list.remove(val)
129+
130+
self.assertTrue(len(elements_list) == 0)
131+
132+
133+
class TestRandomSplitError(unittest.TestCase):
134+
def test_errors(self):
135+
paddle.static.default_startup_program().random_seed = 1
136+
paddle.static.default_main_program().random_seed = 1
137+
138+
self.assertRaises(ValueError, paddle.io.random_split, range(5), [3, 8])
139+
self.assertRaises(ValueError, paddle.io.random_split, range(5), [8])
140+
self.assertRaises(ValueError, paddle.io.random_split, range(5), [])
141+
142+
143+
class TestSubsetDataset(unittest.TestCase):
144+
def run_main(self, num_workers, places):
145+
paddle.static.default_startup_program().random_seed = 1
146+
paddle.static.default_main_program().random_seed = 1
147+
148+
input_np = np.random.random([5, 3, 4]).astype('float32')
149+
input = paddle.to_tensor(input_np)
150+
label_np = np.random.random([5, 1]).astype('int32')
151+
label = paddle.to_tensor(label_np)
152+
153+
dataset = TensorDataset([input, label])
154+
even_subset = paddle.io.Subset(dataset, [0, 2, 4])
155+
odd_subset = paddle.io.Subset(dataset, [1, 3])
156+
157+
assert len(dataset) == 5
158+
159+
def prepare_dataloader(dataset):
160+
return DataLoader(
161+
dataset,
162+
places=places,
163+
num_workers=num_workers,
164+
batch_size=1,
165+
drop_last=True)
166+
167+
dataloader = prepare_dataloader(dataset)
168+
dataloader_even = prepare_dataloader(even_subset)
169+
dataloader_odd = prepare_dataloader(odd_subset)
170+
171+
def assert_basic(input, label):
172+
assert len(input) == 1
173+
assert len(label) == 1
174+
assert input.shape == [1, 3, 4]
175+
assert label.shape == [1, 1]
176+
assert isinstance(input, paddle.Tensor)
177+
assert isinstance(label, paddle.Tensor)
178+
179+
elements_list = list()
180+
for _, (input, label) in enumerate(dataloader()):
181+
assert_basic(input, label)
182+
elements_list.append(label)
183+
184+
for _, (input, label) in enumerate(dataloader_even()):
185+
assert_basic(input, label)
186+
elements_list.remove(label)
187+
188+
odd_list = list()
189+
for _, (input, label) in enumerate(dataloader_odd()):
190+
assert_basic(input, label)
191+
odd_list.append(label)
192+
193+
self.assertEqual(odd_list, elements_list)
194+
195+
def test_main(self):
196+
paddle.static.default_startup_program().random_seed = 1
197+
paddle.static.default_main_program().random_seed = 1
198+
199+
places = [paddle.CPUPlace()]
200+
if paddle.is_compiled_with_cuda():
201+
places.append(paddle.CUDAPlace(0))
202+
for p in places:
203+
self.run_main(num_workers=0, places=p)
204+
205+
113206
class TestChainDataset(unittest.TestCase):
114207
def run_main(self, num_workers, places):
115-
fluid.default_startup_program().random_seed = 1
116-
fluid.default_main_program().random_seed = 1
208+
paddle.static.default_startup_program().random_seed = 1
209+
paddle.static.default_main_program().random_seed = 1
117210

118211
dataset1 = RandomIterableDataset(10)
119212
dataset2 = RandomIterableDataset(10)
@@ -135,9 +228,9 @@ def run_main(self, num_workers, places):
135228
idx += 1
136229

137230
def test_main(self):
138-
places = [fluid.CPUPlace()]
139-
if fluid.core.is_compiled_with_cuda():
140-
places.append(fluid.CUDAPlace(0))
231+
places = [paddle.CPUPlace()]
232+
if paddle.is_compiled_with_cuda():
233+
places.append(paddle.CUDAPlace(0))
141234
for p in places:
142235
self.run_main(num_workers=0, places=p)
143236

python/paddle/io/__init__.py

100644100755
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,11 @@
2828
'SequenceSampler',
2929
'RandomSampler',
3030
'WeightedRandomSampler',
31+
'random_split',
32+
'Subset'
3133
]
3234

3335
from ..fluid.io import DataLoader
3436
from ..fluid.dataloader import Dataset, IterableDataset, BatchSampler, get_worker_info, \
3537
TensorDataset, Sampler, SequenceSampler, RandomSampler, DistributedBatchSampler, \
36-
ComposeDataset, ChainDataset, WeightedRandomSampler
38+
ComposeDataset, ChainDataset, WeightedRandomSampler, Subset, random_split

0 commit comments

Comments
 (0)