@@ -47,24 +47,43 @@ def forward(self, x):
47
47
return out
48
48
49
49
50
+ class NestSequentialNet (paddle .nn .Layer ):
51
+ def __init__ (self ):
52
+ super ().__init__ ()
53
+ group1 = paddle .nn .Sequential (
54
+ paddle .nn .Linear (10 , 10 ),
55
+ paddle .nn .Sigmoid (), )
56
+ group2 = paddle .nn .Sequential (
57
+ paddle .nn .Linear (10 , 3 ),
58
+ paddle .nn .ReLU (), )
59
+ self .layers = paddle .nn .Sequential (group1 , group2 )
60
+
61
+ def forward (self , x ):
62
+ return self .layers (x )
63
+
64
+
50
65
class TestSequential (unittest .TestCase ):
51
66
def setUp (self ):
52
67
paddle .set_device ('cpu' )
53
68
self .seed = 2021
69
+ self ._init_config ()
70
+
71
+ def _init_config (self ):
72
+ self .net = SequentialNet (BufferLayers , 10 , 3 )
73
+ self .model_path = './sequential_net'
54
74
55
75
def _init_seed (self ):
56
76
paddle .seed (self .seed )
57
77
np .random .seed (self .seed )
58
78
59
79
def _run (self , to_static ):
60
80
self ._init_seed ()
61
- net = SequentialNet (BufferLayers , 10 , 3 )
62
81
if to_static :
63
- net = paddle .jit .to_static (net )
82
+ self . net = paddle .jit .to_static (self . net )
64
83
x = paddle .rand ([16 , 10 ], 'float32' )
65
- out = net (x )
84
+ out = self . net (x )
66
85
if to_static :
67
- load_out = self ._test_load (net , x )
86
+ load_out = self ._test_load (self . net , x )
68
87
self .assertTrue (
69
88
np .allclose (load_out , out ),
70
89
msg = 'load_out is {}\st_out is {}' .format (load_out , out ))
@@ -80,12 +99,17 @@ def test_train(self):
80
99
msg = 'dygraph_res is {}\n static_res is {}' .format (dy_out , st_out ))
81
100
82
101
def _test_load (self , net , x ):
83
- model_path = './sequential_net'
84
- paddle .jit .save (net , model_path )
85
- load_net = paddle .jit .load (model_path )
102
+ paddle .jit .save (net , self .model_path )
103
+ load_net = paddle .jit .load (self .model_path )
86
104
out = load_net (x )
87
105
return out
88
106
89
107
108
+ class TestNestSequential (TestSequential ):
109
+ def _init_config (self ):
110
+ self .net = NestSequentialNet ()
111
+ self .model_path = './nested_sequential_net'
112
+
113
+
90
114
if __name__ == '__main__' :
91
115
unittest .main ()
0 commit comments