Skip to content

Commit 43facfd

Browse files
authored
[Cherry-pick]Add DistributedBatchSampler and Colerjitter (#25242)
* add DistributedSampler and ColorJitter, test=develop
1 parent 693083a commit 43facfd

File tree

12 files changed

+564
-0
lines changed

12 files changed

+564
-0
lines changed

python/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ if (WITH_TESTING)
9696
add_subdirectory(paddle/fluid/tests)
9797
add_subdirectory(paddle/fluid/contrib/tests)
9898
add_subdirectory(paddle/fluid/contrib/slim/tests)
99+
add_subdirectory(paddle/incubate/hapi/tests)
99100
endif()
100101
install(DIRECTORY ${PADDLE_PYTHON_PACKAGE_DIR}
101102
DESTINATION opt/paddle/share/wheels

python/paddle/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,6 @@
3535
batch = batch.batch
3636
import paddle.sysconfig
3737
import paddle.complex
38+
39+
from . import incubate
40+
from .incubate import hapi

python/paddle/incubate/__init__.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import hapi
16+
17+
__all__ = []
18+
__all__ += hapi.__all__
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import distributed
16+
from . import vision
17+
18+
__all__ = ['distributed', 'vision']
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import absolute_import
16+
from __future__ import division
17+
from __future__ import print_function
18+
19+
import math
20+
import numpy as np
21+
22+
from paddle.fluid.dygraph.parallel import ParallelEnv
23+
from paddle.fluid.io import BatchSampler
24+
25+
__all__ = ['DistributedBatchSampler']
26+
27+
28+
class DistributedBatchSampler(BatchSampler):
29+
"""Sampler that restricts data loading to a subset of the dataset.
30+
In such case, each process can pass a DistributedBatchSampler instance
31+
as a DataLoader sampler, and load a subset of the original dataset that
32+
is exclusive to it.
33+
.. note::
34+
Dataset is assumed to be of constant size.
35+
36+
Args:
37+
dataset(paddle.io.Dataset): this could be a `paddle.io.Dataset` implement
38+
or other python object which implemented
39+
`__len__` for BatchSampler to get sample
40+
number of data source.
41+
batch_size(int): sample indice number in a mini-batch indices.
42+
shuffle(bool): whther to shuffle indices order before genrating
43+
batch indices. Default False.
44+
drop_last(bool): whether drop the last incomplete batch dataset size
45+
is not divisible by the batch size. Default False
46+
Examples:
47+
.. code-block:: python
48+
49+
from paddle.incubate.hapi.distributed import DistributedBatchSampler
50+
class FakeDataset():
51+
def __init__(self):
52+
pass
53+
54+
def __getitem__(self, idx):
55+
return idx,
56+
57+
def __len__(self):
58+
return 10
59+
60+
train_dataset = FakeDataset()
61+
dist_train_dataloader = DistributedBatchSampler(train_dataset, batch_size=4)
62+
for data in dist_train_dataloader:
63+
# do something
64+
break
65+
"""
66+
67+
def __init__(self, dataset, batch_size, shuffle=False, drop_last=False):
68+
self.dataset = dataset
69+
70+
assert isinstance(batch_size, int) and batch_size > 0, \
71+
"batch_size should be a positive integer"
72+
self.batch_size = batch_size
73+
assert isinstance(shuffle, bool), \
74+
"shuffle should be a boolean value"
75+
self.shuffle = shuffle
76+
assert isinstance(drop_last, bool), \
77+
"drop_last should be a boolean number"
78+
79+
self.drop_last = drop_last
80+
self.nranks = ParallelEnv().nranks
81+
self.local_rank = ParallelEnv().local_rank
82+
self.epoch = 0
83+
self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.nranks))
84+
self.total_size = self.num_samples * self.nranks
85+
86+
def __iter__(self):
87+
num_samples = len(self.dataset)
88+
indices = np.arange(num_samples).tolist()
89+
indices += indices[:(self.total_size - len(indices))]
90+
assert len(indices) == self.total_size
91+
if self.shuffle:
92+
np.random.RandomState(self.epoch).shuffle(indices)
93+
self.epoch += 1
94+
95+
# subsample
96+
def _get_indices_by_batch_size(indices):
97+
subsampled_indices = []
98+
last_batch_size = self.total_size % (self.batch_size * self.nranks)
99+
assert last_batch_size % self.nranks == 0
100+
last_local_batch_size = last_batch_size // self.nranks
101+
102+
for i in range(self.local_rank * self.batch_size,
103+
len(indices) - last_batch_size,
104+
self.batch_size * self.nranks):
105+
subsampled_indices.extend(indices[i:i + self.batch_size])
106+
107+
indices = indices[len(indices) - last_batch_size:]
108+
subsampled_indices.extend(indices[
109+
self.local_rank * last_local_batch_size:(
110+
self.local_rank + 1) * last_local_batch_size])
111+
return subsampled_indices
112+
113+
if self.nranks > 1:
114+
indices = _get_indices_by_batch_size(indices)
115+
116+
assert len(indices) == self.num_samples
117+
_sample_iter = iter(indices)
118+
119+
batch_indices = []
120+
for idx in _sample_iter:
121+
batch_indices.append(idx)
122+
if len(batch_indices) == self.batch_size:
123+
yield batch_indices
124+
batch_indices = []
125+
if not self.drop_last and len(batch_indices) > 0:
126+
yield batch_indices
127+
128+
def __len__(self):
129+
num_samples = self.num_samples
130+
num_samples += int(not self.drop_last) * (self.batch_size - 1)
131+
return num_samples // self.batch_size
132+
133+
def set_epoch(self, epoch):
134+
self.epoch = epoch
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
2+
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
3+
4+
foreach(src ${TEST_OPS})
5+
py_test(${src} SRCS ${src}.py)
6+
endforeach()
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import math
16+
import unittest
17+
18+
from paddle.incubate.hapi.distributed import DistributedBatchSampler
19+
20+
21+
class FakeDataset():
22+
def __init__(self):
23+
pass
24+
25+
def __getitem__(self, index):
26+
return index
27+
28+
def __len__(self):
29+
return 10
30+
31+
32+
class TestDistributedBatchSampler(unittest.TestCase):
33+
def test_sampler(self):
34+
dataset = FakeDataset()
35+
sampler = DistributedBatchSampler(dataset, batch_size=1, shuffle=True)
36+
for batch_idx in sampler:
37+
batch_idx
38+
pass
39+
40+
def test_multiple_gpus_sampler(self):
41+
dataset = FakeDataset()
42+
sampler1 = DistributedBatchSampler(
43+
dataset, batch_size=4, shuffle=True, drop_last=True)
44+
sampler2 = DistributedBatchSampler(
45+
dataset, batch_size=4, shuffle=True, drop_last=True)
46+
47+
sampler1.nranks = 2
48+
sampler1.local_rank = 0
49+
sampler1.num_samples = int(
50+
math.ceil(len(dataset) * 1.0 / sampler1.nranks))
51+
sampler1.total_size = sampler1.num_samples * sampler1.nranks
52+
53+
sampler2.nranks = 2
54+
sampler2.local_rank = 1
55+
sampler2.num_samples = int(
56+
math.ceil(len(dataset) * 1.0 / sampler2.nranks))
57+
sampler2.total_size = sampler2.num_samples * sampler2.nranks
58+
59+
for batch_idx in sampler1:
60+
batch_idx
61+
pass
62+
63+
for batch_idx in sampler2:
64+
batch_idx
65+
pass
66+
67+
68+
if __name__ == '__main__':
69+
unittest.main()
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import numpy as np
17+
18+
from paddle.incubate.hapi.vision.transforms import transforms
19+
20+
21+
class TestTransforms(unittest.TestCase):
22+
def do_transform(self, trans):
23+
fake_img = (np.random.random((400, 300, 3)) * 255).astype('uint8')
24+
for t in trans:
25+
fake_img = t(fake_img)
26+
27+
def test_color_jitter(self):
28+
trans = [
29+
transforms.BrightnessTransform(0.0), transforms.HueTransform(0.0),
30+
transforms.SaturationTransform(0.0),
31+
transforms.ContrastTransform(0.0),
32+
transforms.ColorJitter(0.2, 0.2, 0.2, 0.2)
33+
]
34+
self.do_transform(trans)
35+
36+
def test_exception(self):
37+
38+
with self.assertRaises(ValueError):
39+
transforms.ContrastTransform(-1.0)
40+
41+
with self.assertRaises(ValueError):
42+
transforms.SaturationTransform(-1.0),
43+
44+
with self.assertRaises(ValueError):
45+
transforms.HueTransform(-1.0)
46+
47+
with self.assertRaises(ValueError):
48+
transforms.BrightnessTransform(-1.0)
49+
50+
51+
if __name__ == '__main__':
52+
unittest.main()
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import transforms
16+
from .transforms import *
17+
18+
__all__ = transforms.__all__
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from . import transforms
16+
17+
from .transforms import *
18+
19+
__all__ = transforms.__all__

0 commit comments

Comments
 (0)