@@ -27,12 +27,15 @@ class TestNetWithDtype(unittest.TestCase):
27
27
def set_network (self ):
28
28
self .dtype = "float64"
29
29
self .init_dtype ()
30
- self .x = fluid .layers .data (name = 'x' , shape = [13 ], dtype = self .dtype )
31
- self .y = fluid .layers .data (name = 'y' , shape = [1 ], dtype = self .dtype )
32
- y_predict = fluid .layers .fc (input = self .x , size = 1 , act = None )
30
+ main = fluid .Program ()
31
+ with fluid .program_guard (main ):
32
+ self .x = fluid .layers .data (name = 'x' , shape = [13 ], dtype = self .dtype )
33
+ self .y = fluid .layers .data (name = 'y' , shape = [1 ], dtype = self .dtype )
34
+ y_predict = fluid .layers .fc (input = self .x , size = 1 , act = None )
33
35
34
- cost = fluid .layers .square_error_cost (input = y_predict , label = self .y )
35
- avg_cost = fluid .layers .mean (cost )
36
+ cost = fluid .layers .square_error_cost (input = y_predict , label = self .y )
37
+ avg_cost = fluid .layers .mean (cost )
38
+ self .program = main
36
39
self .fetch_list = [avg_cost ]
37
40
38
41
sgd_optimizer = fluid .optimizer .SGD (learning_rate = 0.001 )
@@ -45,7 +48,7 @@ def run_net_on_place(self, place):
45
48
exe = fluid .Executor (place )
46
49
exe .run (fluid .default_startup_program ())
47
50
for data in train_reader ():
48
- exe .run (fluid . default_main_program () ,
51
+ exe .run (self . program ,
49
52
feed = feeder .feed (data ),
50
53
fetch_list = self .fetch_list )
51
54
# the main program is runable, the datatype is fully supported
@@ -68,7 +71,7 @@ def test_gpu(self):
68
71
69
72
70
73
# TODO(dzhwinter): make sure the fp16 is runable
71
- # class TestFloat16(SimpleNet ):
74
+ # class TestFloat16(TestNetWithDtype ):
72
75
# def init_dtype(self):
73
76
# self.dtype = "float16"
74
77
0 commit comments