File tree Expand file tree Collapse file tree 2 files changed +53
-1
lines changed Expand file tree Collapse file tree 2 files changed +53
-1
lines changed Original file line number Diff line number Diff line change @@ -67,7 +67,16 @@ def __init__(self, *layers):
67
67
self .add_sublayer (str (idx ), layer )
68
68
69
69
def __getitem__ (self , name ):
70
- return self ._sub_layers [str (name )]
70
+ if isinstance (name , slice ):
71
+ return self .__class__ (* (list (self ._sub_layers .values ())[name ]))
72
+ else :
73
+ if name >= len (self ._sub_layers ):
74
+ raise IndexError ('index {} is out of range' .format (name ))
75
+ elif name < 0 and name >= - len (self ._sub_layers ):
76
+ name += len (self ._sub_layers )
77
+ elif name < - len (self ._sub_layers ):
78
+ raise IndexError ('index {} is out of range' .format (name ))
79
+ return self ._sub_layers [str (name )]
71
80
72
81
def __setitem__ (self , name , layer ):
73
82
assert isinstance (layer , Layer )
Original file line number Diff line number Diff line change
1
+ # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import unittest
16
+ import paddle
17
+
18
+
19
+ class TestDataFeeder (unittest .TestCase ):
20
+ def test_lod_level_1_converter (self ):
21
+ sequential = paddle .nn .Sequential ()
22
+
23
+ for i in range (10 ):
24
+ sequential .add_sublayer (str (i ), paddle .nn .Linear (i + 1 , i + 1 ))
25
+
26
+ for item in sequential :
27
+ tmp = item
28
+
29
+ tmp = sequential [3 :5 ]
30
+ self .assertEqual (len (tmp ), 2 )
31
+
32
+ tmp = sequential [- 1 ]
33
+ self .assertEqual (tmp , sequential [9 ])
34
+
35
+ with self .assertRaises (IndexError ):
36
+ tmp = sequential [10 ]
37
+
38
+ with self .assertRaises (IndexError ):
39
+ tmp = sequential [- 11 ]
40
+
41
+
42
+ if __name__ == '__main__' :
43
+ unittest .main ()
You can’t perform that action at this time.
0 commit comments