@@ -73,9 +73,18 @@ def _transpiler_instance(self, config=None):
7373
7474 return self .transpiler
7575
76+ def transpiler_test_impl (self ):
77+ pass
7678
77- class TestBasicModel (TranspilerTest ):
7879 def test_transpiler (self ):
80+ main = fluid .Program ()
81+ startup = fluid .Program ()
82+ with fluid .program_guard (main , startup ):
83+ self .transpiler_test_impl ()
84+
85+
86+ class TestBasicModel (TranspilerTest ):
87+ def transpiler_test_impl (self ):
7988 pserver , startup = self .get_pserver (self .pserver1_ep )
8089 pserver2 , startup2 = self .get_pserver (self .pserver2_ep )
8190
@@ -123,7 +132,7 @@ def test_transpiler(self):
123132
124133
125134class TestBasicModelWithLargeBlockSize (TranspilerTest ):
126- def test_transpiler (self ):
135+ def transpiler_test_impl (self ):
127136 config = fluid .DistributeTranspilerConfig ()
128137 config .min_block_size = 1048576
129138
@@ -148,7 +157,7 @@ def test_transpiler(self):
148157 ["sum" , "scale" , "sgd" ])
149158 # confirm startup program
150159 self .assertEqual ([op .type for op in startup .global_block ().ops ],
151- ["fill_constant" , "fill_constant" , "fill_constant" ])
160+ ["fill_constant" , "fill_constant" ])
152161 # the variable #fc_w will be split into two blocks
153162 fc_w_var = startup2 .global_block ().var ("fc_w" )
154163 self .assertEqual (fc_w_var .shape , (1000L , 1000L ))
@@ -177,7 +186,7 @@ class TestNoSliceVar(TranspilerTest):
177186 def setUp (self ):
178187 super (TestNoSliceVar , self ).setUp ()
179188
180- def test_transpiler (self ):
189+ def transpiler_test_impl (self ):
181190 config = fluid .DistributeTranspilerConfig ()
182191 config .slice_var_up = False
183192
@@ -212,7 +221,7 @@ def net_conf(self):
212221 sgd_optimizer .minimize (avg_cost )
213222 return
214223
215- def test_transpiler (self ):
224+ def transpiler_test_impl (self ):
216225 pserver , startup = self .get_pserver (self .pserver1_ep )
217226 trainer = self .get_trainer ()
218227
@@ -242,7 +251,7 @@ def net_conf(self):
242251 sgd_optimizer .minimize (avg_cost )
243252 return
244253
245- def test_transpiler (self ):
254+ def transpiler_test_impl (self ):
246255 pserver , startup = self .get_pserver (self .pserver1_ep )
247256 trainer = self .get_trainer ()
248257
@@ -291,7 +300,7 @@ def net_conf(self):
291300 sgd_optimizer .minimize (avg_cost )
292301 return
293302
294- def test_transpiler (self ):
303+ def transpiler_test_impl (self ):
295304 pserver , startup = self .get_pserver (self .pserver1_ep )
296305 trainer = self .get_trainer ()
297306
@@ -326,7 +335,7 @@ def net_conf(self):
326335 sgd_optimizer .minimize (avg_cost )
327336 return
328337
329- def test_transpiler (self ):
338+ def transpiler_test_impl (self ):
330339 pserver , startup = self .get_pserver (self .pserver1_ep )
331340 trainer = self .get_trainer ()
332341
0 commit comments