@@ -27,12 +27,15 @@ class TestNetWithDtype(unittest.TestCase):
2727 def set_network (self ):
2828 self .dtype = "float64"
2929 self .init_dtype ()
30- self .x = fluid .layers .data (name = 'x' , shape = [13 ], dtype = self .dtype )
31- self .y = fluid .layers .data (name = 'y' , shape = [1 ], dtype = self .dtype )
32- y_predict = fluid .layers .fc (input = self .x , size = 1 , act = None )
30+ main = fluid .Program ()
31+ with fluid .program_guard (main ):
32+ self .x = fluid .layers .data (name = 'x' , shape = [13 ], dtype = self .dtype )
33+ self .y = fluid .layers .data (name = 'y' , shape = [1 ], dtype = self .dtype )
34+ y_predict = fluid .layers .fc (input = self .x , size = 1 , act = None )
3335
34- cost = fluid .layers .square_error_cost (input = y_predict , label = self .y )
35- avg_cost = fluid .layers .mean (cost )
36+ cost = fluid .layers .square_error_cost (input = y_predict , label = self .y )
37+ avg_cost = fluid .layers .mean (cost )
38+ self .program = main
3639 self .fetch_list = [avg_cost ]
3740
3841 sgd_optimizer = fluid .optimizer .SGD (learning_rate = 0.001 )
@@ -45,7 +48,7 @@ def run_net_on_place(self, place):
4548 exe = fluid .Executor (place )
4649 exe .run (fluid .default_startup_program ())
4750 for data in train_reader ():
48- exe .run (fluid . default_main_program () ,
51+ exe .run (self . program ,
4952 feed = feeder .feed (data ),
5053 fetch_list = self .fetch_list )
5154 # the main program is runable, the datatype is fully supported
@@ -68,7 +71,7 @@ def test_gpu(self):
6871
6972
7073# TODO(dzhwinter): make sure the fp16 is runable
71- # class TestFloat16(SimpleNet ):
74+ # class TestFloat16(TestNetWithDtype ):
7275# def init_dtype(self):
7376# self.dtype = "float16"
7477
0 commit comments