From e5e29af9cba5ab3c85968233a45dbbff598aa0b2 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Sun, 17 Jan 2021 16:48:40 -0800 Subject: [PATCH 1/8] fix pre_ln --- src/gluonnlp/layers.py | 3 ++- src/gluonnlp/models/transformer.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/gluonnlp/layers.py b/src/gluonnlp/layers.py index ae94f0f07e..27f799540e 100644 --- a/src/gluonnlp/layers.py +++ b/src/gluonnlp/layers.py @@ -575,13 +575,14 @@ def forward(self, data): out : Shape (B, seq_length, C_out) """ + residual = data if self._pre_norm: data = self.layer_norm(data) out = self.activation(self.ffn_1(data)) out = self.activation_dropout_layer(out) out = self.ffn_2(out) out = self.dropout_layer(out) - out = out + data + out = out + residual if not self._pre_norm: out = self.layer_norm(out) return out diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index 7f912e0c64..0539e0b3ac 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -255,6 +255,7 @@ def forward(self, data, attn_mask): attn_weight Shape (batch_size, seq_length, seq_length) """ + residual = data if self._pre_norm: data = self.layer_norm(data) query, key, value = np.split(self.attn_qkv(data), 3, axis=-1) @@ -264,7 +265,7 @@ def forward(self, data, attn_mask): out, [_, attn_weight] = self.attention_cell(query, key, value, attn_mask) out = self.attention_proj(out) out = self.dropout_layer(out) - out = out + data + out = out + residual if not self._pre_norm: out = self.layer_norm(out) out = self.ffn(out) From b77b1e3aeefab3738b8d548e4246f49d99da7270 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Sun, 17 Jan 2021 17:02:40 -0800 Subject: [PATCH 2/8] update --- scripts/processing/clean_tok_mono_corpus.py | 2 +- scripts/processing/clean_tok_para_corpus.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/processing/clean_tok_mono_corpus.py b/scripts/processing/clean_tok_mono_corpus.py index 79416b4798..3be34c372d 100644 --- a/scripts/processing/clean_tok_mono_corpus.py +++ b/scripts/processing/clean_tok_mono_corpus.py @@ -213,7 +213,7 @@ def get_parser(): parser.add_argument('--discard-non-latin1', action='store_true', help='Whether to discard the sentence pair if both sentences cannot be ' 'encoded into latin1.') - parser.add_argument('--num-process', type=int, default=8, + parser.add_argument('--num-process', type=int, default=multiprocessing.cpu_count(), help='number of process') parser.add_argument('--overwrite', action='store_true') diff --git a/scripts/processing/clean_tok_para_corpus.py b/scripts/processing/clean_tok_para_corpus.py index c95c25e474..68b0b2fddc 100644 --- a/scripts/processing/clean_tok_para_corpus.py +++ b/scripts/processing/clean_tok_para_corpus.py @@ -261,7 +261,7 @@ def get_parser(): parser.add_argument('--discard-non-latin1', action='store_true', help='Whether to discard the sentence pair if both sentences cannot be ' 'encoded into latin1.') - parser.add_argument('--num-process', type=int, default=8, + parser.add_argument('--num-process', type=int, default=multiprocessing.cpu_count(), help='number of process') parser.add_argument('--overwrite', action='store_true') From 8a031a55ef81841e51fcd004d2b0f1da52b40535 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Sun, 17 Jan 2021 17:14:30 -0800 Subject: [PATCH 3/8] fix --- src/gluonnlp/models/transformer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index 0539e0b3ac..9ef49166cb 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -566,6 +566,7 @@ def forward(self, data, mem, self_causal_mask, mem_attn_mask): Shape (seq_length, batch_size, C_out) """ # 1. Get the causal self-attention value + residual = data if self._pre_norm: data = self.ln_in(data) self_query, self_key, self_value = np.split(self.attn_in_qkv(data), 3, axis=-1) @@ -576,11 +577,12 @@ def forward(self, data, mem, self_causal_mask, mem_attn_mask): self_causal_mask) out = self.proj_in(out) out = self.dropout_layer(out) - out = out + data + out = out + residual if not self._pre_norm: out = self.ln_in(out) # 2. Attend to the contextual memory data = out + residual = data if self._pre_norm: data = self.ln_inter(data) out, [_, context_attn_weight] = self.inter_attention( @@ -590,7 +592,7 @@ def forward(self, data, mem, self_causal_mask, mem_attn_mask): mem_attn_mask) out = self.proj_inter(out) out = self.dropout_layer(out) - out = out + data + out = out + residual if not self._pre_norm: out = self.ln_inter(out) # 3. Encode the output via an FFN layer From 698c0a7840a8de779ab81ab6ff05ec019766bc62 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Sun, 17 Jan 2021 17:20:49 -0800 Subject: [PATCH 4/8] fix --- src/gluonnlp/models/transformer.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index 9ef49166cb..8779f15aa8 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -684,6 +684,7 @@ def incremental_decode(self, data, states, mem, mem_valid_length, mem_attn_mask= Shape (batch_size, prev_seq_length + 1, num_heads, C_value) """ + residual = data if self._pre_norm: data = self.ln_in(data) if self.layout == 'NT': @@ -711,11 +712,12 @@ def incremental_decode(self, data, states, mem, mem_valid_length, mem_attn_mask= out, [_, attn_weight] = self.self_attention(step_query, new_key, new_value, None) out = self.proj_in(out) out = self.dropout_layer(out) - out = out + data + out = out + residual if not self._pre_norm: out = self.ln_in(out) # 2. Attend to the contextual memory data = out + residual = data if self._pre_norm: data = self.ln_inter(data) out, _ = self.inter_attention(npx.reshape(self.attn_inter_q(data), @@ -727,7 +729,7 @@ def incremental_decode(self, data, states, mem, mem_valid_length, mem_attn_mask= mem_attn_mask) out = self.proj_inter(out) out = self.dropout_layer(out) - out = out + data + out = out + residual if not self._pre_norm: out = self.ln_inter(out) # 3. Encode the output via an FFN layer From da6f3ff0c7f4cb6c6bcd1db15eccc7e4a6a42e20 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Sun, 17 Jan 2021 17:25:43 -0800 Subject: [PATCH 5/8] fix --- src/gluonnlp/models/transformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index 8779f15aa8..26428a9f94 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -684,14 +684,14 @@ def incremental_decode(self, data, states, mem, mem_valid_length, mem_attn_mask= Shape (batch_size, prev_seq_length + 1, num_heads, C_value) """ - residual = data - if self._pre_norm: - data = self.ln_in(data) if self.layout == 'NT': time_axis = 1 else: time_axis = 0 data = np.expand_dims(data, axis=time_axis) + residual = data + if self._pre_norm: + data = self.ln_in(data) # Shape (B, prev_L, #Head, C_K), (B, prev_L, #Head, C_V) # or (prev_L, B, #Head, C_K), (prev_L, B, #Head, C_V) prev_key, prev_value = states From caf2305ffc564980b124ba73b5aa6a1c4f6f6677 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Sun, 17 Jan 2021 22:12:50 -0800 Subject: [PATCH 6/8] fix document --- src/gluonnlp/data/batchify.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/gluonnlp/data/batchify.py b/src/gluonnlp/data/batchify.py index e854fe1670..8e1ffafb44 100644 --- a/src/gluonnlp/data/batchify.py +++ b/src/gluonnlp/data/batchify.py @@ -187,8 +187,8 @@ class Pad: val : float or int, default 0 The padding value. axis : int, default 0 - The axis to pad the arrays. The arrays will be padded to the largest dimension at - `axis`. For example, assume the input arrays have shape + The axis to pad the arrays. The arrays will be padded to the largest possible dimension, + and then stack at `axis`. For example, assume the input arrays have shape (10, 8, 5), (6, 8, 5), (3, 8, 5) and the `axis` is 0. Each input will be padded into (10, 8, 5) and then stacked to form the final output, which has shape(3, 10, 8, 5). dtype : str or numpy.dtype, default None From fb630ac14a7855f106cd281d22cca2259c47fc6f Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Mon, 18 Jan 2021 09:53:50 -0800 Subject: [PATCH 7/8] fix doc --- scripts/machine_translation/README.md | 1 + src/gluonnlp/layers.py | 3 +-- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/machine_translation/README.md b/scripts/machine_translation/README.md index f15766535e..e5ff11306d 100644 --- a/scripts/machine_translation/README.md +++ b/scripts/machine_translation/README.md @@ -254,6 +254,7 @@ horovodrun -np 4 -H localhost:4 python3 train_transformer.py \ --warmup_steps 4000 \ --warmup_init_lr 0.0 \ --seed 123 \ + --max_grad_norm 1.0 \ --fp16 ``` diff --git a/src/gluonnlp/layers.py b/src/gluonnlp/layers.py index 27f799540e..20d6b4a75a 100644 --- a/src/gluonnlp/layers.py +++ b/src/gluonnlp/layers.py @@ -489,7 +489,7 @@ class PositionwiseFFN(HybridBlock): """The Position-wise FFN layer used in Transformer-like architectures If pre_norm is True: - norm(data) -> fc1 -> act -> act_dropout -> fc2 -> dropout -> res(+data) + data -> norm(data) -> fc1 -> act -> act_dropout -> fc2 -> dropout -> res(+data) Else: data -> fc1 -> act -> act_dropout -> fc2 -> dropout -> norm(res(+data)) """ @@ -566,7 +566,6 @@ def forward(self, data): Parameters ---------- - F data : Shape (B, seq_length, C_in) From 3c7c4c1b34ffc9404e7f457ad129f6a3a7424baa Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Tue, 19 Jan 2021 09:35:08 -0800 Subject: [PATCH 8/8] Update transformer.py --- src/gluonnlp/models/transformer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index 26428a9f94..0cd18ba1d8 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -125,6 +125,7 @@ def transformer_wmt_en_de_big_t2t(): cfg.defrost() cfg.MODEL.attention_dropout = 0.1 cfg.MODEL.activation_dropout = 0.1 + cfg.MODEL.dropout = 0.1 cfg.MODEL.ENCODER.pre_norm = True cfg.MODEL.DECODER.pre_norm = True cfg.freeze()