|
16 | 16 | "name": "stdout", |
17 | 17 | "output_type": "stream", |
18 | 18 | "text": [ |
19 | | - "1.0.0 cpu\n" |
| 19 | + "1.2.0 cpu\n" |
20 | 20 | ] |
21 | 21 | } |
22 | 22 | ], |
|
52 | 52 | { |
53 | 53 | "cell_type": "code", |
54 | 54 | "execution_count": 2, |
55 | | - "metadata": { |
56 | | - "collapsed": true |
57 | | - }, |
| 55 | + "metadata": {}, |
58 | 56 | "outputs": [], |
59 | 57 | "source": [ |
60 | 58 | "# 将一个序列中所有的词记录在all_tokens中以便之后构造词典,然后在该序列后面添加PAD直到序列\n", |
|
75 | 73 | { |
76 | 74 | "cell_type": "code", |
77 | 75 | "execution_count": 3, |
78 | | - "metadata": { |
79 | | - "collapsed": true |
80 | | - }, |
| 76 | + "metadata": {}, |
81 | 77 | "outputs": [], |
82 | 78 | "source": [ |
83 | 79 | "def read_data(max_seq_len):\n", |
|
130 | 126 | { |
131 | 127 | "cell_type": "code", |
132 | 128 | "execution_count": 5, |
133 | | - "metadata": { |
134 | | - "collapsed": true |
135 | | - }, |
| 129 | + "metadata": {}, |
136 | 130 | "outputs": [], |
137 | 131 | "source": [ |
138 | 132 | "class Encoder(nn.Module):\n", |
|
183 | 177 | { |
184 | 178 | "cell_type": "code", |
185 | 179 | "execution_count": 7, |
186 | | - "metadata": { |
187 | | - "collapsed": true |
188 | | - }, |
| 180 | + "metadata": {}, |
189 | 181 | "outputs": [], |
190 | 182 | "source": [ |
191 | 183 | "def attention_model(input_size, attention_size):\n", |
|
198 | 190 | { |
199 | 191 | "cell_type": "code", |
200 | 192 | "execution_count": 8, |
201 | | - "metadata": { |
202 | | - "collapsed": true |
203 | | - }, |
| 193 | + "metadata": {}, |
204 | 194 | "outputs": [], |
205 | 195 | "source": [ |
206 | 196 | "def attention_forward(model, enc_states, dec_state):\n", |
|
250 | 240 | { |
251 | 241 | "cell_type": "code", |
252 | 242 | "execution_count": 10, |
253 | | - "metadata": { |
254 | | - "collapsed": true |
255 | | - }, |
| 243 | + "metadata": {}, |
256 | 244 | "outputs": [], |
257 | 245 | "source": [ |
258 | 246 | "class Decoder(nn.Module):\n", |
|
261 | 249 | " super(Decoder, self).__init__()\n", |
262 | 250 | " self.embedding = nn.Embedding(vocab_size, embed_size)\n", |
263 | 251 | " self.attention = attention_model(2*num_hiddens, attention_size)\n", |
264 | | - " # GRU的输入包含attention输出的c和实际输入, 所以尺寸是 2*embed_size\n", |
265 | | - " self.rnn = nn.GRU(2*embed_size, num_hiddens, num_layers, dropout=drop_prob)\n", |
| 252 | + " # GRU的输入包含attention输出的c和实际输入, 所以尺寸是 num_hiddens+embed_size\n", |
| 253 | + " self.rnn = nn.GRU(num_hiddens + embed_size, num_hiddens, \n", |
| 254 | + " num_layers, dropout=drop_prob)\n", |
266 | 255 | " self.out = nn.Linear(num_hiddens, vocab_size)\n", |
267 | 256 | "\n", |
268 | 257 | " def forward(self, cur_input, state, enc_states):\n", |
|
272 | 261 | " \"\"\"\n", |
273 | 262 | " # 使用注意力机制计算背景向量\n", |
274 | 263 | " c = attention_forward(self.attention, enc_states, state[-1])\n", |
275 | | - " # 将嵌入后的输入和背景向量在特征维连结\n", |
276 | | - " input_and_c = torch.cat((self.embedding(cur_input), c), dim=1) # (批量大小, 2*embed_size)\n", |
| 264 | + " # 将嵌入后的输入和背景向量在特征维连结, (批量大小, num_hiddens+embed_size)\n", |
| 265 | + " input_and_c = torch.cat((self.embedding(cur_input), c), dim=1) \n", |
277 | 266 | " # 为输入和背景向量的连结增加时间步维,时间步个数为1\n", |
278 | 267 | " output, state = self.rnn(input_and_c.unsqueeze(0), state)\n", |
279 | 268 | " # 移除时间步维,输出形状为(批量大小, 输出词典大小)\n", |
|
295 | 284 | { |
296 | 285 | "cell_type": "code", |
297 | 286 | "execution_count": 11, |
298 | | - "metadata": { |
299 | | - "collapsed": true |
300 | | - }, |
| 287 | + "metadata": {}, |
301 | 288 | "outputs": [], |
302 | 289 | "source": [ |
303 | 290 | "def batch_loss(encoder, decoder, X, Y, loss):\n", |
|
308 | 295 | " dec_state = decoder.begin_state(enc_state)\n", |
309 | 296 | " # 解码器在最初时间步的输入是BOS\n", |
310 | 297 | " dec_input = torch.tensor([out_vocab.stoi[BOS]] * batch_size)\n", |
311 | | - " # 我们将使用掩码变量mask来忽略掉标签为填充项PAD的损失\n", |
| 298 | + " # 我们将使用掩码变量mask来忽略掉标签为填充项PAD的损失, 初始全1\n", |
312 | 299 | " mask, num_not_pad_tokens = torch.ones(batch_size,), 0\n", |
313 | 300 | " l = torch.tensor([0.0])\n", |
314 | 301 | " for y in Y.permute(1,0): # Y shape: (batch, seq_len)\n", |
315 | 302 | " dec_output, dec_state = decoder(dec_input, dec_state, enc_outputs)\n", |
316 | 303 | " l = l + (mask * loss(dec_output, y)).sum()\n", |
317 | 304 | " dec_input = y # 使用强制教学\n", |
318 | 305 | " num_not_pad_tokens += mask.sum().item()\n", |
319 | | - " # 将PAD对应位置的掩码设成0, 原文这里是 y != out_vocab.stoi[EOS], 感觉有误\n", |
320 | | - " mask = mask * (y != out_vocab.stoi[PAD]).float()\n", |
| 306 | + " # EOS后面全是PAD. 下面一行保证一旦遇到EOS接下来的循环中mask就一直是0\n", |
| 307 | + " mask = mask * (y != out_vocab.stoi[EOS]).float()\n", |
321 | 308 | " return l / num_not_pad_tokens" |
322 | 309 | ] |
323 | 310 | }, |
324 | 311 | { |
325 | 312 | "cell_type": "code", |
326 | 313 | "execution_count": 12, |
327 | | - "metadata": { |
328 | | - "collapsed": true |
329 | | - }, |
| 314 | + "metadata": {}, |
330 | 315 | "outputs": [], |
331 | 316 | "source": [ |
332 | 317 | "def train(encoder, decoder, dataset, lr, batch_size, num_epochs):\n", |
|
358 | 343 | "name": "stdout", |
359 | 344 | "output_type": "stream", |
360 | 345 | "text": [ |
361 | | - "epoch 10, loss 0.441\n", |
362 | | - "epoch 20, loss 0.183\n", |
363 | | - "epoch 30, loss 0.100\n", |
364 | | - "epoch 40, loss 0.046\n", |
365 | | - "epoch 50, loss 0.025\n" |
| 346 | + "epoch 10, loss 0.475\n", |
| 347 | + "epoch 20, loss 0.245\n", |
| 348 | + "epoch 30, loss 0.157\n", |
| 349 | + "epoch 40, loss 0.052\n", |
| 350 | + "epoch 50, loss 0.039\n" |
366 | 351 | ] |
367 | 352 | } |
368 | 353 | ], |
|
386 | 371 | { |
387 | 372 | "cell_type": "code", |
388 | 373 | "execution_count": 14, |
389 | | - "metadata": { |
390 | | - "collapsed": true |
391 | | - }, |
| 374 | + "metadata": {}, |
392 | 375 | "outputs": [], |
393 | 376 | "source": [ |
394 | 377 | "def translate(encoder, decoder, input_seq, max_seq_len):\n", |
|
443 | 426 | { |
444 | 427 | "cell_type": "code", |
445 | 428 | "execution_count": 16, |
446 | | - "metadata": { |
447 | | - "collapsed": true |
448 | | - }, |
| 429 | + "metadata": {}, |
449 | 430 | "outputs": [], |
450 | 431 | "source": [ |
451 | 432 | "def bleu(pred_tokens, label_tokens, k):\n", |
|
466 | 447 | { |
467 | 448 | "cell_type": "code", |
468 | 449 | "execution_count": 17, |
469 | | - "metadata": { |
470 | | - "collapsed": true |
471 | | - }, |
| 450 | + "metadata": {}, |
472 | 451 | "outputs": [], |
473 | 452 | "source": [ |
474 | 453 | "def score(input_seq, label_seq, k):\n", |
|
504 | 483 | "name": "stdout", |
505 | 484 | "output_type": "stream", |
506 | 485 | "text": [ |
507 | | - "bleu 0.658, predict: they are russian .\n" |
| 486 | + "bleu 0.658, predict: they are exhausted .\n" |
508 | 487 | ] |
509 | 488 | } |
510 | 489 | ], |
511 | 490 | "source": [ |
512 | | - "score('ils sont canadiens .', 'they are canadian .', k=2)" |
| 491 | + "score('ils sont canadienne .', 'they are canadian .', k=2)" |
513 | 492 | ] |
514 | 493 | }, |
515 | 494 | { |
516 | 495 | "cell_type": "code", |
517 | 496 | "execution_count": null, |
518 | | - "metadata": { |
519 | | - "collapsed": true |
520 | | - }, |
| 497 | + "metadata": {}, |
521 | 498 | "outputs": [], |
522 | 499 | "source": [] |
523 | 500 | } |
524 | 501 | ], |
525 | 502 | "metadata": { |
526 | 503 | "kernelspec": { |
527 | | - "display_name": "Python [conda env:anaconda3]", |
| 504 | + "display_name": "Python [conda env:py36]", |
528 | 505 | "language": "python", |
529 | | - "name": "conda-env-anaconda3-py" |
| 506 | + "name": "conda-env-py36-py" |
530 | 507 | }, |
531 | 508 | "language_info": { |
532 | 509 | "codemirror_mode": { |
|
538 | 515 | "name": "python", |
539 | 516 | "nbconvert_exporter": "python", |
540 | 517 | "pygments_lexer": "ipython3", |
541 | | - "version": "3.6.8" |
| 518 | + "version": "3.6.2" |
542 | 519 | } |
543 | 520 | }, |
544 | 521 | "nbformat": 4, |
|
0 commit comments