4
4
5
5
6
6
class BaseParallelForTest (unittest .TestCase ):
7
- def main (self , callback , feed , fetch ):
7
+ def run_test (self , callback , feed , fetch ):
8
+ """
9
+ Run the unittest for parallel.for
10
+ Args:
11
+ callback(callable): A callable function returns a generator. There
12
+ are two yields in the generator function. The first yield
13
+ returns the data layers, and the second yield returns the loss.
14
+ The modified data variables will be sent back during the first
15
+ yield.
16
+
17
+ feed(dict): The executor feeding dictionary.
18
+ fetch(list|basestr): The fetch name lists.
19
+
20
+ Returns:
21
+ None
22
+
23
+ Raises:
24
+ AssertionError when the computation of cpu, parallel.for in cpu,
25
+ gpu, parallel.for in gpu are different.
26
+
27
+ """
8
28
cpu = fluid .CPUPlace ()
9
- result_cpu = self ._main_impl_ (
29
+ result_cpu = self ._run_test_impl_ (
10
30
callback = callback ,
11
31
feed = feed ,
12
32
fetch = fetch ,
13
33
place = cpu ,
14
34
use_parallel = False )
15
- result_cpu_parallel = self ._main_impl_ (
35
+ result_cpu_parallel = self ._run_test_impl_ (
16
36
callback = callback ,
17
37
feed = feed ,
18
38
fetch = fetch ,
19
39
place = cpu ,
20
40
use_parallel = True )
21
41
if fluid .core .is_compile_gpu ():
22
42
gpu = fluid .CUDAPlace (0 )
23
- result_gpu = self ._main_impl_ (
43
+ result_gpu = self ._run_test_impl_ (
24
44
callback = callback ,
25
45
feed = feed ,
26
46
fetch = fetch ,
27
47
place = gpu ,
28
48
use_parallel = False )
29
- result_gpu_parallel = self ._main_impl_ (
49
+ result_gpu_parallel = self ._run_test_impl_ (
30
50
callback = callback ,
31
51
feed = feed ,
32
52
fetch = fetch ,
@@ -37,7 +57,17 @@ def main(self, callback, feed, fetch):
37
57
else :
38
58
self ._assert_same_ (fetch , result_cpu , result_cpu_parallel )
39
59
40
- def _main_impl_ (self , callback , feed , fetch , place , use_parallel = False ):
60
+ def _run_test_impl_ (self , callback , feed , fetch , place , use_parallel = False ):
61
+ """
62
+ Run a single test, returns the fetch values
63
+ Args:
64
+ place(Place): the computation place.
65
+ use_parallel(bool): Whether use parallel.for or not.
66
+
67
+ Returns:
68
+ Fetched numpy arrays.
69
+
70
+ """
41
71
if isinstance (fetch , basestring ):
42
72
fetch = [fetch ]
43
73
main = fluid .Program ()
@@ -77,6 +107,20 @@ def _main_impl_(self, callback, feed, fetch, place, use_parallel=False):
77
107
return exe .run (main , feed = feed , fetch_list = fetch )
78
108
79
109
def _assert_same_ (self , fetch , * args ):
110
+ """
111
+ Assert the return values of `run_test` are same.
112
+ Args:
113
+ fetch: Fetch list. Used for print error message
114
+ *args: The fetch result lists of each situations.
115
+
116
+ Returns:
117
+ None
118
+
119
+ Raises:
120
+ AssertionError
121
+
122
+ """
123
+
80
124
def _impl_ (a , b , fetch_id , item_id ):
81
125
item_str = ['CPU' , 'ParallelCPU' , 'GPU' , 'ParallelGPU' ]
82
126
flag = numpy .allclose (a , b , rtol = 0.1 )
@@ -100,7 +144,7 @@ def __network__():
100
144
loss = fluid .layers .mean (x = hidden )
101
145
yield loss
102
146
103
- self .main (
147
+ self .run_test (
104
148
callback = __network__ ,
105
149
feed = {
106
150
'img' : numpy .random .random (size = (128 , 784 )).astype ('float32' )
0 commit comments