|
1 | 1 | import sys
|
| 2 | + |
2 | 3 | import paddle.v2 as paddle
|
3 | 4 |
|
4 | 5 |
|
5 |
| -def seqToseq_net(source_dict_dim, target_dict_dim): |
| 6 | +def seqToseq_net(source_dict_dim, target_dict_dim, is_generating=False): |
6 | 7 | ### Network Architecture
|
7 | 8 | word_vector_dim = 512 # dimension of word vector
|
8 | 9 | decoder_size = 512 # dimension of hidden unit in GRU Decoder network
|
9 | 10 | encoder_size = 512 # dimension of hidden unit in GRU Encoder network
|
10 | 11 |
|
| 12 | + beam_size = 3 |
| 13 | + max_length = 250 |
| 14 | + |
11 | 15 | #### Encoder
|
12 | 16 | src_word_id = paddle.layer.data(
|
13 | 17 | name='source_language_word',
|
@@ -67,30 +71,57 @@ def gru_decoder_with_attention(enc_vec, enc_proj, current_word):
|
67 | 71 | group_input2 = paddle.layer.StaticInputV2(input=encoded_proj, is_seq=True)
|
68 | 72 | group_inputs = [group_input1, group_input2]
|
69 | 73 |
|
70 |
| - trg_embedding = paddle.layer.embedding( |
71 |
| - input=paddle.layer.data( |
72 |
| - name='target_language_word', |
73 |
| - type=paddle.data_type.integer_value_sequence(target_dict_dim)), |
74 |
| - size=word_vector_dim, |
75 |
| - param_attr=paddle.attr.ParamAttr(name='_target_language_embedding')) |
76 |
| - group_inputs.append(trg_embedding) |
77 |
| - |
78 |
| - # For decoder equipped with attention mechanism, in training, |
79 |
| - # target embeding (the groudtruth) is the data input, |
80 |
| - # while encoded source sequence is accessed to as an unbounded memory. |
81 |
| - # Here, the StaticInput defines a read-only memory |
82 |
| - # for the recurrent_group. |
83 |
| - decoder = paddle.layer.recurrent_group( |
84 |
| - name=decoder_group_name, |
85 |
| - step=gru_decoder_with_attention, |
86 |
| - input=group_inputs) |
87 |
| - |
88 |
| - lbl = paddle.layer.data( |
89 |
| - name='target_language_next_word', |
90 |
| - type=paddle.data_type.integer_value_sequence(target_dict_dim)) |
91 |
| - cost = paddle.layer.classification_cost(input=decoder, label=lbl) |
92 |
| - |
93 |
| - return cost |
| 74 | + if not is_generating: |
| 75 | + trg_embedding = paddle.layer.embedding( |
| 76 | + input=paddle.layer.data( |
| 77 | + name='target_language_word', |
| 78 | + type=paddle.data_type.integer_value_sequence(target_dict_dim)), |
| 79 | + size=word_vector_dim, |
| 80 | + param_attr=paddle.attr.ParamAttr(name='_target_language_embedding')) |
| 81 | + group_inputs.append(trg_embedding) |
| 82 | + |
| 83 | + # For decoder equipped with attention mechanism, in training, |
| 84 | + # target embeding (the groudtruth) is the data input, |
| 85 | + # while encoded source sequence is accessed to as an unbounded memory. |
| 86 | + # Here, the StaticInput defines a read-only memory |
| 87 | + # for the recurrent_group. |
| 88 | + decoder = paddle.layer.recurrent_group( |
| 89 | + name=decoder_group_name, |
| 90 | + step=gru_decoder_with_attention, |
| 91 | + input=group_inputs) |
| 92 | + |
| 93 | + lbl = paddle.layer.data( |
| 94 | + name='target_language_next_word', |
| 95 | + type=paddle.data_type.integer_value_sequence(target_dict_dim)) |
| 96 | + cost = paddle.layer.classification_cost(input=decoder, label=lbl) |
| 97 | + |
| 98 | + return cost |
| 99 | + else: |
| 100 | + # In generation, the decoder predicts a next target word based on |
| 101 | + # the encoded source sequence and the last generated target word. |
| 102 | + |
| 103 | + # The encoded source sequence (encoder's output) must be specified by |
| 104 | + # StaticInput, which is a read-only memory. |
| 105 | + # Embedding of the last generated word is automatically gotten by |
| 106 | + # GeneratedInputs, which is initialized by a start mark, such as <s>, |
| 107 | + # and must be included in generation. |
| 108 | + |
| 109 | + trg_embedding = paddle.layer.GeneratedInputV2( |
| 110 | + size=target_dict_dim, |
| 111 | + embedding_name='_target_language_embedding', |
| 112 | + embedding_size=word_vector_dim) |
| 113 | + group_inputs.append(trg_embedding) |
| 114 | + |
| 115 | + beam_gen = paddle.layer.beam_search( |
| 116 | + name=decoder_group_name, |
| 117 | + step=gru_decoder_with_attention, |
| 118 | + input=group_inputs, |
| 119 | + bos_id=0, |
| 120 | + eos_id=1, |
| 121 | + beam_size=beam_size, |
| 122 | + max_length=max_length) |
| 123 | + |
| 124 | + return beam_gen |
94 | 125 |
|
95 | 126 |
|
96 | 127 | def main():
|
|
0 commit comments