diff --git a/scripts/machine_translation/README.md b/scripts/machine_translation/README.md index f15766535e..e5ff11306d 100644 --- a/scripts/machine_translation/README.md +++ b/scripts/machine_translation/README.md @@ -254,6 +254,7 @@ horovodrun -np 4 -H localhost:4 python3 train_transformer.py \ --warmup_steps 4000 \ --warmup_init_lr 0.0 \ --seed 123 \ + --max_grad_norm 1.0 \ --fp16 ``` diff --git a/scripts/processing/clean_tok_mono_corpus.py b/scripts/processing/clean_tok_mono_corpus.py index 79416b4798..3be34c372d 100644 --- a/scripts/processing/clean_tok_mono_corpus.py +++ b/scripts/processing/clean_tok_mono_corpus.py @@ -213,7 +213,7 @@ def get_parser(): parser.add_argument('--discard-non-latin1', action='store_true', help='Whether to discard the sentence pair if both sentences cannot be ' 'encoded into latin1.') - parser.add_argument('--num-process', type=int, default=8, + parser.add_argument('--num-process', type=int, default=multiprocessing.cpu_count(), help='number of process') parser.add_argument('--overwrite', action='store_true') diff --git a/scripts/processing/clean_tok_para_corpus.py b/scripts/processing/clean_tok_para_corpus.py index c95c25e474..68b0b2fddc 100644 --- a/scripts/processing/clean_tok_para_corpus.py +++ b/scripts/processing/clean_tok_para_corpus.py @@ -261,7 +261,7 @@ def get_parser(): parser.add_argument('--discard-non-latin1', action='store_true', help='Whether to discard the sentence pair if both sentences cannot be ' 'encoded into latin1.') - parser.add_argument('--num-process', type=int, default=8, + parser.add_argument('--num-process', type=int, default=multiprocessing.cpu_count(), help='number of process') parser.add_argument('--overwrite', action='store_true') diff --git a/src/gluonnlp/data/batchify.py b/src/gluonnlp/data/batchify.py index e854fe1670..8e1ffafb44 100644 --- a/src/gluonnlp/data/batchify.py +++ b/src/gluonnlp/data/batchify.py @@ -187,8 +187,8 @@ class Pad: val : float or int, default 0 The padding value. axis : int, default 0 - The axis to pad the arrays. The arrays will be padded to the largest dimension at - `axis`. For example, assume the input arrays have shape + The axis to pad the arrays. The arrays will be padded to the largest possible dimension, + and then stack at `axis`. For example, assume the input arrays have shape (10, 8, 5), (6, 8, 5), (3, 8, 5) and the `axis` is 0. Each input will be padded into (10, 8, 5) and then stacked to form the final output, which has shape(3, 10, 8, 5). dtype : str or numpy.dtype, default None diff --git a/src/gluonnlp/layers.py b/src/gluonnlp/layers.py index ae94f0f07e..20d6b4a75a 100644 --- a/src/gluonnlp/layers.py +++ b/src/gluonnlp/layers.py @@ -489,7 +489,7 @@ class PositionwiseFFN(HybridBlock): """The Position-wise FFN layer used in Transformer-like architectures If pre_norm is True: - norm(data) -> fc1 -> act -> act_dropout -> fc2 -> dropout -> res(+data) + data -> norm(data) -> fc1 -> act -> act_dropout -> fc2 -> dropout -> res(+data) Else: data -> fc1 -> act -> act_dropout -> fc2 -> dropout -> norm(res(+data)) """ @@ -566,7 +566,6 @@ def forward(self, data): Parameters ---------- - F data : Shape (B, seq_length, C_in) @@ -575,13 +574,14 @@ def forward(self, data): out : Shape (B, seq_length, C_out) """ + residual = data if self._pre_norm: data = self.layer_norm(data) out = self.activation(self.ffn_1(data)) out = self.activation_dropout_layer(out) out = self.ffn_2(out) out = self.dropout_layer(out) - out = out + data + out = out + residual if not self._pre_norm: out = self.layer_norm(out) return out diff --git a/src/gluonnlp/models/transformer.py b/src/gluonnlp/models/transformer.py index 7f912e0c64..0cd18ba1d8 100644 --- a/src/gluonnlp/models/transformer.py +++ b/src/gluonnlp/models/transformer.py @@ -125,6 +125,7 @@ def transformer_wmt_en_de_big_t2t(): cfg.defrost() cfg.MODEL.attention_dropout = 0.1 cfg.MODEL.activation_dropout = 0.1 + cfg.MODEL.dropout = 0.1 cfg.MODEL.ENCODER.pre_norm = True cfg.MODEL.DECODER.pre_norm = True cfg.freeze() @@ -255,6 +256,7 @@ def forward(self, data, attn_mask): attn_weight Shape (batch_size, seq_length, seq_length) """ + residual = data if self._pre_norm: data = self.layer_norm(data) query, key, value = np.split(self.attn_qkv(data), 3, axis=-1) @@ -264,7 +266,7 @@ def forward(self, data, attn_mask): out, [_, attn_weight] = self.attention_cell(query, key, value, attn_mask) out = self.attention_proj(out) out = self.dropout_layer(out) - out = out + data + out = out + residual if not self._pre_norm: out = self.layer_norm(out) out = self.ffn(out) @@ -565,6 +567,7 @@ def forward(self, data, mem, self_causal_mask, mem_attn_mask): Shape (seq_length, batch_size, C_out) """ # 1. Get the causal self-attention value + residual = data if self._pre_norm: data = self.ln_in(data) self_query, self_key, self_value = np.split(self.attn_in_qkv(data), 3, axis=-1) @@ -575,11 +578,12 @@ def forward(self, data, mem, self_causal_mask, mem_attn_mask): self_causal_mask) out = self.proj_in(out) out = self.dropout_layer(out) - out = out + data + out = out + residual if not self._pre_norm: out = self.ln_in(out) # 2. Attend to the contextual memory data = out + residual = data if self._pre_norm: data = self.ln_inter(data) out, [_, context_attn_weight] = self.inter_attention( @@ -589,7 +593,7 @@ def forward(self, data, mem, self_causal_mask, mem_attn_mask): mem_attn_mask) out = self.proj_inter(out) out = self.dropout_layer(out) - out = out + data + out = out + residual if not self._pre_norm: out = self.ln_inter(out) # 3. Encode the output via an FFN layer @@ -681,13 +685,14 @@ def incremental_decode(self, data, states, mem, mem_valid_length, mem_attn_mask= Shape (batch_size, prev_seq_length + 1, num_heads, C_value) """ - if self._pre_norm: - data = self.ln_in(data) if self.layout == 'NT': time_axis = 1 else: time_axis = 0 data = np.expand_dims(data, axis=time_axis) + residual = data + if self._pre_norm: + data = self.ln_in(data) # Shape (B, prev_L, #Head, C_K), (B, prev_L, #Head, C_V) # or (prev_L, B, #Head, C_K), (prev_L, B, #Head, C_V) prev_key, prev_value = states @@ -708,11 +713,12 @@ def incremental_decode(self, data, states, mem, mem_valid_length, mem_attn_mask= out, [_, attn_weight] = self.self_attention(step_query, new_key, new_value, None) out = self.proj_in(out) out = self.dropout_layer(out) - out = out + data + out = out + residual if not self._pre_norm: out = self.ln_in(out) # 2. Attend to the contextual memory data = out + residual = data if self._pre_norm: data = self.ln_inter(data) out, _ = self.inter_attention(npx.reshape(self.attn_inter_q(data), @@ -724,7 +730,7 @@ def incremental_decode(self, data, states, mem, mem_valid_length, mem_attn_mask= mem_attn_mask) out = self.proj_inter(out) out = self.dropout_layer(out) - out = out + data + out = out + residual if not self._pre_norm: out = self.ln_inter(out) # 3. Encode the output via an FFN layer