@@ -152,6 +152,16 @@ def check_simple_fc_convergence(self, use_cuda, use_reduce=False):
152
152
use_cuda = use_cuda ,
153
153
use_reduce = use_reduce )
154
154
155
+ def test_simple_fc (self ):
156
+ # use_cuda
157
+ self .check_simple_fc_convergence (True )
158
+ self .check_simple_fc_convergence (False )
159
+
160
+ def test_simple_fc_with_new_strategy (self ):
161
+ # use_cuda, use_reduce
162
+ self ._compare_reduce_and_allreduce (simple_fc_net , True )
163
+ self ._compare_reduce_and_allreduce (simple_fc_net , False )
164
+
155
165
def check_simple_fc_parallel_accuracy (self , use_cuda ):
156
166
if use_cuda and not core .is_compiled_with_cuda ():
157
167
return
@@ -178,6 +188,10 @@ def check_simple_fc_parallel_accuracy(self, use_cuda):
178
188
for p_l in parallel_last_loss :
179
189
self .assertAlmostEquals (p_l , single_last_loss [0 ], delta = 1e-6 )
180
190
191
+ def test_simple_fc_parallel_accuracy (self ):
192
+ self .check_simple_fc_parallel_accuracy (True )
193
+ self .check_simple_fc_parallel_accuracy (False )
194
+
181
195
def check_batchnorm_fc_convergence (self , use_cuda ):
182
196
if use_cuda and not core .is_compiled_with_cuda ():
183
197
return
@@ -192,31 +206,13 @@ def check_batchnorm_fc_convergence(self, use_cuda):
192
206
"label" : label },
193
207
use_cuda = use_cuda )
194
208
195
- def check_batchnorm_fc_convergence_use_reduce (self , use_cuda ):
196
- if use_cuda and not core .is_compiled_with_cuda ():
197
- return
198
- self .check_network_convergence (
199
- fc_with_batchnorm , use_cuda = use_cuda , use_reduce = False )
200
- """
201
- img, label = self._init_data()
202
-
203
- all_reduce_first_loss, all_reduce_last_loss = self.check_network_convergence(
204
- fc_with_batchnorm,
205
- feed_dict={"image": img,
206
- "label": label},
207
- use_cuda=use_cuda,
208
- use_reduce=False)
209
- reduce_first_loss, reduce_last_loss = self.check_network_convergence(
210
- fc_with_batchnorm,
211
- feed_dict={"image": img,
212
- "label": label},
213
- use_cuda=use_cuda,
214
- use_reduce=True)
215
- """
209
+ def test_batchnorm_fc (self ):
210
+ self .check_batchnorm_fc_convergence (True )
211
+ self .check_batchnorm_fc_convergence (False )
216
212
217
213
def test_batchnorm_fc_with_new_strategy (self ):
218
- self .check_batchnorm_fc_convergence_use_reduce ( True )
219
- # self.check_batchnorm_fc_convergence_use_reduce( False)
214
+ self ._compare_reduce_and_allreduce ( fc_with_batchnorm , True )
215
+ self ._compare_reduce_and_allreduce ( fc_with_batchnorm , False )
220
216
221
217
222
218
if __name__ == '__main__' :
0 commit comments