Skip to content

Commit 9f8cd35

Browse files
committed
add more info about ModuleList
1 parent aa75893 commit 9f8cd35

File tree

2 files changed

+152
-20
lines changed

2 files changed

+152
-20
lines changed

code/chapter04_DL_computation/4.1_model-construction.ipynb

Lines changed: 99 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"name": "stdout",
1717
"output_type": "stream",
1818
"text": [
19-
"0.4.1\n"
19+
"1.2.0\n"
2020
]
2121
}
2222
],
@@ -78,10 +78,10 @@
7878
{
7979
"data": {
8080
"text/plain": [
81-
"tensor([[ 0.1351, -0.0034, 0.0948, -0.1652, 0.1512, 0.0887, -0.0032, 0.0692,\n",
82-
" 0.0942, 0.0956],\n",
83-
" [ 0.1624, -0.0383, 0.1557, -0.0735, 0.1931, 0.1699, -0.0067, 0.0353,\n",
84-
" 0.1712, 0.1568]], grad_fn=<ThAddmmBackward>)"
81+
"tensor([[ 0.0234, -0.2646, -0.1168, -0.2127, 0.0884, -0.0456, 0.0811, 0.0297,\n",
82+
" 0.2032, 0.1364],\n",
83+
" [ 0.1479, -0.1545, -0.0265, -0.2119, -0.0543, -0.0086, 0.0902, -0.1017,\n",
84+
" 0.1504, 0.1144]], grad_fn=<AddmmBackward>)"
8585
]
8686
},
8787
"execution_count": 3,
@@ -107,7 +107,9 @@
107107
{
108108
"cell_type": "code",
109109
"execution_count": 4,
110-
"metadata": {},
110+
"metadata": {
111+
"collapsed": true
112+
},
111113
"outputs": [],
112114
"source": [
113115
"class MySequential(nn.Module):\n",
@@ -146,10 +148,10 @@
146148
{
147149
"data": {
148150
"text/plain": [
149-
"tensor([[ 0.1883, -0.1269, -0.1886, 0.0638, -0.1004, -0.0600, 0.0760, -0.1788,\n",
150-
" -0.1844, -0.2131],\n",
151-
" [ 0.1319, -0.0490, -0.1365, 0.0133, -0.0483, -0.0861, 0.0369, -0.0830,\n",
152-
" -0.0462, -0.2066]], grad_fn=<ThAddmmBackward>)"
151+
"tensor([[ 0.1273, 0.1642, -0.1060, 0.1401, 0.0609, -0.0199, -0.0140, -0.0588,\n",
152+
" 0.1765, -0.1296],\n",
153+
" [ 0.0267, 0.1670, -0.0626, 0.0744, 0.0574, 0.0413, 0.1313, -0.1479,\n",
154+
" 0.0932, -0.0615]], grad_fn=<AddmmBackward>)"
153155
]
154156
},
155157
"execution_count": 5,
@@ -199,6 +201,74 @@
199201
"print(net)"
200202
]
201203
},
204+
{
205+
"cell_type": "code",
206+
"execution_count": 7,
207+
"metadata": {},
208+
"outputs": [],
209+
"source": [
210+
"# net(torch.zeros(1, 784)) # 会报NotImplementedError"
211+
]
212+
},
213+
{
214+
"cell_type": "code",
215+
"execution_count": 8,
216+
"metadata": {
217+
"collapsed": true
218+
},
219+
"outputs": [],
220+
"source": [
221+
"class MyModule(nn.Module):\n",
222+
" def __init__(self):\n",
223+
" super(MyModule, self).__init__()\n",
224+
" self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])\n",
225+
"\n",
226+
" def forward(self, x):\n",
227+
" # ModuleList can act as an iterable, or be indexed using ints\n",
228+
" for i, l in enumerate(self.linears):\n",
229+
" x = self.linears[i // 2](x) + l(x)\n",
230+
" return x"
231+
]
232+
},
233+
{
234+
"cell_type": "code",
235+
"execution_count": 9,
236+
"metadata": {},
237+
"outputs": [
238+
{
239+
"name": "stdout",
240+
"output_type": "stream",
241+
"text": [
242+
"net1:\n",
243+
"torch.Size([10, 10])\n",
244+
"torch.Size([10])\n",
245+
"net2:\n"
246+
]
247+
}
248+
],
249+
"source": [
250+
"class Module_ModuleList(nn.Module):\n",
251+
" def __init__(self):\n",
252+
" super(Module_ModuleList, self).__init__()\n",
253+
" self.linears = nn.ModuleList([nn.Linear(10, 10)])\n",
254+
" \n",
255+
"class Module_List(nn.Module):\n",
256+
" def __init__(self):\n",
257+
" super(Module_List, self).__init__()\n",
258+
" self.linears = [nn.Linear(10, 10)]\n",
259+
"\n",
260+
"net1 = Module_ModuleList()\n",
261+
"net2 = Module_List()\n",
262+
"\n",
263+
"print(\"net1:\")\n",
264+
"for p in net1.parameters():\n",
265+
" print(p.size())\n",
266+
"\n",
267+
"print(\"net2:\")\n",
268+
"for p in net2.parameters():\n",
269+
" print(p)"
270+
]
271+
},
202272
{
203273
"cell_type": "markdown",
204274
"metadata": {},
@@ -208,7 +278,7 @@
208278
},
209279
{
210280
"cell_type": "code",
211-
"execution_count": 7,
281+
"execution_count": 10,
212282
"metadata": {},
213283
"outputs": [
214284
{
@@ -236,6 +306,15 @@
236306
"print(net)"
237307
]
238308
},
309+
{
310+
"cell_type": "code",
311+
"execution_count": 11,
312+
"metadata": {},
313+
"outputs": [],
314+
"source": [
315+
"# net(torch.zeros(1, 784)) # 会报NotImplementedError"
316+
]
317+
},
239318
{
240319
"cell_type": "markdown",
241320
"metadata": {},
@@ -245,7 +324,7 @@
245324
},
246325
{
247326
"cell_type": "code",
248-
"execution_count": 8,
327+
"execution_count": 12,
249328
"metadata": {
250329
"collapsed": true
251330
},
@@ -275,7 +354,7 @@
275354
},
276355
{
277356
"cell_type": "code",
278-
"execution_count": 9,
357+
"execution_count": 13,
279358
"metadata": {},
280359
"outputs": [
281360
{
@@ -290,10 +369,10 @@
290369
{
291370
"data": {
292371
"text/plain": [
293-
"tensor(12.1594, grad_fn=<SumBackward0>)"
372+
"tensor(0.8907, grad_fn=<SumBackward0>)"
294373
]
295374
},
296-
"execution_count": 9,
375+
"execution_count": 13,
297376
"metadata": {},
298377
"output_type": "execute_result"
299378
}
@@ -307,7 +386,7 @@
307386
},
308387
{
309388
"cell_type": "code",
310-
"execution_count": 10,
389+
"execution_count": 14,
311390
"metadata": {},
312391
"outputs": [
313392
{
@@ -331,10 +410,10 @@
331410
{
332411
"data": {
333412
"text/plain": [
334-
"tensor(0.1509, grad_fn=<SumBackward0>)"
413+
"tensor(-0.4605, grad_fn=<SumBackward0>)"
335414
]
336415
},
337-
"execution_count": 10,
416+
"execution_count": 14,
338417
"metadata": {},
339418
"output_type": "execute_result"
340419
}
@@ -367,7 +446,7 @@
367446
],
368447
"metadata": {
369448
"kernelspec": {
370-
"display_name": "Python [default]",
449+
"display_name": "Python 3",
371450
"language": "python",
372451
"name": "python3"
373452
},
@@ -381,7 +460,7 @@
381460
"name": "python",
382461
"nbconvert_exporter": "python",
383462
"pygments_lexer": "ipython3",
384-
"version": "3.6.3"
463+
"version": "3.6.2"
385464
}
386465
},
387466
"nbformat": 4,

docs/chapter04_DL_computation/4.1_model-construction.md

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,7 @@ net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
114114
net.append(nn.Linear(256, 10)) # # 类似List的append操作
115115
print(net[-1]) # 类似List的索引访问
116116
print(net)
117+
# net(torch.zeros(1, 784)) # 会报NotImplementedError
117118
```
118119
输出:
119120
```
@@ -125,6 +126,55 @@ ModuleList(
125126
)
126127
```
127128

129+
既然`Sequential``ModuleList`都可以进行列表化构造网络,那二者区别是什么呢。`ModuleList`仅仅是一个储存各种模块的列表,这些模块之间没有联系也没有顺序(所以不用保证相邻层的输入输出维度匹配),而且没有实现`forward`功能需要自己实现,所以上面执行`net(torch.zeros(1, 784))`会报`NotImplementedError`;而`Sequential`内的模块需要按照顺序排列,要保证相邻层的输入输出大小相匹配,内部`forward`功能已经实现。
130+
131+
`ModuleList`的出现只是让网络定义前向传播时更加灵活,见下面官网的例子。
132+
``` python
133+
class MyModule(nn.Module):
134+
def __init__(self):
135+
super(MyModule, self).__init__()
136+
self.linears = nn.ModuleList([nn.Linear(10, 10) for i in range(10)])
137+
138+
def forward(self, x):
139+
# ModuleList can act as an iterable, or be indexed using ints
140+
for i, l in enumerate(self.linears):
141+
x = self.linears[i // 2](x) + l(x)
142+
return x
143+
```
144+
145+
另外,`ModuleList`不同于一般的Python的`list`,加入到`ModuleList`里面的所有模块的参数会被自动添加到整个网络中,下面看一个例子对比一下。
146+
147+
``` python
148+
class Module_ModuleList(nn.Module):
149+
def __init__(self):
150+
super(Module_ModuleList, self).__init__()
151+
self.linears = nn.ModuleList([nn.Linear(10, 10)])
152+
153+
class Module_List(nn.Module):
154+
def __init__(self):
155+
super(Module_List, self).__init__()
156+
self.linears = [nn.Linear(10, 10)]
157+
158+
net1 = Module_ModuleList()
159+
net2 = Module_List()
160+
161+
print("net1:")
162+
for p in net1.parameters():
163+
print(p.size())
164+
165+
print("net2:")
166+
for p in net2.parameters():
167+
print(p)
168+
```
169+
输出:
170+
```
171+
net1:
172+
torch.Size([10, 10])
173+
torch.Size([10])
174+
net2:
175+
```
176+
177+
128178
### 4.1.2.3 `ModuleDict`
129179
`ModuleDict`接收一个子模块的字典作为输入, 然后也可以类似字典那样进行添加访问操作:
130180
``` python
@@ -136,6 +186,7 @@ net['output'] = nn.Linear(256, 10) # 添加
136186
print(net['linear']) # 访问
137187
print(net.output)
138188
print(net)
189+
# net(torch.zeros(1, 784)) # 会报NotImplementedError
139190
```
140191
输出:
141192
```
@@ -148,6 +199,7 @@ ModuleDict(
148199
)
149200
```
150201

202+
`ModuleList`一样,`ModuleDict`实例仅仅是存放了一些模块的字典,并没有定义`forward`函数需要自己定义。同样,`ModuleDict`也与Python的`Dict`有所不同,`ModuleDict`里的所有模块的参数会被自动添加到整个网络中。
151203

152204
## 4.1.3 构造复杂的模型
153205

@@ -230,6 +282,7 @@ tensor(14.4908, grad_fn=<SumBackward0>)
230282

231283
* 可以通过继承`Module`类来构造模型。
232284
* `Sequential``ModuleList``ModuleDict`类都继承自`Module`类。
285+
*`Sequential`不同,`ModuleList``ModuleDict`并没有定义一个完整的网络,它们只是将不同的模块存放在一起,需要自己定义`forward`函数。
233286
* 虽然`Sequential`等类可以使模型构造更加简单,但直接继承`Module`类可以极大地拓展模型构造的灵活性。
234287

235288

0 commit comments

Comments
 (0)