17
17
import numpy as np
18
18
import paddle .fluid as fluid
19
19
from paddle .fluid import core
20
- import paddle .compat as cpt
21
20
22
21
23
22
def get_random_images_and_labels (image_shape , label_shape ):
@@ -35,19 +34,6 @@ def __reader__():
35
34
return __reader__
36
35
37
36
38
- def sample_list_generator_creator (batch_size , batch_num ):
39
- def __reader__ ():
40
- for _ in range (batch_num ):
41
- sample_list = []
42
- for _ in range (batch_size ):
43
- image , label = get_random_images_and_labels ([784 ], [1 ])
44
- sample_list .append ([image , label ])
45
-
46
- yield sample_list
47
-
48
- return __reader__
49
-
50
-
51
37
def batch_generator_creator (batch_size , batch_num ):
52
38
def __reader__ ():
53
39
for _ in range (batch_num ):
@@ -62,8 +48,8 @@ class TestDygraphhDataLoader(unittest.TestCase):
62
48
def setUp (self ):
63
49
self .batch_size = 8
64
50
self .batch_num = 4
65
- self .epoch_num = 2
66
- self .capacity = 2
51
+ self .epoch_num = 1
52
+ self .capacity = 5
67
53
68
54
def test_single_process_reader (self ):
69
55
with fluid .dygraph .guard ():
@@ -95,20 +81,6 @@ def test_sample_genarator(self):
95
81
self .assertEqual (label .shape , [self .batch_size , 1 ])
96
82
self .assertEqual (relu .shape , [self .batch_size , 784 ])
97
83
98
- def test_sample_list_generator (self ):
99
- with fluid .dygraph .guard ():
100
- loader = fluid .io .DataLoader .from_generator (
101
- capacity = self .capacity , use_multiprocess = True )
102
- loader .set_sample_list_generator (
103
- sample_list_generator_creator (self .batch_size , self .batch_num ),
104
- places = fluid .CPUPlace ())
105
- for _ in range (self .epoch_num ):
106
- for image , label in loader ():
107
- relu = fluid .layers .relu (image )
108
- self .assertEqual (image .shape , [self .batch_size , 784 ])
109
- self .assertEqual (label .shape , [self .batch_size , 1 ])
110
- self .assertEqual (relu .shape , [self .batch_size , 784 ])
111
-
112
84
def test_batch_genarator (self ):
113
85
with fluid .dygraph .guard ():
114
86
loader = fluid .io .DataLoader .from_generator (
@@ -124,63 +96,5 @@ def test_batch_genarator(self):
124
96
self .assertEqual (relu .shape , [self .batch_size , 784 ])
125
97
126
98
127
- class TestDygraphhDataLoaderWithException (unittest .TestCase ):
128
- def setUp (self ):
129
- self .batch_num = 4
130
- self .capacity = 2
131
-
132
- def test_not_capacity (self ):
133
- with fluid .dygraph .guard ():
134
- with self .assertRaisesRegexp (ValueError ,
135
- "Please give value to capacity." ):
136
- fluid .io .DataLoader .from_generator ()
137
-
138
- def test_single_process_with_thread_expection (self ):
139
- def error_sample_genarator (batch_num ):
140
- def __reader__ ():
141
- for _ in range (batch_num ):
142
- yield [[[1 , 2 ], [1 ]]]
143
-
144
- return __reader__
145
-
146
- with fluid .dygraph .guard ():
147
- loader = fluid .io .DataLoader .from_generator (
148
- capacity = self .capacity , iterable = False , use_multiprocess = False )
149
- loader .set_batch_generator (
150
- error_sample_genarator (self .batch_num ), places = fluid .CPUPlace ())
151
- exception = None
152
- try :
153
- for _ in loader ():
154
- print ("test_single_process_with_thread_expection" )
155
- except core .EnforceNotMet as ex :
156
- self .assertIn ("Blocking queue is killed" ,
157
- cpt .get_exception_message (ex ))
158
- exception = ex
159
- self .assertIsNotNone (exception )
160
-
161
- def test_multi_process_with_thread_expection (self ):
162
- def error_sample_genarator (batch_num ):
163
- def __reader__ ():
164
- for _ in range (batch_num ):
165
- yield [[[1 , 2 ], [1 ]]]
166
-
167
- return __reader__
168
-
169
- with fluid .dygraph .guard ():
170
- loader = fluid .io .DataLoader .from_generator (
171
- capacity = self .capacity , use_multiprocess = True )
172
- loader .set_batch_generator (
173
- error_sample_genarator (self .batch_num ), places = fluid .CPUPlace ())
174
- exception = None
175
- try :
176
- for _ in loader ():
177
- print ("test_multi_process_with_thread_expection" )
178
- except core .EnforceNotMet as ex :
179
- self .assertIn ("Blocking queue is killed" ,
180
- cpt .get_exception_message (ex ))
181
- exception = ex
182
- self .assertIsNotNone (exception )
183
-
184
-
185
99
if __name__ == '__main__' :
186
100
unittest .main ()
0 commit comments