Skip to content

Commit 12aca86

Browse files
committed
Add comment
1 parent 2412f2f commit 12aca86

File tree

1 file changed

+51
-7
lines changed

1 file changed

+51
-7
lines changed

python/paddle/v2/fluid/tests/test_parallel_op.py

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,49 @@
44

55

66
class BaseParallelForTest(unittest.TestCase):
7-
def main(self, callback, feed, fetch):
7+
def run_test(self, callback, feed, fetch):
8+
"""
9+
Run the unittest for parallel.for
10+
Args:
11+
callback(callable): A callable function returns a generator. There
12+
are two yields in the generator function. The first yield
13+
returns the data layers, and the second yield returns the loss.
14+
The modified data variables will be sent back during the first
15+
yield.
16+
17+
feed(dict): The executor feeding dictionary.
18+
fetch(list|basestr): The fetch name lists.
19+
20+
Returns:
21+
None
22+
23+
Raises:
24+
AssertionError when the computation of cpu, parallel.for in cpu,
25+
gpu, parallel.for in gpu are different.
26+
27+
"""
828
cpu = fluid.CPUPlace()
9-
result_cpu = self._main_impl_(
29+
result_cpu = self._run_test_impl_(
1030
callback=callback,
1131
feed=feed,
1232
fetch=fetch,
1333
place=cpu,
1434
use_parallel=False)
15-
result_cpu_parallel = self._main_impl_(
35+
result_cpu_parallel = self._run_test_impl_(
1636
callback=callback,
1737
feed=feed,
1838
fetch=fetch,
1939
place=cpu,
2040
use_parallel=True)
2141
if fluid.core.is_compile_gpu():
2242
gpu = fluid.CUDAPlace(0)
23-
result_gpu = self._main_impl_(
43+
result_gpu = self._run_test_impl_(
2444
callback=callback,
2545
feed=feed,
2646
fetch=fetch,
2747
place=gpu,
2848
use_parallel=False)
29-
result_gpu_parallel = self._main_impl_(
49+
result_gpu_parallel = self._run_test_impl_(
3050
callback=callback,
3151
feed=feed,
3252
fetch=fetch,
@@ -37,7 +57,17 @@ def main(self, callback, feed, fetch):
3757
else:
3858
self._assert_same_(fetch, result_cpu, result_cpu_parallel)
3959

40-
def _main_impl_(self, callback, feed, fetch, place, use_parallel=False):
60+
def _run_test_impl_(self, callback, feed, fetch, place, use_parallel=False):
61+
"""
62+
Run a single test, returns the fetch values
63+
Args:
64+
place(Place): the computation place.
65+
use_parallel(bool): Whether use parallel.for or not.
66+
67+
Returns:
68+
Fetched numpy arrays.
69+
70+
"""
4171
if isinstance(fetch, basestring):
4272
fetch = [fetch]
4373
main = fluid.Program()
@@ -77,6 +107,20 @@ def _main_impl_(self, callback, feed, fetch, place, use_parallel=False):
77107
return exe.run(main, feed=feed, fetch_list=fetch)
78108

79109
def _assert_same_(self, fetch, *args):
110+
"""
111+
Assert the return values of `run_test` are same.
112+
Args:
113+
fetch: Fetch list. Used for print error message
114+
*args: The fetch result lists of each situations.
115+
116+
Returns:
117+
None
118+
119+
Raises:
120+
AssertionError
121+
122+
"""
123+
80124
def _impl_(a, b, fetch_id, item_id):
81125
item_str = ['CPU', 'ParallelCPU', 'GPU', 'ParallelGPU']
82126
flag = numpy.allclose(a, b, rtol=0.1)
@@ -100,7 +144,7 @@ def __network__():
100144
loss = fluid.layers.mean(x=hidden)
101145
yield loss
102146

103-
self.main(
147+
self.run_test(
104148
callback=__network__,
105149
feed={
106150
'img': numpy.random.random(size=(128, 784)).astype('float32')

0 commit comments

Comments
 (0)