Skip to content

Commit 4ffd339

Browse files
authored
[Cherry-pick][Dy2Stat]Support Nest sequtial container (#34246) #34262
* support Nest sequtial container * rename model path
1 parent 8db945a commit 4ffd339

File tree

2 files changed

+32
-8
lines changed

2 files changed

+32
-8
lines changed

python/paddle/fluid/dygraph/dygraph_to_static/convert_call_func.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def is_unsupported(func):
8888
for v in m.__dict__.values():
8989
func_in_dict = func == v
9090
if isinstance(func_in_dict, (list, numpy.ndarray)):
91-
func_in_dict = any(func_in_dict)
91+
func_in_dict = numpy.array(func_in_dict).any()
9292
if func_in_dict:
9393
translator_logger.log(
9494
2,

python/paddle/fluid/tests/unittests/dygraph_to_static/test_container.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,24 +47,43 @@ def forward(self, x):
4747
return out
4848

4949

50+
class NestSequentialNet(paddle.nn.Layer):
51+
def __init__(self):
52+
super().__init__()
53+
group1 = paddle.nn.Sequential(
54+
paddle.nn.Linear(10, 10),
55+
paddle.nn.Sigmoid(), )
56+
group2 = paddle.nn.Sequential(
57+
paddle.nn.Linear(10, 3),
58+
paddle.nn.ReLU(), )
59+
self.layers = paddle.nn.Sequential(group1, group2)
60+
61+
def forward(self, x):
62+
return self.layers(x)
63+
64+
5065
class TestSequential(unittest.TestCase):
5166
def setUp(self):
5267
paddle.set_device('cpu')
5368
self.seed = 2021
69+
self._init_config()
70+
71+
def _init_config(self):
72+
self.net = SequentialNet(BufferLayers, 10, 3)
73+
self.model_path = './sequential_net'
5474

5575
def _init_seed(self):
5676
paddle.seed(self.seed)
5777
np.random.seed(self.seed)
5878

5979
def _run(self, to_static):
6080
self._init_seed()
61-
net = SequentialNet(BufferLayers, 10, 3)
6281
if to_static:
63-
net = paddle.jit.to_static(net)
82+
self.net = paddle.jit.to_static(self.net)
6483
x = paddle.rand([16, 10], 'float32')
65-
out = net(x)
84+
out = self.net(x)
6685
if to_static:
67-
load_out = self._test_load(net, x)
86+
load_out = self._test_load(self.net, x)
6887
self.assertTrue(
6988
np.allclose(load_out, out),
7089
msg='load_out is {}\st_out is {}'.format(load_out, out))
@@ -80,12 +99,17 @@ def test_train(self):
8099
msg='dygraph_res is {}\nstatic_res is {}'.format(dy_out, st_out))
81100

82101
def _test_load(self, net, x):
83-
model_path = './sequential_net'
84-
paddle.jit.save(net, model_path)
85-
load_net = paddle.jit.load(model_path)
102+
paddle.jit.save(net, self.model_path)
103+
load_net = paddle.jit.load(self.model_path)
86104
out = load_net(x)
87105
return out
88106

89107

108+
class TestNestSequential(TestSequential):
109+
def _init_config(self):
110+
self.net = NestSequentialNet()
111+
self.model_path = './nested_sequential_net'
112+
113+
90114
if __name__ == '__main__':
91115
unittest.main()

0 commit comments

Comments
 (0)