Skip to content

Commit dc830b7

Browse files
committed
Add Module,ModuleList,ModuleDict docs
1 parent f536970 commit dc830b7

File tree

4 files changed

+1347
-1
lines changed

4 files changed

+1347
-1
lines changed
Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,246 @@
1+
.. _cn_api_paddle_nn_ModuleDict:
2+
3+
ModuleDict
4+
-------------------------------
5+
6+
.. py:class:: paddle.nn.ModuleDict(modules=None)
7+
8+
9+
10+
11+
ModuleDict 用于保存子层到有序字典中,它包含的子层将被正确地注册和添加。列表中的子层可以像常规 python 有序字典一样被访问。
12+
13+
.. note::
14+
``LayerDict`` 是 ``ModuleDict`` 的别名,两者在使用和功能上完全等价。
15+
16+
参数
17+
::::::::::::
18+
19+
- **modules** (ModuleDict|OrderedDict|list[(key, Module)],可选) - 键值对的可迭代对象,值的类型为 `paddle.nn.Module` 。
20+
21+
22+
代码示例
23+
::::::::::::
24+
25+
.. code-block:: python
26+
27+
>>> import paddle
28+
>>> import numpy as np
29+
>>> from collections import OrderedDict
30+
31+
>>> modules = OrderedDict([
32+
... ('conv1d', paddle.nn.Conv1D(3, 2, 3)),
33+
... ('conv2d', paddle.nn.Conv2D(3, 2, 3)),
34+
... ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))),
35+
>>> ])
36+
37+
>>> modules_dict = paddle.nn.ModuleDict(modules=modules)
38+
39+
>>> l = modules_dict['conv1d']
40+
41+
>>> for k in modules_dict:
42+
... l = modules_dict[k]
43+
...
44+
>>> print(len(modules_dict))
45+
3
46+
47+
>>> del modules_dict['conv2d']
48+
>>> print(len(modules_dict))
49+
2
50+
51+
>>> conv1d = modules_dict.pop('conv1d')
52+
>>> print(len(modules_dict))
53+
1
54+
55+
>>> modules_dict.clear()
56+
>>> print(len(modules_dict))
57+
0
58+
59+
方法
60+
::::::::::::
61+
clear()
62+
'''''''''
63+
64+
清除 ModuleDict 中所有的子层。
65+
66+
**参数**
67+
68+
无。
69+
70+
**代码示例**
71+
72+
.. code-block:: python
73+
74+
>>> import paddle
75+
>>> from collections import OrderedDict
76+
77+
>>> modules = OrderedDict([
78+
... ('conv1d', paddle.nn.Conv1D(3, 2, 3)),
79+
... ('conv2d', paddle.nn.Conv2D(3, 2, 3)),
80+
... ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))),
81+
>>> ])
82+
83+
>>> module_dict = paddle.nn.ModuleDict(modules=modules)
84+
>>> len(module_dict)
85+
3
86+
87+
>>> module_dict.clear()
88+
>>> len(module_dict)
89+
0
90+
91+
pop()
92+
'''''''''
93+
94+
移除 ModuleDict 中的键 并且返回该键对应的子层。
95+
96+
**参数**
97+
98+
- **key** (str) - 要移除的 key。
99+
100+
**代码示例**
101+
102+
.. code-block:: python
103+
104+
>>> import paddle
105+
>>> from collections import OrderedDict
106+
107+
>>> modules = OrderedDict([
108+
... ('conv1d', paddle.nn.Conv1D(3, 2, 3)),
109+
... ('conv2d', paddle.nn.Conv2D(3, 2, 3)),
110+
... ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))),
111+
>>> ])
112+
113+
>>> module_dict = paddle.nn.ModuleDict(modules=modules)
114+
>>> len(module_dict)
115+
3
116+
117+
>>> module_dict.pop('conv2d')
118+
>>> len(module_dict)
119+
2
120+
121+
keys()
122+
'''''''''
123+
124+
返回 ModuleDict 中键的可迭代对象。
125+
126+
**参数**
127+
128+
无。
129+
130+
**代码示例**
131+
132+
.. code-block:: python
133+
134+
>>> import paddle
135+
>>> from collections import OrderedDict
136+
137+
>>> modules = OrderedDict([
138+
... ('conv1d', paddle.nn.Conv1D(3, 2, 3)),
139+
... ('conv2d', paddle.nn.Conv2D(3, 2, 3)),
140+
... ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))),
141+
>>> ])
142+
143+
>>> module_dict = paddle.nn.ModuleDict(modules=modules)
144+
>>> for k in module_dict.keys():
145+
... print(k)
146+
conv1d
147+
conv2d
148+
conv3d
149+
150+
151+
items()
152+
'''''''''
153+
154+
返回 ModuleDict 中键/值对的可迭代对象。
155+
156+
**参数**
157+
158+
无。
159+
160+
**代码示例**
161+
162+
.. code-block:: python
163+
164+
>>> import paddle
165+
>>> from collections import OrderedDict
166+
167+
>>> modules = OrderedDict([
168+
... ('conv1d', paddle.nn.Conv1D(3, 2, 3)),
169+
... ('conv2d', paddle.nn.Conv2D(3, 2, 3)),
170+
... ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))),
171+
>>> ])
172+
173+
>>> module_dict = paddle.nn.ModuleDict(modules=modules)
174+
>>> for k, v in module_dict.items():
175+
... print(f"{k}:", v)
176+
conv1d : Conv1D(3, 2, kernel_size=[3], data_format=NCL)
177+
conv2d : Conv2D(3, 2, kernel_size=[3, 3], data_format=NCHW)
178+
conv3d : Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)
179+
180+
181+
values()
182+
'''''''''
183+
184+
返回 ModuleDict 中值的可迭代对象。
185+
186+
**参数**
187+
188+
无。
189+
190+
**代码示例**
191+
192+
.. code-block:: python
193+
194+
>>> import paddle
195+
>>> from collections import OrderedDict
196+
197+
>>> modules = OrderedDict([
198+
... ('conv1d', paddle.nn.Conv1D(3, 2, 3)),
199+
... ('conv2d', paddle.nn.Conv2D(3, 2, 3)),
200+
... ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))),
201+
>>> ])
202+
203+
>>> module_dict = paddle.nn.ModuleDict(modules=modules)
204+
>>> for v in module_dict.values():
205+
... print(v)
206+
Conv1D(3, 2, kernel_size=[3], data_format=NCL)
207+
Conv2D(3, 2, kernel_size=[3, 3], data_format=NCHW)
208+
Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)
209+
210+
211+
update()
212+
'''''''''
213+
214+
更新子层中的键/值对到 ModuleDict 中,会覆盖已经存在的键。
215+
216+
**参数**
217+
218+
- **sublayers** (ModuleDict|OrderedDict|list[(key, Module)]) - 键值对的可迭代对象,值的类型为 `paddle.nn.Module` 。
219+
220+
**代码示例**
221+
222+
.. code-block:: python
223+
224+
>>> import paddle
225+
>>> from collections import OrderedDict
226+
227+
>>> modules = OrderedDict([
228+
... ('conv1d', paddle.nn.Conv1D(3, 2, 3)),
229+
... ('conv2d', paddle.nn.Conv2D(3, 2, 3)),
230+
... ('conv3d', paddle.nn.Conv3D(4, 6, (3, 3, 3))),
231+
>>> ])
232+
233+
>>> new_submodules = OrderedDict([
234+
... ('relu', paddle.nn.ReLU()),
235+
... ('conv2d', paddle.nn.Conv2D(4, 2, 4)),
236+
>>> ])
237+
>>> module_dict = paddle.nn.ModuleDict(modules=modules)
238+
239+
>>> module_dict.update(new_submodules)
240+
241+
>>> for k, v in module_dict.items():
242+
... print(f"{k}:", v)
243+
conv1d : Conv1D(3, 2, kernel_size=[3], data_format=NCL)
244+
conv2d : Conv2D(4, 2, kernel_size=[4, 4], data_format=NCHW)
245+
conv3d : Conv3D(4, 6, kernel_size=[3, 3, 3], data_format=NCDHW)
246+
relu : ReLU()
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
.. _cn_api_paddle_nn_ModuleList:
2+
3+
ModuleList
4+
-------------------------------
5+
6+
.. py:class:: paddle.nn.ModuleList(modules=None)
7+
8+
9+
10+
11+
ModuleList 用于保存子层列表,它包含的子层将被正确地注册和添加。列表中的子层可以像常规 python 列表一样被索引。
12+
13+
.. note::
14+
``LayerList`` 是 ``ModuleList`` 的别名,两者在使用和功能上完全等价。
15+
16+
参数
17+
::::::::::::
18+
19+
- **modules** (iterable,可选) - 要保存的子层。
20+
21+
22+
代码示例
23+
::::::::::::
24+
25+
.. code-block:: python
26+
27+
>>> import paddle
28+
29+
>>> class MyModule(paddle.nn.Module):
30+
... def __init__(self):
31+
... super().__init__()
32+
... self.linears = paddle.nn.ModuleList(
33+
... [paddle.nn.Linear(10, 10) for i in range(10)])
34+
...
35+
... def forward(self, x):
36+
... for i, l in enumerate(self.linears):
37+
... x = self.linears[i // 2](x) + l(x)
38+
... return x
39+
40+
方法
41+
::::::::::::
42+
append()
43+
'''''''''
44+
45+
添加一个子层到整个 list 的最后。
46+
47+
**参数**
48+
49+
- **sublayer** (Module) - 要添加的子层。
50+
51+
**代码示例**
52+
53+
.. code-block:: python
54+
55+
>>> import paddle
56+
57+
>>> linears = paddle.nn.ModuleList([paddle.nn.Linear(10, 10) for i in range(10)])
58+
>>> another = paddle.nn.Linear(10, 10)
59+
>>> linears.append(another)
60+
>>> print(len(linears))
61+
11
62+
63+
64+
insert()
65+
'''''''''
66+
67+
向 list 中插入一个子层,到给定的 index 前面。
68+
69+
**参数**
70+
71+
- **index** (int) - 要插入的位置。
72+
- **sublayers** (Layer) - 要插入的子层。
73+
74+
**代码示例**
75+
76+
.. code-block:: python
77+
78+
>>> import paddle
79+
80+
>>> linears = paddle.nn.ModuleList([paddle.nn.Linear(10, 10) for i in range(10)])
81+
>>> another = paddle.nn.Linear(10, 10)
82+
>>> linears.insert(3, another)
83+
>>> print(linears[3] is another)
84+
True
85+
>>> another = paddle.nn.Linear(10, 10)
86+
>>> linears.insert(-1, another)
87+
>>> print(linears[-2] is another)
88+
True
89+
90+
extend()
91+
'''''''''
92+
93+
添加多个子层到整个 list 的最后。
94+
95+
**参数**
96+
97+
- **sublayers** (iterable of Module) - 要添加的所有子层。
98+
99+
**代码示例**
100+
101+
.. code-block:: python
102+
103+
>>> import paddle
104+
105+
>>> linears = paddle.nn.ModuleList([paddle.nn.Linear(10, 10) for i in range(10)])
106+
>>> another_list = paddle.nn.ModuleList([paddle.nn.Linear(10, 10) for i in range(5)])
107+
>>> linears.extend(another_list)
108+
>>> print(len(linears))
109+
15
110+
>>> print(another_list[0] is linears[10])
111+
True

0 commit comments

Comments
 (0)