Skip to content

Commit 158fc99

Browse files
[cherry pick]fix a bug of Sequential::__getitem__ (#30899) (#31192)
* fix a bug of Sequential::__getitem__, test=develop
1 parent c7b32fe commit 158fc99

File tree

2 files changed

+53
-1
lines changed

2 files changed

+53
-1
lines changed

python/paddle/fluid/dygraph/container.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,16 @@ def __init__(self, *layers):
6767
self.add_sublayer(str(idx), layer)
6868

6969
def __getitem__(self, name):
70-
return self._sub_layers[str(name)]
70+
if isinstance(name, slice):
71+
return self.__class__(*(list(self._sub_layers.values())[name]))
72+
else:
73+
if name >= len(self._sub_layers):
74+
raise IndexError('index {} is out of range'.format(name))
75+
elif name < 0 and name >= -len(self._sub_layers):
76+
name += len(self._sub_layers)
77+
elif name < -len(self._sub_layers):
78+
raise IndexError('index {} is out of range'.format(name))
79+
return self._sub_layers[str(name)]
7180

7281
def __setitem__(self, name, layer):
7382
assert isinstance(layer, Layer)
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import paddle
17+
18+
19+
class TestDataFeeder(unittest.TestCase):
20+
def test_lod_level_1_converter(self):
21+
sequential = paddle.nn.Sequential()
22+
23+
for i in range(10):
24+
sequential.add_sublayer(str(i), paddle.nn.Linear(i + 1, i + 1))
25+
26+
for item in sequential:
27+
tmp = item
28+
29+
tmp = sequential[3:5]
30+
self.assertEqual(len(tmp), 2)
31+
32+
tmp = sequential[-1]
33+
self.assertEqual(tmp, sequential[9])
34+
35+
with self.assertRaises(IndexError):
36+
tmp = sequential[10]
37+
38+
with self.assertRaises(IndexError):
39+
tmp = sequential[-11]
40+
41+
42+
if __name__ == '__main__':
43+
unittest.main()

0 commit comments

Comments
 (0)