From dce05e114a75b2128959258e9dc7313dd0e5afc7 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Tue, 31 Dec 2019 07:46:52 +0000 Subject: [PATCH 01/26] Add machine translation estimator --- src/gluonnlp/estimator/__init__.py | 25 ++++++++ .../machine_translation_batch_processor.py | 60 +++++++++++++++++++ .../machine_translation_estimator.py | 50 ++++++++++++++++ .../machine_translation_event_handler.py | 59 ++++++++++++++++++ 4 files changed, 194 insertions(+) create mode 100644 src/gluonnlp/estimator/__init__.py create mode 100644 src/gluonnlp/estimator/machine_translation_batch_processor.py create mode 100644 src/gluonnlp/estimator/machine_translation_estimator.py create mode 100644 src/gluonnlp/estimator/machine_translation_event_handler.py diff --git a/src/gluonnlp/estimator/__init__.py b/src/gluonnlp/estimator/__init__.py new file mode 100644 index 0000000000..3dca823d2c --- /dev/null +++ b/src/gluonnlp/estimator/__init__.py @@ -0,0 +1,25 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=eval-used, redefined-outer-name + +""" Gluon NLP Estimator Module """ +from .machine_translation_estimator import * +from .machine_translation_event_handler import * +from .machine_translation_batch_processor import * + +__all__ = (machine_translation_estimator.__all__ + machine_translation_event_handler.__all__ + + machine_translation_batch_processor.__all__) diff --git a/src/gluonnlp/estimator/machine_translation_batch_processor.py b/src/gluonnlp/estimator/machine_translation_batch_processor.py new file mode 100644 index 0000000000..ac69c3d190 --- /dev/null +++ b/src/gluonnlp/estimator/machine_translation_batch_processor.py @@ -0,0 +1,60 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=eval-used, redefined-outer-name +""" Gluon Machine Translation Batch Processor """ + +import numpy as np +import mxnet as mx +from mxnet.gluon.contrib.estimator import BatchProcessor +from mxnet.gluon.utils import split_and_load + +__all__ = ['ParallelMachineTranslationBatchProcessor'] + +class ParallelMachineTranslationBatchProcessor(BatchProcessor): + def __init__(self, rescale_loss=100, batch_size=1024): + self.rescale_loss = rescale_loss + self.batch_size = batch_size + + def fit_batch(self, estimator, train_batch, batch_axis=0): + data = [shard[0] for shard in train_batch] + target = [shard[1] for shard in train_batch] + src_word_count, tgt_word_count, bs = np.sum([(shard[2].sum(), + shard[3].sum(), shard[0].shape[0]) for shard in train_batch], + axis=0) + estimator.tgt_valid_length = tgt_word_count.asscalar() - bs + seqs = [[seq.as_in_context(context) for seq in shard] + for context, shard in zip(estimator.context, train_batch)] + Ls = [] + for seq in seqs: + estimator.net.put((seq, self.batch_size)) + Ls = [self.estimator.get() for _ in range(len(estimator.context))] + return data, target, None, Ls + + def evaluate_batch(self, estimator, val_batch, batch_axis=0): + ctx = estimator.context[0] + src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids = val_batch + src_seq = src_seq.as_in_context(ctx) + tgt_seq = tgt_seq.as_in_context(ctx) + src_valid_length = src_valid_length.as_in_context(ctx) + tgt_valid_length = tgt_valid_length.as_in_context(ctx) + + out, _ = self.eval_net(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) + loss = self.evaluation_loss(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean().asscalar() + inst_ids = inst_ids.asnumpy().astype(np.int32).tolist() + loss *= (tgt_seq.shape[1] - 1) + estimator.val_tgt_valid_length = tgt_seq.shape[1] - 1 + return src_seq, tgt_seq, out, loss diff --git a/src/gluonnlp/estimator/machine_translation_estimator.py b/src/gluonnlp/estimator/machine_translation_estimator.py new file mode 100644 index 0000000000..fa6a12dada --- /dev/null +++ b/src/gluonnlp/estimator/machine_translation_estimator.py @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=eval-used, redefined-outer-name +""" Gluon Machine Translation Estimator """ + +import copy +import warnings + +import numpy as np +import mxnet as mx +from mxnet.gluon.contrib.estimator import Estimator +from .machine_translation_batch_processor import MachineTranslationBatchProcessor + +__all__ = ['MachineTranslationEstimator'] + +class MachineTranslationEstimator(Estimator): + def __init__(self, net, loss, + train_metrics=None, + initializer=None, + trainer=None, + context=None, + evaluation_loss=None, + eval_net=None, + batch_processor=MachineTranslationBatchProcessor()): + super().__init__(net=net, loss=loss, + train_metrics=train_metrics, + val_metrics=val_metrics, + initializer=initializer, + trainer=trainer, + context=context, + evaluation_loss=evaluation_loss, + eval_net=eval_net, + batch_processor=batch_processor) + self.tgt_valid_length = 0 + self.val_tgt_valid_length = 0 + self.avg_param = None diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py new file mode 100644 index 0000000000..13e0021303 --- /dev/null +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=eval-used, redefined-outer-name +""" Gluon Machine Translation Event Handler """ + +import copy +import warnings + +import mxnet as mx +from mxnet.gluon.contrib.estimator import TrainBegin, TrainEnd, EpochBegin +from mxnet.gluon.contrib.estimator import EpochEnd, BatchBegin, BatchEnc +from mxnet.gluon.contrib.estimator import GradientUpdateHandler +from mxnet.gluon.contrib.estimator import MetricHandler + +__all__ = ['AvgParamUpdateHandler'] + +class AvgParamUpdateHandler(BatchEnd, EpochEnd): + def __init__(self, avg_start, grad_interval=1): + self.batch_id = 0 + self.grad_interval = grad_interval + self.step_num = 0 + self.avg_start = avg_start + + def _update_avg_param(self, estimator): + if estimator.avg_param is None: + # estimator.net is parallel model estimator.net._model is the model + # to be investigated on + estimator.avg_param = {k:v.data(estimator.context[0]).copy() for k, v in + estimator.net._model.collect_params().items()} + if self.step_num > self.avg_start: + params = estimator.net._model.collect_params() + alpha = 1. / max(1, self.step_num - self.avg_start) + for key, val in estimator.avg_param.items(): + estimator.avg_param[:] += alpha * + (params[key].data(estimator.context[0]) - + val) + + def batch_end(self, estimator, *args, **kwargs): + if self.batch_id % self.grad_interval == 0: + self.step_num += 1 + if self.batch_id % self.grad_interval == self.grad_interval - 1: + _update_avg_param(estimator) + + def epoch_end(self, estimator, *args, **kwargs): + _update_avg_param(estimator) From 47889852047f2a9e4d6af0e645101370425f6cd0 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Thu, 9 Jan 2020 09:48:10 +0000 Subject: [PATCH 02/26] Add some files to machine translation estimator --- .../train_transformer_estimator.py | 215 ++++++++++++++++++ src/gluonnlp/__init__.py | 4 +- src/gluonnlp/estimator/__init__.py | 4 +- .../estimator/length_normalized_loss.py | 71 ++++++ .../machine_translation_event_handler.py | 131 ++++++++++- 5 files changed, 420 insertions(+), 5 deletions(-) create mode 100644 scripts/machine_translation/train_transformer_estimator.py create mode 100644 src/gluonnlp/estimator/length_normalized_loss.py diff --git a/scripts/machine_translation/train_transformer_estimator.py b/scripts/machine_translation/train_transformer_estimator.py new file mode 100644 index 0000000000..65c099e3d2 --- /dev/null +++ b/scripts/machine_translation/train_transformer_estimator.py @@ -0,0 +1,215 @@ +""" +Transformer +================================= + +This example shows how to implement the Transformer model with Gluon NLP Toolkit. + +@inproceedings{vaswani2017attention, + title={Attention is all you need}, + author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, + Llion and Gomez, Aidan N and Kaiser, Lukasz and Polosukhin, Illia}, + booktitle={Advances in Neural Information Processing Systems}, + pages={6000--6010}, + year={2017} +} +""" + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name,logging-format-interpolation + +import argparse +import logging +import math +import os +import random +import time + +import numpy as np +import mxnet as mx +from mxnet import gluon + +import gluonnlp as nlp +from gluonnlp.loss import LabelSmoothing, MaskedSoftmaxCELoss +from gluonnlp.model.transformer import ParallelTransformer, get_transformer_encoder_decoder +from gluonnlp.model.translation import NMTModel +from gluonnlp.utils.parallel import Parallel +import dataprocessor +from bleu import _bpe_to_words, compute_bleu +from translation import BeamSearchTranslator +from utils import logging_config +from gluonnlp.estimator import MachineTranslationEstimator +from gluonnlp.estimator import LengthNormalizedLoss + +np.random.seed(100) +random.seed(100) +mx.random.seed(10000) + +nlp.utils.check_version('0.9.0') + +parser = argparse.ArgumentParser( + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + description='Neural Machine Translation Example with the Transformer Model.') +parser.add_argument('--dataset', type=str.upper, default='WMT2016BPE', help='Dataset to use.', + choices=['IWSLT2015', 'WMT2016BPE', 'WMT2014BPE', 'TOY']) +parser.add_argument('--src_lang', type=str, default='en', help='Source language') +parser.add_argument('--tgt_lang', type=str, default='de', help='Target language') +parser.add_argument('--epochs', type=int, default=10, help='upper epoch limit') +parser.add_argument('--num_units', type=int, default=512, help='Dimension of the embedding ' + 'vectors and states.') +parser.add_argument('--hidden_size', type=int, default=2048, + help='Dimension of the hidden state in position-wise feed-forward networks.') +parser.add_argument('--dropout', type=float, default=0.1, + help='dropout applied to layers (0 = no dropout)') +parser.add_argument('--epsilon', type=float, default=0.1, + help='epsilon parameter for label smoothing') +parser.add_argument('--num_layers', type=int, default=6, + help='number of layers in the encoder and decoder') +parser.add_argument('--num_heads', type=int, default=8, + help='number of heads in multi-head attention') +parser.add_argument('--scaled', action='store_true', help='Turn on to use scale in attention') +parser.add_argument('--batch_size', type=int, default=1024, + help='Batch size. Number of tokens per gpu in a minibatch') +parser.add_argument('--beam_size', type=int, default=4, help='Beam size') +parser.add_argument('--lp_alpha', type=float, default=0.6, + help='Alpha used in calculating the length penalty') +parser.add_argument('--lp_k', type=int, default=5, help='K used in calculating the length penalty') +parser.add_argument('--test_batch_size', type=int, default=256, help='Test batch size') +parser.add_argument('--num_buckets', type=int, default=10, help='Bucket number') +parser.add_argument('--bucket_scheme', type=str, default='constant', + help='Strategy for generating bucket keys. It supports: ' + '"constant": all the buckets have the same width; ' + '"linear": the width of bucket increases linearly; ' + '"exp": the width of bucket increases exponentially') +parser.add_argument('--bucket_ratio', type=float, default=0.0, help='Ratio for increasing the ' + 'throughput of the bucketing') +parser.add_argument('--src_max_len', type=int, default=-1, help='Maximum length of the source ' + 'sentence, -1 means no clipping') +parser.add_argument('--tgt_max_len', type=int, default=-1, help='Maximum length of the target ' + 'sentence, -1 means no clipping') +parser.add_argument('--optimizer', type=str, default='adam', help='optimization algorithm') +parser.add_argument('--lr', type=float, default=1.0, help='Initial learning rate') +parser.add_argument('--warmup_steps', type=float, default=4000, + help='number of warmup steps used in NOAM\'s stepsize schedule') +parser.add_argument('--num_accumulated', type=int, default=1, + help='Number of steps to accumulate the gradients. ' + 'This is useful to mimic large batch training with limited gpu memory') +parser.add_argument('--magnitude', type=float, default=3.0, + help='Magnitude of Xavier initialization') +parser.add_argument('--average_checkpoint', action='store_true', + help='Turn on to perform final testing based on ' + 'the average of last few checkpoints') +parser.add_argument('--num_averages', type=int, default=5, + help='Perform final testing based on the ' + 'average of last num_averages checkpoints. ' + 'This is only used if average_checkpoint is True') +parser.add_argument('--average_start', type=int, default=5, + help='Perform average SGD on last average_start epochs') +parser.add_argument('--full', action='store_true', + help='In default, we use the test dataset in' + ' http://statmt.org/wmt14/test-filtered.tgz.' + ' When the option full is turned on, we use the test dataset in' + ' http://statmt.org/wmt14/test-full.tgz') +parser.add_argument('--bleu', type=str, default='tweaked', + help='Schemes for computing bleu score. It can be: ' + '"tweaked": it uses similar steps in get_ende_bleu.sh in tensor2tensor ' + 'repository, where compound words are put in ATAT format; ' + '"13a": This uses official WMT tokenization and produces the same results' + ' as official script (mteval-v13a.pl) used by WMT; ' + '"intl": This use international tokenization in mteval-v14a.pl') +parser.add_argument('--log_interval', type=int, default=100, metavar='N', + help='report interval') +parser.add_argument('--save_dir', type=str, default='transformer_out', + help='directory path to save the final model and training log') +parser.add_argument('--gpus', type=str, + help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu.' + '(using single gpu is suggested)') +args = parser.parse_args() +logging_config(args.save_dir) +logging.info(args) + + +data_train, data_val, data_test, val_tgt_sentences, test_tgt_sentences, src_vocab, tgt_vocab \ + = dataprocessor.load_translation_data(dataset=args.dataset, bleu=args.bleu, args=args) + +dataprocessor.write_sentences(val_tgt_sentences, os.path.join(args.save_dir, 'val_gt.txt')) +dataprocessor.write_sentences(test_tgt_sentences, os.path.join(args.save_dir, 'test_gt.txt')) + +data_train = data_train.transform(lambda src, tgt: (src, tgt, len(src), len(tgt)), lazy=False) +data_val = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i) + for i, ele in enumerate(data_val)]) +data_test = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i) + for i, ele in enumerate(data_test)]) + +ctx = [mx.cpu()] if args.gpus is None or args.gpus == '' else \ + [mx.gpu(int(x)) for x in args.gpus.split(',')] +num_ctxs = len(ctx) + +data_train_lengths, data_val_lengths, data_test_lengths = [dataprocessor.get_data_lengths(x) + for x in + [data_train, data_val, data_test]] + +if args.src_max_len <= 0 or args.tgt_max_len <= 0: + max_len = np.max( + [np.max(data_train_lengths, axis=0), np.max(data_val_lengths, axis=0), + np.max(data_test_lengths, axis=0)], + axis=0) +if args.src_max_len > 0: + src_max_len = args.src_max_len +else: + src_max_len = max_len[0] +if args.tgt_max_len > 0: + tgt_max_len = args.tgt_max_len +else: + tgt_max_len = max_len[1] +encoder, decoder, one_step_ahead_decoder = get_transformer_encoder_decoder( + units=args.num_units, hidden_size=args.hidden_size, dropout=args.dropout, + num_layers=args.num_layers, num_heads=args.num_heads, max_src_length=max(src_max_len, 500), + max_tgt_length=max(tgt_max_len, 500), scaled=args.scaled) +model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder, + one_step_ahead_decoder=one_step_ahead_decoder, + share_embed=args.dataset not in ('TOY', 'IWSLT2015'), embed_size=args.num_units, + tie_weights=args.dataset not in ('TOY', 'IWSLT2015'), embed_initializer=None, + prefix='transformer_') +model.initialize(init=mx.init.Xavier(magnitude=args.magnitude), ctx=ctx) +static_alloc = True +model.hybridize(static_alloc=static_alloc) +logging.info(model) + +translator = BeamSearchTranslator(model=model, beam_size=args.beam_size, + scorer=nlp.model.BeamSearchScorer(alpha=args.lp_alpha, + K=args.lp_k), + max_length=200) +logging.info('Use beam_size={}, alpha={}, K={}'.format(args.beam_size, args.lp_alpha, args.lp_k)) + +label_smoothing = LabelSmoothing(epsilon=args.epsilon, units=len(tgt_vocab)) +label_smoothing.hybridize(static_alloc=static_alloc) + +loss_function = MaskedSoftmaxCELoss(sparse_label=False) +loss_function.hybridize(static_alloc=static_alloc) + +test_loss_function = MaskedSoftmaxCELoss() +test_loss_function.hybridize(static_alloc=static_alloc) + +rescale_loss = 100. +parallel_model = ParallelTransformer(model, label_smoothing, loss_function, rescale_loss) +detokenizer = nlp.data.SacreMosesDetokenizer() + +loss = LengthNormalizedLoss() +train_metric = mx.metric.Loss(loss) + +mt_estimator = MachineTranslationEstimator(net=parallel_model, loss=loss_function) diff --git a/src/gluonnlp/__init__.py b/src/gluonnlp/__init__.py index 7a588e8233..f9772b95fc 100644 --- a/src/gluonnlp/__init__.py +++ b/src/gluonnlp/__init__.py @@ -30,6 +30,7 @@ from . import vocab from . import optimizer from . import initializer +from . import estimator from .vocab import Vocab __version__ = '0.10.0.dev' @@ -43,7 +44,8 @@ 'initializer', 'optimizer', 'utils', - 'metric'] + 'metric', + 'estimator'] warnings.filterwarnings(module='gluonnlp', action='default', category=DeprecationWarning) utils.version.check_version('1.6.0', warning_only=True, library=mxnet) diff --git a/src/gluonnlp/estimator/__init__.py b/src/gluonnlp/estimator/__init__.py index 3dca823d2c..12c5b14769 100644 --- a/src/gluonnlp/estimator/__init__.py +++ b/src/gluonnlp/estimator/__init__.py @@ -20,6 +20,8 @@ from .machine_translation_estimator import * from .machine_translation_event_handler import * from .machine_translation_batch_processor import * +from .length_normalized_loss import * __all__ = (machine_translation_estimator.__all__ + machine_translation_event_handler.__all__ - + machine_translation_batch_processor.__all__) + + machine_translation_batch_processor.__all__ + + length_normalized_loss.__all__) diff --git a/src/gluonnlp/estimator/length_normalized_loss.py b/src/gluonnlp/estimator/length_normalized_loss.py new file mode 100644 index 0000000000..5e7282c4cd --- /dev/null +++ b/src/gluonnlp/estimator/length_normalized_loss.py @@ -0,0 +1,71 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Length Normalized Loss """ + +from mxnet import ndarray +from ..metric.masked_accuracy import EvalMetric + +__all__ = ['LengthNormalizedLoss'] + +class LengthNormalizedLoss(EvalMetric): + """Compute length normalized loss metrics + + Parameters + ---------- + axis : int, default=1 + The axis that represents classes + name : str + Name of this metric instance for display. + output_names : list of str, or None + Name of predictions that should be used when updating with update_dict. + By default include all predictions. + label_names : list of str, or None + Name of labels that should be used when updating with update_dict. + By default include all labels. + """ + def __init__(self, axis=0, name='length-normalized-loss', + output_names=None, label_names=None): + super(LengthNormalizedLoss, self).__init__( + name, axis=axis, + output_names=output_names, label_names=label_names, + has_global_stats=True) + + # Parameter labels should be a list in the form of [target_sequence, + # target_seqauence_valid_length] + def update(self, labels, preds): + if not isinstance(labels, list) or len(labels) != 2: + raise ValueError('labels must be a list. Its first element should be' + ' target sequence and the second element should be' + 'the valid length of sequence.') + + _, seq_valid_length = labels + + if not isinstance(seq_valid_length, list): + seq_valid_length = [seq_valid_length] + + if not isinstance(preds, list): + preds = [preds] + + for length in seq_valid_length: + total_length = ndarray.sum(length).asscalar() + self.num_inst += total_length + self.global_num_inst += total_length + + for pred in preds: + loss = ndarray.sum(pred).asscalar() + self.sum_metric += loss + self.global_sum_metric += loss diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py index 13e0021303..675387984c 100644 --- a/src/gluonnlp/estimator/machine_translation_event_handler.py +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -19,6 +19,7 @@ import copy import warnings +import math import mxnet as mx from mxnet.gluon.contrib.estimator import TrainBegin, TrainEnd, EpochBegin @@ -26,9 +27,11 @@ from mxnet.gluon.contrib.estimator import GradientUpdateHandler from mxnet.gluon.contrib.estimator import MetricHandler -__all__ = ['AvgParamUpdateHandler'] +__all__ = ['MTTransformerParamUpdateHandler', 'TransformerLearningRateHandler', + 'MTTransformerMetricHandler', 'TransformerGradientAccumulationHandler', + 'ComputeBleuHandler'] -class AvgParamUpdateHandler(BatchEnd, EpochEnd): +class MTTransformerParamUpdateHandler(EpochBegin, BatchEnd, EpochEnd): def __init__(self, avg_start, grad_interval=1): self.batch_id = 0 self.grad_interval = grad_interval @@ -38,7 +41,7 @@ def __init__(self, avg_start, grad_interval=1): def _update_avg_param(self, estimator): if estimator.avg_param is None: # estimator.net is parallel model estimator.net._model is the model - # to be investigated on + # embedded in the parallel model estimator.avg_param = {k:v.data(estimator.context[0]).copy() for k, v in estimator.net._model.collect_params().items()} if self.step_num > self.avg_start: @@ -49,11 +52,133 @@ def _update_avg_param(self, estimator): (params[key].data(estimator.context[0]) - val) + def epoch_begin(self, estimator, *args, **kwargs): + self.batch_id = 0 + def batch_end(self, estimator, *args, **kwargs): if self.batch_id % self.grad_interval == 0: self.step_num += 1 if self.batch_id % self.grad_interval == self.grad_interval - 1: _update_avg_param(estimator) + self.batch_id += 1 def epoch_end(self, estimator, *args, **kwargs): _update_avg_param(estimator) + + +class TransformerLearningRateHandler(EpochBegin, BatchBegin): + def __init__(self, lr, + num_units=512, + warmup_steps=4000, + grad_interval=1): + self.lr = lr + self.num_units = num_units + self.warmup_steps = warmup_steps + self.grad_interval = grad_interval + + def epoch_begin(self, estimator, *args, **kwargs): + self.batch_id = 0 + + def batch_begin(self, estimator, *args, **kwargs): + if self.batch_id % self.grad_interval == 0: + self.step_num += 1 + new_lr = self.lr / math.sqrt(self.num_units) * \ + min(1. / math.sqrt(self.step_num), self.step_num * + self.warmup_steps ** (-1.5)) + estimator.trainer.set_learning_rate(new_lr) + self.batch_id += 1 + +class TransformerGradientAccumulationHandler(TrainBegin, EpochBegin, BatchEnd): + def __init__(self, grad_interval=1, + batch_size=1024, + rescale_loss=100): + self.grad_interval = grad_interval + self.batch_size = batch_size + self.rescale_loss = rescale_loss + + def _update_gradient(self, estimator): + estimator.trainer.step(float(self.loss_denom) / + self.batch_size /self.rescale_loss) + params = estimator.net._model.collect_params() + params.zero_grad() + self.loss_denom = 0 + + def train_begin(self, estimator, *args, **kwargs): + params = estimator.net._model.collect_params() + params.setattr('grad_req', 'add') + params.zero_grad() + + def epoch_begin(self, estimator, *args, **kwargs): + self.batch_id = 0 + self.loss_denom = 0 + + def batch_end(self, estimator, *args, **kwargs): + self.loss_denom += estimator.tgt_valid_length + if self.batch_id % self.grad_interval == self.grad_interval - 1: + _update_gradient(estimator) + self.batch_id += 1 + + def epoch_end(self, estimator, *args, **kwargs): + if self.loss_denom > 0: + _update_gradient(estimator) + +class MTTransformerMetricHandler(MetricHandler, BatchBegin): + def __init__(self, grad_interval, *args, **kwargs): + super(MTTransformerMetricHandler, self).__init__(*args, **kwargs) + self.grad_interval = grad_interval + + def epoch_begin(self, estimator, *args, **kwargs): + self.batch_id = 0 + for metric in self.metrics: + metric.reset() + + def batch_begin(self, estimator, *args, **kwargs): + if self.batch_id % self.grad_interval == 0: + for metric in self.metrics: + metric.reset_local() + self.batch_id += 1 + +# A temporary workaround for computing the bleu function. After bleu is in the metric +# api, this event handler could be removed. +class ComputeBleuHandler(EpochEnd): + def __init__(self, + tgt_vocab, + tgt_sentence, + translator, + compute_bleu_fn, + tokenized, + tokenizer, + split_compound_word, + bpe): + self.tgt_vocab = tgt_vocab + self.tgt_sentence = tgt_sentence + self.translator = translator + self.compute_bleu_fn = compute_bleu_fn + self.tokenized = tokenized + self.tokenizer = tokenizer + self.split_compound_word = split_compound_word + self.bpe = bpe + + self.all_inst_ids = [] + self.translation_out = [] + + def batch_end(self, estimator, *args, **kwargs): + batch = kwargs['batch'] + label = kwargs['label'] + src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids = batch + self.all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist()) + samples, _, sample_valid_length = self.translator.translate( + src_seq=src_seq, src_valid_length=src_valid_length) + max_score_sample = samples[:, 0, :].asnumpy() + sample_valid_length = sample_valid_length[:, 0].asnumpy() + for i in range(max_score_sample.shape[0]): + self.translation_out.append( + [self.tgt_vocab.idx_to_token[ele] for ele in + max_score_sample[i][1:(sample_valid_length[i] - 1)]]) + + def epoch_end(self, estimator, *args, **kwargs): + self.real_translation_out = [None for _ in range(len(all_inst_ids))] + for ind, sentence in zip(self.all_inst_ids, self.translation_out): + self.real_translation_out[ind] = sentence + self.bleu_score, _, _, _, _ = self.compute_bleu_fn([self.tgt_sentence], + self.real_translation_out) From db30a5f1d9456bdf38d7480a2d79e51686e7dcc7 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Thu, 9 Jan 2020 10:52:13 +0000 Subject: [PATCH 03/26] modify machine translation estimator --- .../train_transformer_estimator.py | 62 +++++++++++++++++-- .../machine_translation_batch_processor.py | 4 +- .../machine_translation_estimator.py | 1 + .../machine_translation_event_handler.py | 6 +- 4 files changed, 65 insertions(+), 8 deletions(-) diff --git a/scripts/machine_translation/train_transformer_estimator.py b/scripts/machine_translation/train_transformer_estimator.py index 65c099e3d2..ff4bddd6ba 100644 --- a/scripts/machine_translation/train_transformer_estimator.py +++ b/scripts/machine_translation/train_transformer_estimator.py @@ -52,8 +52,10 @@ from bleu import _bpe_to_words, compute_bleu from translation import BeamSearchTranslator from utils import logging_config -from gluonnlp.estimator import MachineTranslationEstimator -from gluonnlp.estimator import LengthNormalizedLoss +from gluonnlp.estimator import MachineTranslationEstimator, LengthNormalizedLoss +from gluonnlp.estimator import MTTransformerBatchProcessor, MTTransformerParamUpdateHandler +from gluonnlp.estimator import TransformerLearningRateHandler, MTTransformerMetricHandler +from gluonnlp.estimator import TransformerGradientAccumulationHandler, ComputeBleuHandler np.random.seed(100) random.seed(100) @@ -209,7 +211,57 @@ parallel_model = ParallelTransformer(model, label_smoothing, loss_function, rescale_loss) detokenizer = nlp.data.SacreMosesDetokenizer() -loss = LengthNormalizedLoss() -train_metric = mx.metric.Loss(loss) +trainer = gluon.Trainer(model.collect_params(), args.optimizer, + {'learning_rate': args.lr, 'beta2': 0.98, 'epsilon': 1e-9}) -mt_estimator = MachineTranslationEstimator(net=parallel_model, loss=loss_function) +train_data_loader, val_data_loader, test_data_loader \ + = dataprocessor.make_dataloader(data_train, data_val, data_test, args, + use_average_length=True, num_shards=len(ctx)) + +if args.bleu == 'tweaked': + bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY') + split_compound_word = bpe + tokenized = True +elif args.bleu == '13a' or args.bleu == 'intl': + bpe = False + split_compound_word = False + tokenized = False +else: + raise NotImplementedError + +grad_interval = args.num_accumulated +average_start = (len(train_data_loader) // grad_interval) * (args.epochs - args.average_start) + +train_metric = LengthNormalizedLoss(loss_function) +val_metric = LengthNormalizedLoss(test_loss_function) + +mt_estimator = MachineTranslationEstimator(net=parallel_model, loss=loss_function, + train_metrics=train_metric, + val_metrics=val_metric, + trainer=trainer, + context=ctx, + evaluation_loss=test_loss_function, + eval_net=model, + batch_processor=MTTransformerBatchProcessor()) + +param_update_handler = MTTransformerParamUpdateHandler(avg_start=average_start, + grad_interval=grad_interval) +learning_rate_handler = TransformerLearningRateHandler(lr=args.lr, num_units=args.num_unit, + warmup_steps=args.warmup_steps, + grad_interval=grad_interval) +gradient_acc_handler = TransformerGradientAccumulationHandler(grad_interval=grad_interval, + batch_size=args.batch_size, + rescale_loss=rescale_loss) +metric_handler = MTTransformerMetricHandler(grad_interval=grad_interval) +bleu_handler = ComputeBleuHandler(tgt_vocab, tgt_sentence=val_tgt_sentences, + translator=translator, compute_bleu_fn=compute_bleu, + tokenized=tokenized, tokenizer=args.bleu, + split_compound_word=split_compound_word, + bpe=bpe) + +event_handlers = [param_update_handler, learning_rate_handler, gradient_acc_handler, + metric_handler, bleu_handler] + +mt_estimator.fit(train_data=train_dta_loader, val_data=val_data_loader, + epochs=args.epochs, event_handlers=event_handlers, + batch_axis=0) diff --git a/src/gluonnlp/estimator/machine_translation_batch_processor.py b/src/gluonnlp/estimator/machine_translation_batch_processor.py index ac69c3d190..036fbc55f5 100644 --- a/src/gluonnlp/estimator/machine_translation_batch_processor.py +++ b/src/gluonnlp/estimator/machine_translation_batch_processor.py @@ -22,9 +22,9 @@ from mxnet.gluon.contrib.estimator import BatchProcessor from mxnet.gluon.utils import split_and_load -__all__ = ['ParallelMachineTranslationBatchProcessor'] +__all__ = ['MTTransformerBatchProcessor'] -class ParallelMachineTranslationBatchProcessor(BatchProcessor): +class MTTransformerBatchProcessor(BatchProcessor): def __init__(self, rescale_loss=100, batch_size=1024): self.rescale_loss = rescale_loss self.batch_size = batch_size diff --git a/src/gluonnlp/estimator/machine_translation_estimator.py b/src/gluonnlp/estimator/machine_translation_estimator.py index fa6a12dada..e36e8444e2 100644 --- a/src/gluonnlp/estimator/machine_translation_estimator.py +++ b/src/gluonnlp/estimator/machine_translation_estimator.py @@ -30,6 +30,7 @@ class MachineTranslationEstimator(Estimator): def __init__(self, net, loss, train_metrics=None, + val_metrics=None, initializer=None, trainer=None, context=None, diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py index 675387984c..0b2fd660b7 100644 --- a/src/gluonnlp/estimator/machine_translation_event_handler.py +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -181,4 +181,8 @@ def epoch_end(self, estimator, *args, **kwargs): for ind, sentence in zip(self.all_inst_ids, self.translation_out): self.real_translation_out[ind] = sentence self.bleu_score, _, _, _, _ = self.compute_bleu_fn([self.tgt_sentence], - self.real_translation_out) + self.real_translation_out, + tokenized=self.tokenized, + tokenizer=self.tokenizer, + split_compound_word+self.split_compound_word, + bpe=self.bpe) From f3f00a6309b0ee11d153dc9738229e837db1a09a Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Fri, 10 Jan 2020 10:38:06 +0000 Subject: [PATCH 04/26] bug fix for transformer translator --- .../train_transformer_estimator.py | 16 ++++--- .../estimator/length_normalized_loss.py | 2 +- .../machine_translation_batch_processor.py | 27 +++++++++-- .../machine_translation_estimator.py | 4 +- .../machine_translation_event_handler.py | 47 ++++++++++++------- 5 files changed, 66 insertions(+), 30 deletions(-) diff --git a/scripts/machine_translation/train_transformer_estimator.py b/scripts/machine_translation/train_transformer_estimator.py index ff4bddd6ba..cc00533141 100644 --- a/scripts/machine_translation/train_transformer_estimator.py +++ b/scripts/machine_translation/train_transformer_estimator.py @@ -234,25 +234,29 @@ train_metric = LengthNormalizedLoss(loss_function) val_metric = LengthNormalizedLoss(test_loss_function) +batch_processor = MTTransformerBatchProcessor(rescale_loss=rescale_loss, + batch_size=args.batch_size, + label_smoothing=label_smoothing, + loss_function=loss_function) -mt_estimator = MachineTranslationEstimator(net=parallel_model, loss=loss_function, +mt_estimator = MachineTranslationEstimator(net=model, loss=loss_function, train_metrics=train_metric, val_metrics=val_metric, trainer=trainer, context=ctx, evaluation_loss=test_loss_function, - eval_net=model, - batch_processor=MTTransformerBatchProcessor()) + batch_processor=batch_processor) param_update_handler = MTTransformerParamUpdateHandler(avg_start=average_start, grad_interval=grad_interval) -learning_rate_handler = TransformerLearningRateHandler(lr=args.lr, num_units=args.num_unit, +learning_rate_handler = TransformerLearningRateHandler(lr=args.lr, num_units=args.num_units, warmup_steps=args.warmup_steps, grad_interval=grad_interval) gradient_acc_handler = TransformerGradientAccumulationHandler(grad_interval=grad_interval, batch_size=args.batch_size, rescale_loss=rescale_loss) -metric_handler = MTTransformerMetricHandler(grad_interval=grad_interval) +metric_handler = MTTransformerMetricHandler(metrics=mt_estimator.train_metrics, + grad_interval=grad_interval) bleu_handler = ComputeBleuHandler(tgt_vocab, tgt_sentence=val_tgt_sentences, translator=translator, compute_bleu_fn=compute_bleu, tokenized=tokenized, tokenizer=args.bleu, @@ -262,6 +266,6 @@ event_handlers = [param_update_handler, learning_rate_handler, gradient_acc_handler, metric_handler, bleu_handler] -mt_estimator.fit(train_data=train_dta_loader, val_data=val_data_loader, +mt_estimator.fit(train_data=train_data_loader, val_data=val_data_loader, epochs=args.epochs, event_handlers=event_handlers, batch_axis=0) diff --git a/src/gluonnlp/estimator/length_normalized_loss.py b/src/gluonnlp/estimator/length_normalized_loss.py index 5e7282c4cd..1b0a087f39 100644 --- a/src/gluonnlp/estimator/length_normalized_loss.py +++ b/src/gluonnlp/estimator/length_normalized_loss.py @@ -17,7 +17,7 @@ """ Length Normalized Loss """ from mxnet import ndarray -from ..metric.masked_accuracy import EvalMetric +from mxnet.metric import EvalMetric __all__ = ['LengthNormalizedLoss'] diff --git a/src/gluonnlp/estimator/machine_translation_batch_processor.py b/src/gluonnlp/estimator/machine_translation_batch_processor.py index 036fbc55f5..6be212b8da 100644 --- a/src/gluonnlp/estimator/machine_translation_batch_processor.py +++ b/src/gluonnlp/estimator/machine_translation_batch_processor.py @@ -21,15 +21,32 @@ import mxnet as mx from mxnet.gluon.contrib.estimator import BatchProcessor from mxnet.gluon.utils import split_and_load +from ..model.transformer import ParallelTransformer +from ..utils.parallel import Parallel __all__ = ['MTTransformerBatchProcessor'] class MTTransformerBatchProcessor(BatchProcessor): - def __init__(self, rescale_loss=100, batch_size=1024): + def __init__(self, rescale_loss=100, + batch_size=1024, + label_smoothing=None, + loss_function=None): self.rescale_loss = rescale_loss self.batch_size = batch_size + self.label_smoothing = label_smoothing + self.loss_function = loss_function + self.parallel_model = None + + def _get_parallel_model(self, estimator): + if self.label_smoothing is None or self.loss_function is None: + raise ValueError('label smoothing or loss function cannot be none.') + if self.parallel_model is None: + self.parallel_model = ParallelTransformer(estimator.net, self.label_smoothing, + self.loss_function, self.rescale_loss) + self.parallel_model = Parallel(len(estimator.context), self.parallel_model) def fit_batch(self, estimator, train_batch, batch_axis=0): + self._get_parallel_model(estimator) data = [shard[0] for shard in train_batch] target = [shard[1] for shard in train_batch] src_word_count, tgt_word_count, bs = np.sum([(shard[2].sum(), @@ -40,9 +57,9 @@ def fit_batch(self, estimator, train_batch, batch_axis=0): for context, shard in zip(estimator.context, train_batch)] Ls = [] for seq in seqs: - estimator.net.put((seq, self.batch_size)) - Ls = [self.estimator.get() for _ in range(len(estimator.context))] - return data, target, None, Ls + self.parallel_model.put((seq, self.batch_size)) + Ls = [self.parallel_model.get() for _ in range(len(estimator.context))] + return data, [target, tgt_word_count - bs], None, Ls def evaluate_batch(self, estimator, val_batch, batch_axis=0): ctx = estimator.context[0] @@ -55,6 +72,6 @@ def evaluate_batch(self, estimator, val_batch, batch_axis=0): out, _ = self.eval_net(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) loss = self.evaluation_loss(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean().asscalar() inst_ids = inst_ids.asnumpy().astype(np.int32).tolist() - loss *= (tgt_seq.shape[1] - 1) + loss = loss * (tgt_seq.shape[1] - 1) estimator.val_tgt_valid_length = tgt_seq.shape[1] - 1 return src_seq, tgt_seq, out, loss diff --git a/src/gluonnlp/estimator/machine_translation_estimator.py b/src/gluonnlp/estimator/machine_translation_estimator.py index e36e8444e2..3e87a357a0 100644 --- a/src/gluonnlp/estimator/machine_translation_estimator.py +++ b/src/gluonnlp/estimator/machine_translation_estimator.py @@ -23,7 +23,7 @@ import numpy as np import mxnet as mx from mxnet.gluon.contrib.estimator import Estimator -from .machine_translation_batch_processor import MachineTranslationBatchProcessor +from .machine_translation_batch_processor import MTTransformerBatchProcessor __all__ = ['MachineTranslationEstimator'] @@ -36,7 +36,7 @@ def __init__(self, net, loss, context=None, evaluation_loss=None, eval_net=None, - batch_processor=MachineTranslationBatchProcessor()): + batch_processor=MTTransformerBatchProcessor()): super().__init__(net=net, loss=loss, train_metrics=train_metrics, val_metrics=val_metrics, diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py index 0b2fd660b7..93c4f5e75d 100644 --- a/src/gluonnlp/estimator/machine_translation_event_handler.py +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -23,9 +23,11 @@ import mxnet as mx from mxnet.gluon.contrib.estimator import TrainBegin, TrainEnd, EpochBegin -from mxnet.gluon.contrib.estimator import EpochEnd, BatchBegin, BatchEnc +from mxnet.gluon.contrib.estimator import EpochEnd, BatchBegin, BatchEnd from mxnet.gluon.contrib.estimator import GradientUpdateHandler from mxnet.gluon.contrib.estimator import MetricHandler +from mxnet.metric import Loss as MetricLoss +from .length_normalized_loss import LengthNormalizedLoss __all__ = ['MTTransformerParamUpdateHandler', 'TransformerLearningRateHandler', 'MTTransformerMetricHandler', 'TransformerGradientAccumulationHandler', @@ -40,17 +42,15 @@ def __init__(self, avg_start, grad_interval=1): def _update_avg_param(self, estimator): if estimator.avg_param is None: - # estimator.net is parallel model estimator.net._model is the model - # embedded in the parallel model estimator.avg_param = {k:v.data(estimator.context[0]).copy() for k, v in - estimator.net._model.collect_params().items()} + estimator.net.collect_params().items()} if self.step_num > self.avg_start: - params = estimator.net._model.collect_params() + params = estimator.net.collect_params() alpha = 1. / max(1, self.step_num - self.avg_start) for key, val in estimator.avg_param.items(): - estimator.avg_param[:] += alpha * - (params[key].data(estimator.context[0]) - - val) + estimator.avg_param[:] += alpha * \ + (params[key].data(estimator.context[0]) - + val) def epoch_begin(self, estimator, *args, **kwargs): self.batch_id = 0 @@ -59,11 +59,11 @@ def batch_end(self, estimator, *args, **kwargs): if self.batch_id % self.grad_interval == 0: self.step_num += 1 if self.batch_id % self.grad_interval == self.grad_interval - 1: - _update_avg_param(estimator) + self._update_avg_param(estimator) self.batch_id += 1 def epoch_end(self, estimator, *args, **kwargs): - _update_avg_param(estimator) + self._update_avg_param(estimator) class TransformerLearningRateHandler(EpochBegin, BatchBegin): @@ -75,6 +75,7 @@ def __init__(self, lr, self.num_units = num_units self.warmup_steps = warmup_steps self.grad_interval = grad_interval + self.step_num = 0 def epoch_begin(self, estimator, *args, **kwargs): self.batch_id = 0 @@ -88,7 +89,9 @@ def batch_begin(self, estimator, *args, **kwargs): estimator.trainer.set_learning_rate(new_lr) self.batch_id += 1 -class TransformerGradientAccumulationHandler(TrainBegin, EpochBegin, BatchEnd): +class TransformerGradientAccumulationHandler(GradientUpdateHandler, + TrainBegin, + EpochBegin): def __init__(self, grad_interval=1, batch_size=1024, rescale_loss=100): @@ -99,12 +102,12 @@ def __init__(self, grad_interval=1, def _update_gradient(self, estimator): estimator.trainer.step(float(self.loss_denom) / self.batch_size /self.rescale_loss) - params = estimator.net._model.collect_params() + params = estimator.net.collect_params() params.zero_grad() self.loss_denom = 0 def train_begin(self, estimator, *args, **kwargs): - params = estimator.net._model.collect_params() + params = estimator.net.collect_params() params.setattr('grad_req', 'add') params.zero_grad() @@ -115,12 +118,12 @@ def epoch_begin(self, estimator, *args, **kwargs): def batch_end(self, estimator, *args, **kwargs): self.loss_denom += estimator.tgt_valid_length if self.batch_id % self.grad_interval == self.grad_interval - 1: - _update_gradient(estimator) + self._update_gradient(estimator) self.batch_id += 1 def epoch_end(self, estimator, *args, **kwargs): if self.loss_denom > 0: - _update_gradient(estimator) + self._update_gradient(estimator) class MTTransformerMetricHandler(MetricHandler, BatchBegin): def __init__(self, grad_interval, *args, **kwargs): @@ -138,6 +141,18 @@ def batch_begin(self, estimator, *args, **kwargs): metric.reset_local() self.batch_id += 1 + def batch_end(self, estimator, *args, **kwargs): + pred = kwargs['pred'] + label = kwargs['label'] + loss = kwargs['loss'] + for metric in self.metrics: + if isinstance(metric, MetricLoss): + metric.update(0, loss) + elif isinstance(metric, LengthNormalizedLoss): + metric.update(label, loss) + else: + metric.update(label, pred) + # A temporary workaround for computing the bleu function. After bleu is in the metric # api, this event handler could be removed. class ComputeBleuHandler(EpochEnd): @@ -184,5 +199,5 @@ def epoch_end(self, estimator, *args, **kwargs): self.real_translation_out, tokenized=self.tokenized, tokenizer=self.tokenizer, - split_compound_word+self.split_compound_word, + split_compound_word=self.split_compound_word, bpe=self.bpe) From ab94b381ddf1231b8fc70cf7fec2ced2d9c0034e Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Mon, 13 Jan 2020 10:26:00 +0000 Subject: [PATCH 05/26] fix bugs and add gnmt batch processor --- .../train_transformer_estimator.py | 23 +++- .../estimator/length_normalized_loss.py | 10 +- .../machine_translation_batch_processor.py | 43 +++++++- .../machine_translation_estimator.py | 1 + .../machine_translation_event_handler.py | 102 ++++++++++++++++-- 5 files changed, 159 insertions(+), 20 deletions(-) diff --git a/scripts/machine_translation/train_transformer_estimator.py b/scripts/machine_translation/train_transformer_estimator.py index cc00533141..c51a3fa51b 100644 --- a/scripts/machine_translation/train_transformer_estimator.py +++ b/scripts/machine_translation/train_transformer_estimator.py @@ -56,6 +56,7 @@ from gluonnlp.estimator import MTTransformerBatchProcessor, MTTransformerParamUpdateHandler from gluonnlp.estimator import TransformerLearningRateHandler, MTTransformerMetricHandler from gluonnlp.estimator import TransformerGradientAccumulationHandler, ComputeBleuHandler +from gluonnlp.estimator import ValBleuHandler np.random.seed(100) random.seed(100) @@ -257,15 +258,27 @@ rescale_loss=rescale_loss) metric_handler = MTTransformerMetricHandler(metrics=mt_estimator.train_metrics, grad_interval=grad_interval) -bleu_handler = ComputeBleuHandler(tgt_vocab, tgt_sentence=val_tgt_sentences, +bleu_handler = ComputeBleuHandler(tgt_vocab=tgt_vocab, tgt_sentence=val_tgt_sentences, translator=translator, compute_bleu_fn=compute_bleu, tokenized=tokenized, tokenizer=args.bleu, split_compound_word=split_compound_word, - bpe=bpe) + bpe=bpe, bleu=args.bleu, detokenizer=detokenizer, + _bpe_to_words=_bpe_to_words) + +val_bleu_handler = ValBleuHandler(val_data=val_data_loader, val_tgt_vocab=tgt_vocab, + val_tgt_sentences=val_tgt_sentences, translator=translator, + tokenized=tokenized, tokenizer=args.bleu, + split_compound_word=split_compound_word, bpe=bpe, + compute_bleu_fn=compute_bleu, + bleu=args.bleu, detokenizer=detokenizer, + _bpe_to_words=_bpe_to_words) event_handlers = [param_update_handler, learning_rate_handler, gradient_acc_handler, - metric_handler, bleu_handler] + metric_handler, val_bleu_handler] -mt_estimator.fit(train_data=train_data_loader, val_data=val_data_loader, - epochs=args.epochs, event_handlers=event_handlers, +mt_estimator.fit(train_data=train_data_loader, + val_data=val_data_loader, + #epochs=args.epochs, + batches=2, + event_handlers=event_handlers, batch_axis=0) diff --git a/src/gluonnlp/estimator/length_normalized_loss.py b/src/gluonnlp/estimator/length_normalized_loss.py index 1b0a087f39..e4558c6fb1 100644 --- a/src/gluonnlp/estimator/length_normalized_loss.py +++ b/src/gluonnlp/estimator/length_normalized_loss.py @@ -61,11 +61,17 @@ def update(self, labels, preds): preds = [preds] for length in seq_valid_length: - total_length = ndarray.sum(length).asscalar() + if isinstance(length, ndarray.ndarray.NDArray): + total_length = ndarray.sum(length).asscalar() + else: + total_length = length self.num_inst += total_length self.global_num_inst += total_length for pred in preds: - loss = ndarray.sum(pred).asscalar() + if isinstance(pred, ndarray.ndarray.NDArray): + loss = ndarray.sum(pred).asscalar() + else: + loss = pred self.sum_metric += loss self.global_sum_metric += loss diff --git a/src/gluonnlp/estimator/machine_translation_batch_processor.py b/src/gluonnlp/estimator/machine_translation_batch_processor.py index 6be212b8da..2d667565a1 100644 --- a/src/gluonnlp/estimator/machine_translation_batch_processor.py +++ b/src/gluonnlp/estimator/machine_translation_batch_processor.py @@ -24,7 +24,7 @@ from ..model.transformer import ParallelTransformer from ..utils.parallel import Parallel -__all__ = ['MTTransformerBatchProcessor'] +__all__ = ['MTTransformerBatchProcessor', 'MTGNMTBatchProcessor'] class MTTransformerBatchProcessor(BatchProcessor): def __init__(self, rescale_loss=100, @@ -69,9 +69,42 @@ def evaluate_batch(self, estimator, val_batch, batch_axis=0): src_valid_length = src_valid_length.as_in_context(ctx) tgt_valid_length = tgt_valid_length.as_in_context(ctx) - out, _ = self.eval_net(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) - loss = self.evaluation_loss(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean().asscalar() + out, _ = estimator.eval_net(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) + loss = estimator.evaluation_loss(out, tgt_seq[:, 1:], tgt_valid_length - 1).sum().asscalar() inst_ids = inst_ids.asnumpy().astype(np.int32).tolist() loss = loss * (tgt_seq.shape[1] - 1) - estimator.val_tgt_valid_length = tgt_seq.shape[1] - 1 - return src_seq, tgt_seq, out, loss + val_tgt_valid_length = (tgt_valid_length - 1).sum().asscalar() + return src_seq, [tgt_seq, val_tgt_valid_length], out, loss + +class MTGNMTBatchProcessor(BatchProcessor): + def __init__(): + pass + + def fit_batch(self, estimator, train_batch, batch_axis=0): + src_seq, tgt_seql, src_valid_length, tgt_valid_length = train_batch + src_seq = src_seq.as_in_context(estimator.context) + tgt_seq = tgt_seq.as_in_context(estimator.context) + src_valid_length = src_valid_length.as_in_context(estimator.context) + tgt_valid_lenght = tgt_valid_length.as_in_context(estimator.context) + with mx.autograd.record(): + out, _ = estimator.net(src_seq, tgt_seq[:, :-1], src_valid_length, + tgt_valid_length - 1) + loss = estimator.loss(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean() + loss = loss * (tgt_seq.shape[1] - 1) + loss = loss / (tgt_valid_length - 1).mean() + loss.backward() + return src_seq, [tgt_seq, (tgt_valid_length - 1).sum()], out, loss * tgt_seq.shape[0] + + def evaluate_batch(self, estimator, val_batch, batch_axis=0): + src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids = val_batch + src_seq = src_seq.as_in_context(estimator.context) + tgt_seq = tgt_seq.as_in_context(estimator.context) + src_valid_length = src_valid_length.as_in_context(estimator.context) + tgt_valid_length = tgt_valid_length.as_in_context(estimator.context) + out, _ = estimator.eval_net(src_seq, tgt_seq[:, :-1], src_valid_length, + tgt_valid_length - 1) + loss = estimator.evaluation_loss(out, tgt_seq[:, 1:], + tgt_valid_length - 1).sum().asscalar() + loss = loss * (tgt_seq.shape[1] - 1) + val_tgt_valid_length = (tgt_valid_length - 1).sum().asscalar() + return src_seq, [tgt_seq, val_tgt_valid_length], out, loss diff --git a/src/gluonnlp/estimator/machine_translation_estimator.py b/src/gluonnlp/estimator/machine_translation_estimator.py index 3e87a357a0..8c6c63698f 100644 --- a/src/gluonnlp/estimator/machine_translation_estimator.py +++ b/src/gluonnlp/estimator/machine_translation_estimator.py @@ -49,3 +49,4 @@ def __init__(self, net, loss, self.tgt_valid_length = 0 self.val_tgt_valid_length = 0 self.avg_param = None + self.bleu_score = 0.0 diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py index 93c4f5e75d..3037296bd4 100644 --- a/src/gluonnlp/estimator/machine_translation_event_handler.py +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -21,17 +21,19 @@ import warnings import math +import numpy as np import mxnet as mx from mxnet.gluon.contrib.estimator import TrainBegin, TrainEnd, EpochBegin from mxnet.gluon.contrib.estimator import EpochEnd, BatchBegin, BatchEnd from mxnet.gluon.contrib.estimator import GradientUpdateHandler from mxnet.gluon.contrib.estimator import MetricHandler +from mxnet import gluon from mxnet.metric import Loss as MetricLoss from .length_normalized_loss import LengthNormalizedLoss __all__ = ['MTTransformerParamUpdateHandler', 'TransformerLearningRateHandler', 'MTTransformerMetricHandler', 'TransformerGradientAccumulationHandler', - 'ComputeBleuHandler'] + 'ComputeBleuHandler', 'ValBleuHandler', 'MTGNMTGradientUpdateHandler'] class MTTransformerParamUpdateHandler(EpochBegin, BatchEnd, EpochEnd): def __init__(self, avg_start, grad_interval=1): @@ -89,12 +91,24 @@ def batch_begin(self, estimator, *args, **kwargs): estimator.trainer.set_learning_rate(new_lr) self.batch_id += 1 +class MTGNMTGradientUpdateHandler(GradientUpdateHandler): + def __init__(self, clip): + super(MTGNMTGradientUpdateHandler, self).__init__() + self.clip = clip + + def batch_end(self, estimator, *args, **kwargs): + grads = [p.grad(ctx) for p in estimator.net.collect_params().values()] + gnorm = gluon.utils.clip_global_norm(grads, self.clip) + estimator.trainer.step(1) + class TransformerGradientAccumulationHandler(GradientUpdateHandler, TrainBegin, - EpochBegin): + EpochBegin, + EpochEnd): def __init__(self, grad_interval=1, batch_size=1024, rescale_loss=100): + super(TransformerGradientAccumulationHandler, self).__init__() self.grad_interval = grad_interval self.batch_size = batch_size self.rescale_loss = rescale_loss @@ -155,7 +169,7 @@ def batch_end(self, estimator, *args, **kwargs): # A temporary workaround for computing the bleu function. After bleu is in the metric # api, this event handler could be removed. -class ComputeBleuHandler(EpochEnd): +class ComputeBleuHandler(BatchEnd, EpochEnd): def __init__(self, tgt_vocab, tgt_sentence, @@ -164,7 +178,10 @@ def __init__(self, tokenized, tokenizer, split_compound_word, - bpe): + bpe, + bleu, + detokenizer, + _bpe_to_words): self.tgt_vocab = tgt_vocab self.tgt_sentence = tgt_sentence self.translator = translator @@ -173,6 +190,9 @@ def __init__(self, self.tokenizer = tokenizer self.split_compound_word = split_compound_word self.bpe = bpe + self.bleu = bleu + self.detokenizer = detokenizer + self._bpe_to_words = _bpe_to_words self.all_inst_ids = [] self.translation_out = [] @@ -192,12 +212,78 @@ def batch_end(self, estimator, *args, **kwargs): max_score_sample[i][1:(sample_valid_length[i] - 1)]]) def epoch_end(self, estimator, *args, **kwargs): - self.real_translation_out = [None for _ in range(len(all_inst_ids))] + real_translation_out = [None for _ in range(len(all_inst_ids))] for ind, sentence in zip(self.all_inst_ids, self.translation_out): - self.real_translation_out[ind] = sentence - self.bleu_score, _, _, _, _ = self.compute_bleu_fn([self.tgt_sentence], - self.real_translation_out, + if self.bleu == 'tweaked': + real_translation_out[ind] = sentence + elif self.bleu == '13a' or self.bleu == 'intl': + real_translation_out[ind] = self.detokenizer(self._bpe_to_words(sentence)) + else: + raise NotImplementedError + estimator.bleu_score, _, _, _, _ = self.compute_bleu_fn([self.tgt_sentence], + real_translation_out, tokenized=self.tokenized, tokenizer=self.tokenizer, split_compound_word=self.split_compound_word, bpe=self.bpe) + + +# temporary validation bleu metric hack, it can be removed once bleu metric api is available +class ValBleuHandler(EpochEnd): + def __init__(self, val_data, + val_tgt_vocab, + val_tgt_sentences, + translator, + tokenized, + tokenizer, + split_compound_word, + bpe, + compute_bleu_fn, + bleu, + detokenizer, + _bpe_to_words): + self.val_data = val_data + self.val_tgt_vocab = val_tgt_vocab + self.val_tgt_sentences = val_tgt_sentences + self.translator = translator + self.tokenized = tokenized + self.tokenizer = tokenizer + self.split_compound_word = split_compound_word + self.bpe = bpe + self.compute_bleu_fn = compute_bleu_fn + self.bleu = bleu + self.detokenizer = detokenizer + self._bpe_to_words = _bpe_to_words + + def epoch_end(self, estimator, *args, **kwargs): + translation_out = [] + all_inst_ids = [] + for _, (src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids) \ + in enumerate(self.val_data): + src_seq = src_seq.as_in_context(estimator.context[0]) + tgt_seq = tgt_seq.as_in_context(estimator.context[0]) + src_valid_length = src_valid_length.as_in_context(estimator.context[0]) + tgt_valid_length = tgt_valid_length.as_in_context(estimator.context[0]) + all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist()) + samples, _, sample_valid_length = self.translator.translate( + src_seq=src_seq, src_valid_length=src_valid_length) + max_score_sample = samples[:, 0, :].asnumpy() + sample_valid_length = sample_valid_length[:, 0].asnumpy() + for i in range(max_score_sample.shape[0]): + translation_out.append( + [self.val_tgt_vocab.idx_to_token[ele] for ele in + max_score_sample[i][1:(sample_valid_length[i] - 1)]]) + real_translation_out = [None for _ in range(len(all_inst_ids))] + for ind, sentence in zip(all_inst_ids, translation_out): + if self.bleu == 'tweaked': + real_translation_out[ind] = sentence + elif self.bleu == '13a' or self.beu == 'intl': + real_translation_out[ind] = self.detokenizer(self._bpe_to_words(sentence)) + else: + raise NotImplementedError + estimator.bleu, _, _, _, _ = self.compute_bleu_fn([self.val_tgt_sentences], + real_translation_out, + tokenized=self.tokenized, + tokenizer=self.tokenizer, + split_compound_word=self.split_compound_word, + bpe=self.bpe) From 83ef62ddb1f36adc32e712048b1d3cdb854df0d4 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Tue, 14 Jan 2020 10:48:51 +0000 Subject: [PATCH 06/26] add gnmt event handler and script --- .../train_gnmt_estimator.py | 189 ++++++++++++++++++ .../train_transformer_estimator.py | 11 +- .../machine_translation_batch_processor.py | 1 + .../machine_translation_event_handler.py | 94 +++++++-- 4 files changed, 275 insertions(+), 20 deletions(-) create mode 100644 scripts/machine_translation/train_gnmt_estimator.py diff --git a/scripts/machine_translation/train_gnmt_estimator.py b/scripts/machine_translation/train_gnmt_estimator.py new file mode 100644 index 0000000000..acff3c918b --- /dev/null +++ b/scripts/machine_translation/train_gnmt_estimator.py @@ -0,0 +1,189 @@ +""" +Google Neural Machine Translation +================================= + +This example shows how to implement the GNMT model with Gluon NLP Toolkit. + +@article{wu2016google, + title={Google's neural machine translation system: + Bridging the gap between human and machine translation}, + author={Wu, Yonghui and Schuster, Mike and Chen, Zhifeng and Le, Quoc V and + Norouzi, Mohammad and Macherey, Wolfgang and Krikun, Maxim and Cao, Yuan and Gao, Qin and + Macherey, Klaus and others}, + journal={arXiv preprint arXiv:1609.08144}, + year={2016} +} +""" + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name,logging-format-interpolation + +import argparse +import time +import random +import os +import logging +import numpy as np +import mxnet as mx +from mxnet import gluon +import gluonnlp as nlp + +from gluonnlp.model.translation import NMTModel +from gluonnlp.loss import MaskedSoftmaxCELoss +from gnmt import get_gnmt_encoder_decoder +from translation import BeamSearchTranslator +from utils import logging_config +from bleu import compute_bleu +import dataprocessor +from gluonnlp.estimator import MachineTranslationEstimator, LengthNormalizedLoss +from gluonnlp.estimator import MTGNMTBatchProcessor, MTGNMTGradientUpdateHandler +from gluonnlp.estimator import ComputeBleuHandler, ValBleuHandler +from gluonnlp.estimator import MTTransformerMetricHandler, MTGNMTLearningRateHandler +from gluonnlp.estimator import MTCheckpointHandler + +np.random.seed(100) +random.seed(100) +mx.random.seed(10000) + +nlp.utils.check_version('0.9.0') + +parser = argparse.ArgumentParser(description='Neural Machine Translation Example.' + 'We train the Google NMT model') +parser.add_argument('--dataset', type=str, default='IWSLT2015', help='Dataset to use.') +parser.add_argument('--src_lang', type=str, default='en', help='Source language') +parser.add_argument('--tgt_lang', type=str, default='vi', help='Target language') +parser.add_argument('--epochs', type=int, default=40, help='upper epoch limit') +parser.add_argument('--num_hidden', type=int, default=128, help='Dimension of the embedding ' + 'vectors and states.') +parser.add_argument('--dropout', type=float, default=0.2, + help='dropout applied to layers (0 = no dropout)') +parser.add_argument('--num_layers', type=int, default=2, help='number of layers in the encoder' + ' and decoder') +parser.add_argument('--num_bi_layers', type=int, default=1, + help='number of bidirectional layers in the encoder and decoder') +parser.add_argument('--batch_size', type=int, default=128, help='Batch size') +parser.add_argument('--beam_size', type=int, default=4, help='Beam size') +parser.add_argument('--lp_alpha', type=float, default=1.0, + help='Alpha used in calculating the length penalty') +parser.add_argument('--lp_k', type=int, default=5, help='K used in calculating the length penalty') +parser.add_argument('--test_batch_size', type=int, default=32, help='Test batch size') +parser.add_argument('--num_buckets', type=int, default=5, help='Bucket number') +parser.add_argument('--bucket_scheme', type=str, default='constant', + help='Strategy for generating bucket keys. It supports: ' + '"constant": all the buckets have the same width; ' + '"linear": the width of bucket increases linearly; ' + '"exp": the width of bucket increases exponentially') +parser.add_argument('--bucket_ratio', type=float, default=0.0, help='Ratio for increasing the ' + 'throughput of the bucketing') +parser.add_argument('--src_max_len', type=int, default=50, help='Maximum length of the source ' + 'sentence') +parser.add_argument('--tgt_max_len', type=int, default=50, help='Maximum length of the target ' + 'sentence') +parser.add_argument('--optimizer', type=str, default='adam', help='optimization algorithm') +parser.add_argument('--lr', type=float, default=1E-3, help='Initial learning rate') +parser.add_argument('--lr_update_factor', type=float, default=0.5, + help='Learning rate decay factor') +parser.add_argument('--clip', type=float, default=5.0, help='gradient clipping') +parser.add_argument('--log_interval', type=int, default=100, metavar='N', + help='report interval') +parser.add_argument('--save_dir', type=str, default='out_dir', + help='directory path to save the final model and training log') +parser.add_argument('--gpu', type=int, default=None, + help='id of the gpu to use. Set it to empty means to use cpu.') +args = parser.parse_args() +print(args) +logging_config(args.save_dir) + + +data_train, data_val, data_test, val_tgt_sentences, test_tgt_sentences, src_vocab, tgt_vocab\ + = dataprocessor.load_translation_data(dataset=args.dataset, bleu='tweaked', args=args) + +dataprocessor.write_sentences(val_tgt_sentences, os.path.join(args.save_dir, 'val_gt.txt')) +dataprocessor.write_sentences(test_tgt_sentences, os.path.join(args.save_dir, 'test_gt.txt')) + +data_train = data_train.transform(lambda src, tgt: (src, tgt, len(src), len(tgt)), lazy=False) +data_val = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i) + for i, ele in enumerate(data_val)]) +data_test = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i) + for i, ele in enumerate(data_test)]) +if args.gpu is None: + ctx = mx.cpu() + print('Use CPU') +else: + ctx = mx.gpu(args.gpu) + +encoder, decoder, one_step_ahead_decoder = get_gnmt_encoder_decoder( + hidden_size=args.num_hidden, dropout=args.dropout, num_layers=args.num_layers, + num_bi_layers=args.num_bi_layers) +model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder, + one_step_ahead_decoder=one_step_ahead_decoder, embed_size=args.num_hidden, + prefix='gnmt_') +model.initialize(init=mx.init.Uniform(0.1), ctx=ctx) +static_alloc = True +model.hybridize(static_alloc=static_alloc) +logging.info(model) + +translator = BeamSearchTranslator(model=model, beam_size=args.beam_size, + scorer=nlp.model.BeamSearchScorer(alpha=args.lp_alpha, + K=args.lp_k), + max_length=args.tgt_max_len + 100) +logging.info('Use beam_size={}, alpha={}, K={}'.format(args.beam_size, args.lp_alpha, args.lp_k)) + + +loss_function = MaskedSoftmaxCELoss() +loss_function.hybridize(static_alloc=static_alloc) +trainer = gluon.Trainer(model.collect_params(), args.optimizer, {'learning_rate': args.lr}) + +train_data_loader, val_data_loader, test_data_loader \ + = dataprocessor.make_dataloader(data_train, data_val, data_test, args) + +train_metric = LengthNormalizedLoss(loss_function) +val_metric = LengthNormalziedLoss(loss_function) +batchprocessor = MTGNMTBatchProcessor() +gnmt_estimator = MachineTranslationEstimator(net=model, loss=loss_function, + train_metrics=train_metric, + val_metrics=val_metric, + trainer=trainer, + context=ctx, + batch_processor=batch_processor) + +learning_rate_handler = MTGNMTLearningRateHandler(epochs=args.epochs, + lr_update_factor=args.lr_update_factor) + +gradient_update_handler = MTGNMTGradientUpdateHandler(clip=args.clip) + +metric_handler = MTTransformerMetrichandler(metrics=gnmt_estimator.train_metrics, + grad_interval=1) + +bleu_handler = ComputeBleuHandler(tgt_vocab=tgt_vocab, tgt_sentence=val_tgt_sentences, + translator=translator, compute_bleu_fn=compute_bleu) + +val_bleu_handler = ValBleuHandler(tgt_vocab=tgt_vocab, tgt_sentence=val_tgt_sentence, + translator=translator, compute_bleu_fn=compute_bleu) + +checkpoint_handler = MTCheckpointHandler(model_dir=args.save_dir) + +event_handlers = [learning_rate_handler, gradient_update_handler, metric_handler, + val_bleu_handler, checkpoint_handler] + +gnmt_estimator.fit(train_data=train_data_loader, + val_data=val_data_loader, + epochs=args.epochs, + #batches=5, + event_handlers=event_handlers, + batch_axis=0) diff --git a/scripts/machine_translation/train_transformer_estimator.py b/scripts/machine_translation/train_transformer_estimator.py index c51a3fa51b..c50d667504 100644 --- a/scripts/machine_translation/train_transformer_estimator.py +++ b/scripts/machine_translation/train_transformer_estimator.py @@ -56,7 +56,7 @@ from gluonnlp.estimator import MTTransformerBatchProcessor, MTTransformerParamUpdateHandler from gluonnlp.estimator import TransformerLearningRateHandler, MTTransformerMetricHandler from gluonnlp.estimator import TransformerGradientAccumulationHandler, ComputeBleuHandler -from gluonnlp.estimator import ValBleuHandler +from gluonnlp.estimator import ValBleuHandler, MTCheckpointHandler np.random.seed(100) random.seed(100) @@ -273,12 +273,17 @@ bleu=args.bleu, detokenizer=detokenizer, _bpe_to_words=_bpe_to_words) +checkpoint_handler = MTCheckpointHandler(model_dir=args.save_dir, + average_checkpoint=args.average_checkpoint, + num_averages=args.num_averages, + average_start=args.average_start) + event_handlers = [param_update_handler, learning_rate_handler, gradient_acc_handler, - metric_handler, val_bleu_handler] + metric_handler, val_bleu_handler, checkpoint_handler] mt_estimator.fit(train_data=train_data_loader, val_data=val_data_loader, #epochs=args.epochs, - batches=2, + batches=5, event_handlers=event_handlers, batch_axis=0) diff --git a/src/gluonnlp/estimator/machine_translation_batch_processor.py b/src/gluonnlp/estimator/machine_translation_batch_processor.py index 2d667565a1..c3265dcfe6 100644 --- a/src/gluonnlp/estimator/machine_translation_batch_processor.py +++ b/src/gluonnlp/estimator/machine_translation_batch_processor.py @@ -59,6 +59,7 @@ def fit_batch(self, estimator, train_batch, batch_axis=0): for seq in seqs: self.parallel_model.put((seq, self.batch_size)) Ls = [self.parallel_model.get() for _ in range(len(estimator.context))] + Ls = Ls * self.batch_size * self.rescale_loss return data, [target, tgt_word_count - bs], None, Ls def evaluate_batch(self, estimator, val_batch, batch_axis=0): diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py index 3037296bd4..e9b5a738cd 100644 --- a/src/gluonnlp/estimator/machine_translation_event_handler.py +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -20,12 +20,13 @@ import copy import warnings import math +import os import numpy as np import mxnet as mx from mxnet.gluon.contrib.estimator import TrainBegin, TrainEnd, EpochBegin from mxnet.gluon.contrib.estimator import EpochEnd, BatchBegin, BatchEnd -from mxnet.gluon.contrib.estimator import GradientUpdateHandler +from mxnet.gluon.contrib.estimator import GradientUpdateHandler, CheckpointHandler from mxnet.gluon.contrib.estimator import MetricHandler from mxnet import gluon from mxnet.metric import Loss as MetricLoss @@ -33,7 +34,8 @@ __all__ = ['MTTransformerParamUpdateHandler', 'TransformerLearningRateHandler', 'MTTransformerMetricHandler', 'TransformerGradientAccumulationHandler', - 'ComputeBleuHandler', 'ValBleuHandler', 'MTGNMTGradientUpdateHandler'] + 'ComputeBleuHandler', 'ValBleuHandler', 'MTGNMTGradientUpdateHandler', + 'MTGNMTLearningRateHandler', 'MTCheckpointHandler'] class MTTransformerParamUpdateHandler(EpochBegin, BatchEnd, EpochEnd): def __init__(self, avg_start, grad_interval=1): @@ -67,6 +69,17 @@ def batch_end(self, estimator, *args, **kwargs): def epoch_end(self, estimator, *args, **kwargs): self._update_avg_param(estimator) +class MTGNMTLearningRateHandler(EpochEnd): + def __init__(self, epochs, lr_update_factor): + self.epoch_id = 0 + self.epochs = epochs + self.lr_update_factor = lr_update_factor + + def epoch_end(self, estimator, *args, **kwargs): + if self.epoch_id + 1 >= (self.epochs * 2) // 3: + new_lr = estimator.trainer.learning_rate * self.lr_update_factor + estimator.trainer.set_learning_rate(new_lr) + self.epoch_id += 1 class TransformerLearningRateHandler(EpochBegin, BatchBegin): def __init__(self, lr, @@ -167,6 +180,53 @@ def batch_end(self, estimator, *args, **kwargs): else: metric.update(label, pred) +class MTCheckpointHandler(CheckpointHandler, TrainEnd): + def __init__(self, *args, + average_checkpoint=None, + num_averages=None, + average_start=0, + epochs=0, + **kwargs): + super(MTCheckpointHandler, self).__init__(*args, **kwargs) + self.bleu_score = 0. + self.average_checkpoint = average_checkpoint + self.num_averages = num_averages + self.average_start = average_start + self.epochs = epochs + + def epoch_end(self, estimator, *args, **kwargs): + if estimator.bleu_score > self.bleu_score: + self.bleu_score = estimator.bleu_score + save_path = os.path.join(self.model_dir, 'valid_best.params') + estimator.net.save_parameters(save_path) + save_path = os.path.join(self.model_dir, 'epoch{:d}.params'.format(self.current_epoch)) + estimator.net.save(save_path) + self.current_epoch += 1 + + def train_end(self, estimator, *args, **kwargs): + ctx = estimator.context + save_path = os.path.join(self.model_dir, 'average.params') + mx.nd.save(save_path, estimator.avg_param) + if self.average_checkpoint: + for j in range(args.num_averages): + params = mx.nd.load(os.path.join(self.model_dir, + 'epoch{:d}.params'.format(self.epochs - j - 1))) + alpha = 1. / (j + 1) + for k, v in estimator.net._collect_params_with_prefix().items(): + for c in ctx: + v.data(c)[:] == alpha * (params[k].as_in_context(c) - v.data(c)) + save_path = os.path.join(self.model_dir, + 'average_checkpoint_{}.params'.format(self.num_averages)) + estimator.net.save_parameters(save_path) + elif self.average_start: + for k, v in estimator.net.collect_params().items(): + v.set_data(average_param_dict[k]) + save_path = os.path.join(self.model_dir, 'average.params') + estimator.net.save_parameters(save_path) + else: + estimator.net.load_parameters(os.path.join(self.model_dir, + 'valid_best.params'), ctx) + # A temporary workaround for computing the bleu function. After bleu is in the metric # api, this event handler could be removed. class ComputeBleuHandler(BatchEnd, EpochEnd): @@ -175,13 +235,13 @@ def __init__(self, tgt_sentence, translator, compute_bleu_fn, - tokenized, - tokenizer, - split_compound_word, - bpe, - bleu, - detokenizer, - _bpe_to_words): + tokenized=True, + tokenizer='13a', + split_compound_word=False, + bpe=False, + bleu='tweaked', + detokenizer=None, + _bpe_to_words=None): self.tgt_vocab = tgt_vocab self.tgt_sentence = tgt_sentence self.translator = translator @@ -234,14 +294,14 @@ def __init__(self, val_data, val_tgt_vocab, val_tgt_sentences, translator, - tokenized, - tokenizer, - split_compound_word, - bpe, compute_bleu_fn, - bleu, - detokenizer, - _bpe_to_words): + tokenized=True, + tokenizer='13a', + split_compound_word=False, + bpe=False, + bleu='tweaked', + detokenizer=None, + _bpe_to_words=None): self.val_data = val_data self.val_tgt_vocab = val_tgt_vocab self.val_tgt_sentences = val_tgt_sentences @@ -281,7 +341,7 @@ def epoch_end(self, estimator, *args, **kwargs): real_translation_out[ind] = self.detokenizer(self._bpe_to_words(sentence)) else: raise NotImplementedError - estimator.bleu, _, _, _, _ = self.compute_bleu_fn([self.val_tgt_sentences], + estimator.bleu_score, _, _, _, _ = self.compute_bleu_fn([self.val_tgt_sentences], real_translation_out, tokenized=self.tokenized, tokenizer=self.tokenizer, From f40e4dc8df425925ac66a5b2359b5f5d21e0b23d Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Tue, 14 Jan 2020 10:53:52 +0000 Subject: [PATCH 07/26] bug fix --- src/gluonnlp/estimator/machine_translation_batch_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gluonnlp/estimator/machine_translation_batch_processor.py b/src/gluonnlp/estimator/machine_translation_batch_processor.py index c3265dcfe6..26bc36a6be 100644 --- a/src/gluonnlp/estimator/machine_translation_batch_processor.py +++ b/src/gluonnlp/estimator/machine_translation_batch_processor.py @@ -59,7 +59,7 @@ def fit_batch(self, estimator, train_batch, batch_axis=0): for seq in seqs: self.parallel_model.put((seq, self.batch_size)) Ls = [self.parallel_model.get() for _ in range(len(estimator.context))] - Ls = Ls * self.batch_size * self.rescale_loss + Ls = [l * self.batch_size * self.rescale_loss for l in Ls] return data, [target, tgt_word_count - bs], None, Ls def evaluate_batch(self, estimator, val_batch, batch_axis=0): From 6dabf23e1cabb846ef98609f2320ad483714b36c Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Thu, 16 Jan 2020 08:05:07 +0000 Subject: [PATCH 08/26] fix various errors --- .../train_transformer_estimator.py | 32 +++++++- .../machine_translation_event_handler.py | 76 +++++++++++++++---- 2 files changed, 89 insertions(+), 19 deletions(-) diff --git a/scripts/machine_translation/train_transformer_estimator.py b/scripts/machine_translation/train_transformer_estimator.py index c50d667504..395f22eb9d 100644 --- a/scripts/machine_translation/train_transformer_estimator.py +++ b/scripts/machine_translation/train_transformer_estimator.py @@ -57,6 +57,7 @@ from gluonnlp.estimator import TransformerLearningRateHandler, MTTransformerMetricHandler from gluonnlp.estimator import TransformerGradientAccumulationHandler, ComputeBleuHandler from gluonnlp.estimator import ValBleuHandler, MTCheckpointHandler +from gluonnlp.estimator import MTTransformerLoggingHandler, MTValidationHandler np.random.seed(100) random.seed(100) @@ -278,12 +279,35 @@ num_averages=args.num_averages, average_start=args.average_start) -event_handlers = [param_update_handler, learning_rate_handler, gradient_acc_handler, - metric_handler, val_bleu_handler, checkpoint_handler] +val_metric_handler = MTTransformerMetricHandler(metrics=mt_estimator.val_metrics) + +val_validation_handler = MTValidationHandler(val_data=val_data_loader, + eval_fn=mt_estimator.evaluate, + event_handlers=val_metric_handler) + +log_interval = args.log_interval * grad_interval +logging_handler = MTTransformerLoggingHandler(log_interval=log_interval, + metrics=mt_estimator.train_metrics) + +event_handlers = [param_update_handler, + learning_rate_handler, + gradient_acc_handler, + metric_handler, + val_validation_handler, + val_bleu_handler, + checkpoint_handler, + logging_handler] mt_estimator.fit(train_data=train_data_loader, val_data=val_data_loader, - #epochs=args.epochs, - batches=5, + epochs=args.epochs, + #batches=200, event_handlers=event_handlers, batch_axis=0) + +val_event_handlers = [val_metric_handler, + bleu_handler] + +mt_estimator.evaluate(val_data=val_data_loader, event_handlers=val_event_handlers) + +mt_estimator.evaluate(val_data=test_data_loader, event_handlers=val_event_handlers) diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py index e9b5a738cd..a217a009c7 100644 --- a/src/gluonnlp/estimator/machine_translation_event_handler.py +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -21,13 +21,15 @@ import warnings import math import os +import time import numpy as np import mxnet as mx from mxnet.gluon.contrib.estimator import TrainBegin, TrainEnd, EpochBegin from mxnet.gluon.contrib.estimator import EpochEnd, BatchBegin, BatchEnd from mxnet.gluon.contrib.estimator import GradientUpdateHandler, CheckpointHandler -from mxnet.gluon.contrib.estimator import MetricHandler +from mxnet.gluon.contrib.estimator import MetricHandler, LoggingHandler +from mxnet.gluon.contrib.estimator import ValidationHandler from mxnet import gluon from mxnet.metric import Loss as MetricLoss from .length_normalized_loss import LengthNormalizedLoss @@ -35,7 +37,8 @@ __all__ = ['MTTransformerParamUpdateHandler', 'TransformerLearningRateHandler', 'MTTransformerMetricHandler', 'TransformerGradientAccumulationHandler', 'ComputeBleuHandler', 'ValBleuHandler', 'MTGNMTGradientUpdateHandler', - 'MTGNMTLearningRateHandler', 'MTCheckpointHandler'] + 'MTGNMTLearningRateHandler', 'MTCheckpointHandler', + 'MTTransformerLoggingHandler', 'MTValidationHandler'] class MTTransformerParamUpdateHandler(EpochBegin, BatchEnd, EpochEnd): def __init__(self, avg_start, grad_interval=1): @@ -153,7 +156,7 @@ def epoch_end(self, estimator, *args, **kwargs): self._update_gradient(estimator) class MTTransformerMetricHandler(MetricHandler, BatchBegin): - def __init__(self, grad_interval, *args, **kwargs): + def __init__(self, *args, grad_interval=None, **kwargs): super(MTTransformerMetricHandler, self).__init__(*args, **kwargs) self.grad_interval = grad_interval @@ -163,7 +166,7 @@ def epoch_begin(self, estimator, *args, **kwargs): metric.reset() def batch_begin(self, estimator, *args, **kwargs): - if self.batch_id % self.grad_interval == 0: + if self.grad_interval is not None and self.batch_id % self.grad_interval == 0: for metric in self.metrics: metric.reset_local() self.batch_id += 1 @@ -200,7 +203,7 @@ def epoch_end(self, estimator, *args, **kwargs): save_path = os.path.join(self.model_dir, 'valid_best.params') estimator.net.save_parameters(save_path) save_path = os.path.join(self.model_dir, 'epoch{:d}.params'.format(self.current_epoch)) - estimator.net.save(save_path) + estimator.net.save_parameters(save_path) self.current_epoch += 1 def train_end(self, estimator, *args, **kwargs): @@ -215,17 +218,17 @@ def train_end(self, estimator, *args, **kwargs): for k, v in estimator.net._collect_params_with_prefix().items(): for c in ctx: v.data(c)[:] == alpha * (params[k].as_in_context(c) - v.data(c)) - save_path = os.path.join(self.model_dir, - 'average_checkpoint_{}.params'.format(self.num_averages)) - estimator.net.save_parameters(save_path) + save_path = os.path.join(self.model_dir, + 'average_checkpoint_{}.params'.format(self.num_averages)) + estimator.net.save_parameters(save_path) elif self.average_start: for k, v in estimator.net.collect_params().items(): - v.set_data(average_param_dict[k]) - save_path = os.path.join(self.model_dir, 'average.params') - estimator.net.save_parameters(save_path) - else: - estimator.net.load_parameters(os.path.join(self.model_dir, - 'valid_best.params'), ctx) + v.set_data(estimator.avg_param[k]) + save_path = os.path.join(self.model_dir, 'average.params') + estimator.net.save_parameters(save_path) + else: + estimator.net.load_parameters(os.path.join(self.model_dir, + 'valid_best.params'), ctx) # A temporary workaround for computing the bleu function. After bleu is in the metric # api, this event handler could be removed. @@ -258,9 +261,14 @@ def __init__(self, self.translation_out = [] def batch_end(self, estimator, *args, **kwargs): + ctx = estimator.context[0] batch = kwargs['batch'] label = kwargs['label'] src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids = batch + src_seq = src_seq.as_in_context(ctx) + tgt_seq = tgt_seq.as_in_context(ctx) + src_valid_length = src_valid_length.as_in_context(ctx) + tgt_valid_length = tgt_valid_length.as_in_context(ctx) self.all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist()) samples, _, sample_valid_length = self.translator.translate( src_seq=src_seq, src_valid_length=src_valid_length) @@ -272,7 +280,7 @@ def batch_end(self, estimator, *args, **kwargs): max_score_sample[i][1:(sample_valid_length[i] - 1)]]) def epoch_end(self, estimator, *args, **kwargs): - real_translation_out = [None for _ in range(len(all_inst_ids))] + real_translation_out = [None for _ in range(len(self.all_inst_ids))] for ind, sentence in zip(self.all_inst_ids, self.translation_out): if self.bleu == 'tweaked': real_translation_out[ind] = sentence @@ -347,3 +355,41 @@ def epoch_end(self, estimator, *args, **kwargs): tokenizer=self.tokenizer, split_compound_word=self.split_compound_word, bpe=self.bpe) + +class MTTransformerLoggingHandler(LoggingHandler): + def __init__(self, *args, **kwargs): + super(MTTransformerLoggingHandler, self).__init__(*args, **kwargs) + + def batch_end(self, estimator, *args, **kwargs): + if isinstance(self.log_interval, int): + batch_time = time.time() - self.batch_start + msg = '[Epoch %d][Batch %d]' % (self.current_epoch, self.batch_index) + cur_batches = kwargs['batch'] + for batch in cur_batches: + self.processed_samples += batch[0].shape[0] + msg += '[Samples %s]' % (self.processed_samples) + self.log_interval_time += batch_time + if self.batch_index % self.log_interval == 0: + msg += 'time/interval: %.3fs ' % self.log_interval_time + self.log_interval_time = 0 + for metric in self.metrics: + name, val = metric.get() + msg += '%s: %.4f, ' % (name, val) + estimator.logger.info(msg.rstrip(', ')) + self.batch_index += 1 + +# TODO: change the mxnet validation_handler to include event_handlers +class MTValidationHandler(ValidationHandler): + def __init__(self, event_handlers, *args, **kwargs): + super(MTValidationHandler, self).__init__(*args, **kwargs) + self.event_handlers = event_handlers + + def batch_end(self, estimator, *args, **kwargs): + self.current_batch += 1 + if self.batch_period and self.current_batch % self.batch_period == 0: + self.eval_fn(val_data=self.val_data, event_handlers=self.event_handlers) + + def epoch_end(self, estimator, *args, **kwargs): + self.current_epoch += 1 + if self.epoch_period and self.current_epoch % self.epoch_period == 0: + self.eval_fn(val_data=self.val_data, event_handlers=self.event_handlers) From f1e26d1828230c2a7c9431a05f5549ec509fa6de Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Fri, 17 Jan 2020 06:21:53 +0000 Subject: [PATCH 09/26] bug fix --- src/gluonnlp/estimator/machine_translation_event_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py index a217a009c7..fc9f29b0fb 100644 --- a/src/gluonnlp/estimator/machine_translation_event_handler.py +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -55,7 +55,7 @@ def _update_avg_param(self, estimator): params = estimator.net.collect_params() alpha = 1. / max(1, self.step_num - self.avg_start) for key, val in estimator.avg_param.items(): - estimator.avg_param[:] += alpha * \ + val[:] += alpha * \ (params[key].data(estimator.context[0]) - val) From 81530e1c5bb616e1a6bcacff9ea186e6711bcf5a Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Fri, 17 Jan 2020 06:23:24 +0000 Subject: [PATCH 10/26] bug fix --- src/gluonnlp/estimator/machine_translation_event_handler.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py index fc9f29b0fb..33e822543f 100644 --- a/src/gluonnlp/estimator/machine_translation_event_handler.py +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -55,9 +55,7 @@ def _update_avg_param(self, estimator): params = estimator.net.collect_params() alpha = 1. / max(1, self.step_num - self.avg_start) for key, val in estimator.avg_param.items(): - val[:] += alpha * \ - (params[key].data(estimator.context[0]) - - val) + val[:] += alpha * (params[key].data(estimator.context[0]) - val) def epoch_begin(self, estimator, *args, **kwargs): self.batch_id = 0 From 0165b80ebc9e6d34cc86f0a082d3876b9605d26f Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Fri, 17 Jan 2020 08:54:57 +0000 Subject: [PATCH 11/26] fix gnmt estimator bugs --- .../train_gnmt_estimator.py | 18 +++++++++---- .../machine_translation_batch_processor.py | 25 +++++++++++-------- .../machine_translation_event_handler.py | 3 ++- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/scripts/machine_translation/train_gnmt_estimator.py b/scripts/machine_translation/train_gnmt_estimator.py index acff3c918b..55b8867e3b 100644 --- a/scripts/machine_translation/train_gnmt_estimator.py +++ b/scripts/machine_translation/train_gnmt_estimator.py @@ -54,7 +54,8 @@ from gluonnlp.estimator import MTGNMTBatchProcessor, MTGNMTGradientUpdateHandler from gluonnlp.estimator import ComputeBleuHandler, ValBleuHandler from gluonnlp.estimator import MTTransformerMetricHandler, MTGNMTLearningRateHandler -from gluonnlp.estimator import MTCheckpointHandler +from gluonnlp.estimator import MTCheckpointHandler, MTTransformerMetricHandler +from gluonnlp.estimator import MTValidationHandler np.random.seed(100) random.seed(100) @@ -153,8 +154,8 @@ = dataprocessor.make_dataloader(data_train, data_val, data_test, args) train_metric = LengthNormalizedLoss(loss_function) -val_metric = LengthNormalziedLoss(loss_function) -batchprocessor = MTGNMTBatchProcessor() +val_metric = LengthNormalizedLoss(loss_function) +batch_processor = MTGNMTBatchProcessor() gnmt_estimator = MachineTranslationEstimator(net=model, loss=loss_function, train_metrics=train_metric, val_metrics=val_metric, @@ -167,17 +168,24 @@ gradient_update_handler = MTGNMTGradientUpdateHandler(clip=args.clip) -metric_handler = MTTransformerMetrichandler(metrics=gnmt_estimator.train_metrics, +metric_handler = MTTransformerMetricHandler(metrics=gnmt_estimator.train_metrics, grad_interval=1) bleu_handler = ComputeBleuHandler(tgt_vocab=tgt_vocab, tgt_sentence=val_tgt_sentences, translator=translator, compute_bleu_fn=compute_bleu) -val_bleu_handler = ValBleuHandler(tgt_vocab=tgt_vocab, tgt_sentence=val_tgt_sentence, +val_bleu_handler = ValBleuHandler(val_data=val_data_loader, + val_tgt_vocab=tgt_vocab, val_tgt_sentences=val_tgt_sentences, translator=translator, compute_bleu_fn=compute_bleu) checkpoint_handler = MTCheckpointHandler(model_dir=args.save_dir) +val_metric_handler = MTTransformerMetricHandler(metrics=gnmt_estimator.val_metrics) + +val_validation_handler = MTValidationHandler(val_data=val_data_loader, + eval_fn=gnmt_estimator.evaluate, + event_handlers=val_metric_handler) + event_handlers = [learning_rate_handler, gradient_update_handler, metric_handler, val_bleu_handler, checkpoint_handler] diff --git a/src/gluonnlp/estimator/machine_translation_batch_processor.py b/src/gluonnlp/estimator/machine_translation_batch_processor.py index 26bc36a6be..f5456a8948 100644 --- a/src/gluonnlp/estimator/machine_translation_batch_processor.py +++ b/src/gluonnlp/estimator/machine_translation_batch_processor.py @@ -78,30 +78,33 @@ def evaluate_batch(self, estimator, val_batch, batch_axis=0): return src_seq, [tgt_seq, val_tgt_valid_length], out, loss class MTGNMTBatchProcessor(BatchProcessor): - def __init__(): + def __init__(self): pass def fit_batch(self, estimator, train_batch, batch_axis=0): - src_seq, tgt_seql, src_valid_length, tgt_valid_length = train_batch - src_seq = src_seq.as_in_context(estimator.context) - tgt_seq = tgt_seq.as_in_context(estimator.context) - src_valid_length = src_valid_length.as_in_context(estimator.context) - tgt_valid_lenght = tgt_valid_length.as_in_context(estimator.context) + ctx = estimator.context[0] + src_seq, tgt_seq, src_valid_length, tgt_valid_length = train_batch + src_seq = src_seq.as_in_context(ctx) + tgt_seq = tgt_seq.as_in_context(ctx) + src_valid_length = src_valid_length.as_in_context(ctx) + tgt_valid_lenght = tgt_valid_length.as_in_context(ctx) with mx.autograd.record(): out, _ = estimator.net(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) loss = estimator.loss(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean() loss = loss * (tgt_seq.shape[1] - 1) + log_loss = loss * tgt_seq.shape[0] loss = loss / (tgt_valid_length - 1).mean() loss.backward() - return src_seq, [tgt_seq, (tgt_valid_length - 1).sum()], out, loss * tgt_seq.shape[0] + return src_seq, [tgt_seq, (tgt_valid_length - 1).sum()], out, log_loss def evaluate_batch(self, estimator, val_batch, batch_axis=0): + ctx = estimator.context[0] src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids = val_batch - src_seq = src_seq.as_in_context(estimator.context) - tgt_seq = tgt_seq.as_in_context(estimator.context) - src_valid_length = src_valid_length.as_in_context(estimator.context) - tgt_valid_length = tgt_valid_length.as_in_context(estimator.context) + src_seq = src_seq.as_in_context(ctx) + tgt_seq = tgt_seq.as_in_context(ctx) + src_valid_length = src_valid_length.as_in_context(ctx) + tgt_valid_length = tgt_valid_length.as_in_context(ctx) out, _ = estimator.eval_net(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) loss = estimator.evaluation_loss(out, tgt_seq[:, 1:], diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py index 33e822543f..e93a7e4a85 100644 --- a/src/gluonnlp/estimator/machine_translation_event_handler.py +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -111,7 +111,8 @@ def __init__(self, clip): self.clip = clip def batch_end(self, estimator, *args, **kwargs): - grads = [p.grad(ctx) for p in estimator.net.collect_params().values()] + grads = [p.grad(estimator.context[0]) + for p in estimator.net.collect_params().values()] gnorm = gluon.utils.clip_global_norm(grads, self.clip) estimator.trainer.step(1) From b4a695c2076865cd6618ab2c4050979605171138 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Fri, 17 Jan 2020 10:03:26 +0000 Subject: [PATCH 12/26] fix test data bugs --- .../machine_translation/train_gnmt_estimator.py | 15 ++++++++++++--- .../train_transformer_estimator.py | 12 +++++++++++- .../machine_translation_event_handler.py | 5 +++-- 3 files changed, 26 insertions(+), 6 deletions(-) diff --git a/scripts/machine_translation/train_gnmt_estimator.py b/scripts/machine_translation/train_gnmt_estimator.py index 55b8867e3b..e1b15ea878 100644 --- a/scripts/machine_translation/train_gnmt_estimator.py +++ b/scripts/machine_translation/train_gnmt_estimator.py @@ -174,6 +174,9 @@ bleu_handler = ComputeBleuHandler(tgt_vocab=tgt_vocab, tgt_sentence=val_tgt_sentences, translator=translator, compute_bleu_fn=compute_bleu) +test_bleu_handler = ComputeBleuHandler(tgt_vocab=tgt_vocab, tgt_sentence=test_tgt_sentences, + translator=translator, compute_bleu_fn=compute_bleu) + val_bleu_handler = ValBleuHandler(val_data=val_data_loader, val_tgt_vocab=tgt_vocab, val_tgt_sentences=val_tgt_sentences, translator=translator, compute_bleu_fn=compute_bleu) @@ -187,11 +190,17 @@ event_handlers=val_metric_handler) event_handlers = [learning_rate_handler, gradient_update_handler, metric_handler, - val_bleu_handler, checkpoint_handler] + val_bleu_handler, checkpoint_handler, val_validation_handler] gnmt_estimator.fit(train_data=train_data_loader, val_data=val_data_loader, - epochs=args.epochs, - #batches=5, + #epochs=args.epochs, + batches=5, event_handlers=event_handlers, batch_axis=0) + +val_event_handlers = [val_metric_handler, bleu_handler] +test_event_handlers = [val_metric_handler, test_bleu_handler] + +gnmt_estimator.evaluate(val_data=val_data_loader, event_handlers=val_event_handlers) +gnmt_estimator.evaluate(val_data=test_data_loader, event_handlers=test_event_handlers) diff --git a/scripts/machine_translation/train_transformer_estimator.py b/scripts/machine_translation/train_transformer_estimator.py index 395f22eb9d..01d74d518b 100644 --- a/scripts/machine_translation/train_transformer_estimator.py +++ b/scripts/machine_translation/train_transformer_estimator.py @@ -266,6 +266,13 @@ bpe=bpe, bleu=args.bleu, detokenizer=detokenizer, _bpe_to_words=_bpe_to_words) +test_bleu_handler = ComputeBleuHandler(tgt_vocab=tgt_vocab, tgt_sentence=test_tgt_sentences, + translator=translator, compute_bleu_fn=compute_bleu, + tokenized=tokenized, tokenizer=args.bleu, + split_compound_word=split_compound_word, + bpe=bpe, bleu=args.bleu, detokenizer=detokenizer, + _bpe_to_words=_bpe_to_words) + val_bleu_handler = ValBleuHandler(val_data=val_data_loader, val_tgt_vocab=tgt_vocab, val_tgt_sentences=val_tgt_sentences, translator=translator, tokenized=tokenized, tokenizer=args.bleu, @@ -308,6 +315,9 @@ val_event_handlers = [val_metric_handler, bleu_handler] +test_event_handlers = [val_metric_handler, + test_bleu_handler] + mt_estimator.evaluate(val_data=val_data_loader, event_handlers=val_event_handlers) -mt_estimator.evaluate(val_data=test_data_loader, event_handlers=val_event_handlers) +mt_estimator.evaluate(val_data=test_data_loader, event_handlers=test_event_handlers) diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py index e93a7e4a85..7f87e6e753 100644 --- a/src/gluonnlp/estimator/machine_translation_event_handler.py +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -207,8 +207,9 @@ def epoch_end(self, estimator, *args, **kwargs): def train_end(self, estimator, *args, **kwargs): ctx = estimator.context - save_path = os.path.join(self.model_dir, 'average.params') - mx.nd.save(save_path, estimator.avg_param) + if estimator.avg_param is not None: + save_path = os.path.join(self.model_dir, 'average.params') + mx.nd.save(save_path, estimator.avg_param) if self.average_checkpoint: for j in range(args.num_averages): params = mx.nd.load(os.path.join(self.model_dir, From 458b6f78bc7efb650eca21182bc71a675903619a Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Sun, 19 Jan 2020 05:23:09 +0000 Subject: [PATCH 13/26] fix typo --- src/gluonnlp/estimator/machine_translation_event_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py index 7f87e6e753..93da5202a8 100644 --- a/src/gluonnlp/estimator/machine_translation_event_handler.py +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -345,7 +345,7 @@ def epoch_end(self, estimator, *args, **kwargs): for ind, sentence in zip(all_inst_ids, translation_out): if self.bleu == 'tweaked': real_translation_out[ind] = sentence - elif self.bleu == '13a' or self.beu == 'intl': + elif self.bleu == '13a' or self.bleu == 'intl': real_translation_out[ind] = self.detokenizer(self._bpe_to_words(sentence)) else: raise NotImplementedError From 406acbe6bc1484125745b44f01e9f04bdc5313a0 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Wed, 22 Jan 2020 06:06:49 +0000 Subject: [PATCH 14/26] fix gnmt estimator bugs --- .../train_gnmt_estimator.py | 19 +++++++++++++------ .../train_transformer_estimator.py | 1 - .../machine_translation_batch_processor.py | 2 +- .../machine_translation_event_handler.py | 6 ++++-- 4 files changed, 18 insertions(+), 10 deletions(-) diff --git a/scripts/machine_translation/train_gnmt_estimator.py b/scripts/machine_translation/train_gnmt_estimator.py index e1b15ea878..01fb69f8dd 100644 --- a/scripts/machine_translation/train_gnmt_estimator.py +++ b/scripts/machine_translation/train_gnmt_estimator.py @@ -56,6 +56,7 @@ from gluonnlp.estimator import MTTransformerMetricHandler, MTGNMTLearningRateHandler from gluonnlp.estimator import MTCheckpointHandler, MTTransformerMetricHandler from gluonnlp.estimator import MTValidationHandler +from mxnet.gluon.contrib.estimator import LoggingHandler np.random.seed(100) random.seed(100) @@ -172,14 +173,17 @@ grad_interval=1) bleu_handler = ComputeBleuHandler(tgt_vocab=tgt_vocab, tgt_sentence=val_tgt_sentences, - translator=translator, compute_bleu_fn=compute_bleu) + translator=translator, compute_bleu_fn=compute_bleu, + bleu='tweaked') test_bleu_handler = ComputeBleuHandler(tgt_vocab=tgt_vocab, tgt_sentence=test_tgt_sentences, - translator=translator, compute_bleu_fn=compute_bleu) + translator=translator, compute_bleu_fn=compute_bleu, + bleu='tweaked') val_bleu_handler = ValBleuHandler(val_data=val_data_loader, val_tgt_vocab=tgt_vocab, val_tgt_sentences=val_tgt_sentences, - translator=translator, compute_bleu_fn=compute_bleu) + translator=translator, compute_bleu_fn=compute_bleu, + bleu='tweaked') checkpoint_handler = MTCheckpointHandler(model_dir=args.save_dir) @@ -189,13 +193,16 @@ eval_fn=gnmt_estimator.evaluate, event_handlers=val_metric_handler) +logging_handler = LoggingHandler(log_interval=args.log_interval, + metrics=gnmt_estimator.train_metrics) + event_handlers = [learning_rate_handler, gradient_update_handler, metric_handler, - val_bleu_handler, checkpoint_handler, val_validation_handler] + val_bleu_handler, checkpoint_handler, val_validation_handler, logging_handler] gnmt_estimator.fit(train_data=train_data_loader, val_data=val_data_loader, - #epochs=args.epochs, - batches=5, + epochs=args.epochs, + #batches=5, event_handlers=event_handlers, batch_axis=0) diff --git a/scripts/machine_translation/train_transformer_estimator.py b/scripts/machine_translation/train_transformer_estimator.py index 01d74d518b..78ff187881 100644 --- a/scripts/machine_translation/train_transformer_estimator.py +++ b/scripts/machine_translation/train_transformer_estimator.py @@ -308,7 +308,6 @@ mt_estimator.fit(train_data=train_data_loader, val_data=val_data_loader, epochs=args.epochs, - #batches=200, event_handlers=event_handlers, batch_axis=0) diff --git a/src/gluonnlp/estimator/machine_translation_batch_processor.py b/src/gluonnlp/estimator/machine_translation_batch_processor.py index f5456a8948..818f71c67f 100644 --- a/src/gluonnlp/estimator/machine_translation_batch_processor.py +++ b/src/gluonnlp/estimator/machine_translation_batch_processor.py @@ -87,7 +87,7 @@ def fit_batch(self, estimator, train_batch, batch_axis=0): src_seq = src_seq.as_in_context(ctx) tgt_seq = tgt_seq.as_in_context(ctx) src_valid_length = src_valid_length.as_in_context(ctx) - tgt_valid_lenght = tgt_valid_length.as_in_context(ctx) + tgt_valid_length = tgt_valid_length.as_in_context(ctx) with mx.autograd.record(): out, _ = estimator.net(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py index 93da5202a8..6fc89a7c88 100644 --- a/src/gluonnlp/estimator/machine_translation_event_handler.py +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -242,7 +242,7 @@ def __init__(self, tokenizer='13a', split_compound_word=False, bpe=False, - bleu='tweaked', + bleu='13a', detokenizer=None, _bpe_to_words=None): self.tgt_vocab = tgt_vocab @@ -294,6 +294,7 @@ def epoch_end(self, estimator, *args, **kwargs): tokenizer=self.tokenizer, split_compound_word=self.split_compound_word, bpe=self.bpe) + print(estimator.bleu_score) # temporary validation bleu metric hack, it can be removed once bleu metric api is available @@ -307,7 +308,7 @@ def __init__(self, val_data, tokenizer='13a', split_compound_word=False, bpe=False, - bleu='tweaked', + bleu='13a', detokenizer=None, _bpe_to_words=None): self.val_data = val_data @@ -355,6 +356,7 @@ def epoch_end(self, estimator, *args, **kwargs): tokenizer=self.tokenizer, split_compound_word=self.split_compound_word, bpe=self.bpe) + print(estimator.bleu_score) class MTTransformerLoggingHandler(LoggingHandler): def __init__(self, *args, **kwargs): From 02f78197190fc17756c846f47e08dfe08672b704 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Wed, 22 Jan 2020 07:52:20 +0000 Subject: [PATCH 15/26] change variable names for the latext mxnet build --- .../machine_translation/train_transformer_estimator.py | 2 +- .../estimator/machine_translation_batch_processor.py | 8 ++++---- src/gluonnlp/estimator/machine_translation_estimator.py | 8 ++++---- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/scripts/machine_translation/train_transformer_estimator.py b/scripts/machine_translation/train_transformer_estimator.py index 78ff187881..4349ae2c69 100644 --- a/scripts/machine_translation/train_transformer_estimator.py +++ b/scripts/machine_translation/train_transformer_estimator.py @@ -246,7 +246,7 @@ val_metrics=val_metric, trainer=trainer, context=ctx, - evaluation_loss=test_loss_function, + val_loss=test_loss_function, batch_processor=batch_processor) param_update_handler = MTTransformerParamUpdateHandler(avg_start=average_start, diff --git a/src/gluonnlp/estimator/machine_translation_batch_processor.py b/src/gluonnlp/estimator/machine_translation_batch_processor.py index 818f71c67f..fcc3476d1e 100644 --- a/src/gluonnlp/estimator/machine_translation_batch_processor.py +++ b/src/gluonnlp/estimator/machine_translation_batch_processor.py @@ -70,8 +70,8 @@ def evaluate_batch(self, estimator, val_batch, batch_axis=0): src_valid_length = src_valid_length.as_in_context(ctx) tgt_valid_length = tgt_valid_length.as_in_context(ctx) - out, _ = estimator.eval_net(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) - loss = estimator.evaluation_loss(out, tgt_seq[:, 1:], tgt_valid_length - 1).sum().asscalar() + out, _ = estimator.val_net(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) + loss = estimator.val_loss(out, tgt_seq[:, 1:], tgt_valid_length - 1).sum().asscalar() inst_ids = inst_ids.asnumpy().astype(np.int32).tolist() loss = loss * (tgt_seq.shape[1] - 1) val_tgt_valid_length = (tgt_valid_length - 1).sum().asscalar() @@ -105,9 +105,9 @@ def evaluate_batch(self, estimator, val_batch, batch_axis=0): tgt_seq = tgt_seq.as_in_context(ctx) src_valid_length = src_valid_length.as_in_context(ctx) tgt_valid_length = tgt_valid_length.as_in_context(ctx) - out, _ = estimator.eval_net(src_seq, tgt_seq[:, :-1], src_valid_length, + out, _ = estimator.val_net(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) - loss = estimator.evaluation_loss(out, tgt_seq[:, 1:], + loss = estimator.val_loss(out, tgt_seq[:, 1:], tgt_valid_length - 1).sum().asscalar() loss = loss * (tgt_seq.shape[1] - 1) val_tgt_valid_length = (tgt_valid_length - 1).sum().asscalar() diff --git a/src/gluonnlp/estimator/machine_translation_estimator.py b/src/gluonnlp/estimator/machine_translation_estimator.py index 8c6c63698f..4b0a9fcde4 100644 --- a/src/gluonnlp/estimator/machine_translation_estimator.py +++ b/src/gluonnlp/estimator/machine_translation_estimator.py @@ -34,8 +34,8 @@ def __init__(self, net, loss, initializer=None, trainer=None, context=None, - evaluation_loss=None, - eval_net=None, + val_loss=None, + val_net=None, batch_processor=MTTransformerBatchProcessor()): super().__init__(net=net, loss=loss, train_metrics=train_metrics, @@ -43,8 +43,8 @@ def __init__(self, net, loss, initializer=initializer, trainer=trainer, context=context, - evaluation_loss=evaluation_loss, - eval_net=eval_net, + val_loss=val_loss, + val_net=val_net, batch_processor=batch_processor) self.tgt_valid_length = 0 self.val_tgt_valid_length = 0 From 740c71242058f9bef62ffdeefe809d55f1e68cec Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Thu, 13 Feb 2020 12:46:55 +0000 Subject: [PATCH 16/26] remove temporary length normalized loss --- .../train_transformer_estimator.py | 14 ++-- .../estimator/length_normalized_loss.py | 77 ------------------- .../machine_translation_event_handler.py | 22 +----- 3 files changed, 12 insertions(+), 101 deletions(-) delete mode 100644 src/gluonnlp/estimator/length_normalized_loss.py diff --git a/scripts/machine_translation/train_transformer_estimator.py b/scripts/machine_translation/train_transformer_estimator.py index 4349ae2c69..16216f315f 100644 --- a/scripts/machine_translation/train_transformer_estimator.py +++ b/scripts/machine_translation/train_transformer_estimator.py @@ -42,8 +42,8 @@ import numpy as np import mxnet as mx from mxnet import gluon - import gluonnlp as nlp + from gluonnlp.loss import LabelSmoothing, MaskedSoftmaxCELoss from gluonnlp.model.transformer import ParallelTransformer, get_transformer_encoder_decoder from gluonnlp.model.translation import NMTModel @@ -52,12 +52,14 @@ from bleu import _bpe_to_words, compute_bleu from translation import BeamSearchTranslator from utils import logging_config -from gluonnlp.estimator import MachineTranslationEstimator, LengthNormalizedLoss +from gluonnlp.metric import LengthNormalizedLoss +from gluonnlp.estimator import MachineTranslationEstimator from gluonnlp.estimator import MTTransformerBatchProcessor, MTTransformerParamUpdateHandler from gluonnlp.estimator import TransformerLearningRateHandler, MTTransformerMetricHandler from gluonnlp.estimator import TransformerGradientAccumulationHandler, ComputeBleuHandler from gluonnlp.estimator import ValBleuHandler, MTCheckpointHandler -from gluonnlp.estimator import MTTransformerLoggingHandler, MTValidationHandler +from gluonnlp.estimator import MTTransformerLoggingHandler +from mxnet.gluon.contrib.estimator import ValidationHandler np.random.seed(100) random.seed(100) @@ -288,9 +290,9 @@ val_metric_handler = MTTransformerMetricHandler(metrics=mt_estimator.val_metrics) -val_validation_handler = MTValidationHandler(val_data=val_data_loader, - eval_fn=mt_estimator.evaluate, - event_handlers=val_metric_handler) +val_validation_handler = ValidationHandler(val_data=val_data_loader, + eval_fn=mt_estimator.evaluate, + event_handlers=val_metric_handler) log_interval = args.log_interval * grad_interval logging_handler = MTTransformerLoggingHandler(log_interval=log_interval, diff --git a/src/gluonnlp/estimator/length_normalized_loss.py b/src/gluonnlp/estimator/length_normalized_loss.py deleted file mode 100644 index e4558c6fb1..0000000000 --- a/src/gluonnlp/estimator/length_normalized_loss.py +++ /dev/null @@ -1,77 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -""" Length Normalized Loss """ - -from mxnet import ndarray -from mxnet.metric import EvalMetric - -__all__ = ['LengthNormalizedLoss'] - -class LengthNormalizedLoss(EvalMetric): - """Compute length normalized loss metrics - - Parameters - ---------- - axis : int, default=1 - The axis that represents classes - name : str - Name of this metric instance for display. - output_names : list of str, or None - Name of predictions that should be used when updating with update_dict. - By default include all predictions. - label_names : list of str, or None - Name of labels that should be used when updating with update_dict. - By default include all labels. - """ - def __init__(self, axis=0, name='length-normalized-loss', - output_names=None, label_names=None): - super(LengthNormalizedLoss, self).__init__( - name, axis=axis, - output_names=output_names, label_names=label_names, - has_global_stats=True) - - # Parameter labels should be a list in the form of [target_sequence, - # target_seqauence_valid_length] - def update(self, labels, preds): - if not isinstance(labels, list) or len(labels) != 2: - raise ValueError('labels must be a list. Its first element should be' - ' target sequence and the second element should be' - 'the valid length of sequence.') - - _, seq_valid_length = labels - - if not isinstance(seq_valid_length, list): - seq_valid_length = [seq_valid_length] - - if not isinstance(preds, list): - preds = [preds] - - for length in seq_valid_length: - if isinstance(length, ndarray.ndarray.NDArray): - total_length = ndarray.sum(length).asscalar() - else: - total_length = length - self.num_inst += total_length - self.global_num_inst += total_length - - for pred in preds: - if isinstance(pred, ndarray.ndarray.NDArray): - loss = ndarray.sum(pred).asscalar() - else: - loss = pred - self.sum_metric += loss - self.global_sum_metric += loss diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py index 6fc89a7c88..6e70814b3d 100644 --- a/src/gluonnlp/estimator/machine_translation_event_handler.py +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -32,13 +32,13 @@ from mxnet.gluon.contrib.estimator import ValidationHandler from mxnet import gluon from mxnet.metric import Loss as MetricLoss -from .length_normalized_loss import LengthNormalizedLoss +from ..metric.length_normalized_loss import LengthNormalizedLoss __all__ = ['MTTransformerParamUpdateHandler', 'TransformerLearningRateHandler', 'MTTransformerMetricHandler', 'TransformerGradientAccumulationHandler', 'ComputeBleuHandler', 'ValBleuHandler', 'MTGNMTGradientUpdateHandler', 'MTGNMTLearningRateHandler', 'MTCheckpointHandler', - 'MTTransformerLoggingHandler', 'MTValidationHandler'] + 'MTTransformerLoggingHandler'] class MTTransformerParamUpdateHandler(EpochBegin, BatchEnd, EpochEnd): def __init__(self, avg_start, grad_interval=1): @@ -154,6 +154,8 @@ def epoch_end(self, estimator, *args, **kwargs): if self.loss_denom > 0: self._update_gradient(estimator) +"""TODO: merge this event handler with metricresethandler for language model +""" class MTTransformerMetricHandler(MetricHandler, BatchBegin): def __init__(self, *args, grad_interval=None, **kwargs): super(MTTransformerMetricHandler, self).__init__(*args, **kwargs) @@ -379,19 +381,3 @@ def batch_end(self, estimator, *args, **kwargs): msg += '%s: %.4f, ' % (name, val) estimator.logger.info(msg.rstrip(', ')) self.batch_index += 1 - -# TODO: change the mxnet validation_handler to include event_handlers -class MTValidationHandler(ValidationHandler): - def __init__(self, event_handlers, *args, **kwargs): - super(MTValidationHandler, self).__init__(*args, **kwargs) - self.event_handlers = event_handlers - - def batch_end(self, estimator, *args, **kwargs): - self.current_batch += 1 - if self.batch_period and self.current_batch % self.batch_period == 0: - self.eval_fn(val_data=self.val_data, event_handlers=self.event_handlers) - - def epoch_end(self, estimator, *args, **kwargs): - self.current_epoch += 1 - if self.epoch_period and self.current_epoch % self.epoch_period == 0: - self.eval_fn(val_data=self.val_data, event_handlers=self.event_handlers) From feef52e223580ae0987fe2e10e3f21718b879056 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Thu, 13 Feb 2020 13:05:14 +0000 Subject: [PATCH 17/26] update index.rst --- scripts/machine_translation/index.rst | 4 ++-- src/gluonnlp/estimator/__init__.py | 4 +--- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/scripts/machine_translation/index.rst b/scripts/machine_translation/index.rst index a228ee24ed..9824909677 100644 --- a/scripts/machine_translation/index.rst +++ b/scripts/machine_translation/index.rst @@ -10,7 +10,7 @@ Use the following command to train the GNMT model on the IWSLT2015 dataset. .. code-block:: console - $ MXNET_GPU_MEM_POOL_TYPE=Round python train_gnmt.py --src_lang en --tgt_lang vi --batch_size 128 \ + $ MXNET_GPU_MEM_POOL_TYPE=Round python train_gnmt_estimator.py --src_lang en --tgt_lang vi --batch_size 128 \ --optimizer adam --lr 0.001 --lr_update_factor 0.5 --beam_size 10 --bucket_scheme exp \ --num_hidden 512 --save_dir gnmt_en_vi_l2_h512_beam10 --epochs 12 --gpu 0 @@ -23,7 +23,7 @@ Use the following commands to train the Transformer model on the WMT14 dataset f .. code-block:: console - $ MXNET_GPU_MEM_POOL_TYPE=Round python train_transformer.py --dataset WMT2014BPE \ + $ MXNET_GPU_MEM_POOL_TYPE=Round python train_transformer_estimator.py --dataset WMT2014BPE \ --src_lang en --tgt_lang de --batch_size 2700 \ --optimizer adam --num_accumulated 16 --lr 2.0 --warmup_steps 4000 \ --save_dir transformer_en_de_u512 --epochs 30 --gpus 0,1,2,3,4,5,6,7 --scaled \ diff --git a/src/gluonnlp/estimator/__init__.py b/src/gluonnlp/estimator/__init__.py index 12c5b14769..3dca823d2c 100644 --- a/src/gluonnlp/estimator/__init__.py +++ b/src/gluonnlp/estimator/__init__.py @@ -20,8 +20,6 @@ from .machine_translation_estimator import * from .machine_translation_event_handler import * from .machine_translation_batch_processor import * -from .length_normalized_loss import * __all__ = (machine_translation_estimator.__all__ + machine_translation_event_handler.__all__ - + machine_translation_batch_processor.__all__ - + length_normalized_loss.__all__) + + machine_translation_batch_processor.__all__) From 4e13742c5a9cead5008c0e16e3beb86e8fddb35f Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Thu, 13 Feb 2020 13:14:18 +0000 Subject: [PATCH 18/26] fix import in gnmt estimator --- scripts/machine_translation/train_gnmt_estimator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/machine_translation/train_gnmt_estimator.py b/scripts/machine_translation/train_gnmt_estimator.py index 01fb69f8dd..f7c9629ca6 100644 --- a/scripts/machine_translation/train_gnmt_estimator.py +++ b/scripts/machine_translation/train_gnmt_estimator.py @@ -50,13 +50,13 @@ from utils import logging_config from bleu import compute_bleu import dataprocessor -from gluonnlp.estimator import MachineTranslationEstimator, LengthNormalizedLoss +from gluonnlp.metric import LengthNormalizedLoss +from gluonnlp.estimator import MachineTranslationEstimator from gluonnlp.estimator import MTGNMTBatchProcessor, MTGNMTGradientUpdateHandler from gluonnlp.estimator import ComputeBleuHandler, ValBleuHandler from gluonnlp.estimator import MTTransformerMetricHandler, MTGNMTLearningRateHandler from gluonnlp.estimator import MTCheckpointHandler, MTTransformerMetricHandler -from gluonnlp.estimator import MTValidationHandler -from mxnet.gluon.contrib.estimator import LoggingHandler +from mxnet.gluon.contrib.estimator import LoggingHandler, ValidationHandler np.random.seed(100) random.seed(100) @@ -189,7 +189,7 @@ val_metric_handler = MTTransformerMetricHandler(metrics=gnmt_estimator.val_metrics) -val_validation_handler = MTValidationHandler(val_data=val_data_loader, +val_validation_handler = ValidationHandler(val_data=val_data_loader, eval_fn=gnmt_estimator.evaluate, event_handlers=val_metric_handler) From bfa8425ce1ec6341854847e67eb4f8a58902e679 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Fri, 14 Feb 2020 15:21:18 +0000 Subject: [PATCH 19/26] fix pylint errors and update docstrings --- .../machine_translation_batch_processor.py | 38 ++++- .../machine_translation_estimator.py | 33 +++- .../machine_translation_event_handler.py | 154 ++++++++++++++---- 3 files changed, 184 insertions(+), 41 deletions(-) diff --git a/src/gluonnlp/estimator/machine_translation_batch_processor.py b/src/gluonnlp/estimator/machine_translation_batch_processor.py index fcc3476d1e..d02b07f3ed 100644 --- a/src/gluonnlp/estimator/machine_translation_batch_processor.py +++ b/src/gluonnlp/estimator/machine_translation_batch_processor.py @@ -20,17 +20,32 @@ import numpy as np import mxnet as mx from mxnet.gluon.contrib.estimator import BatchProcessor -from mxnet.gluon.utils import split_and_load from ..model.transformer import ParallelTransformer from ..utils.parallel import Parallel __all__ = ['MTTransformerBatchProcessor', 'MTGNMTBatchProcessor'] class MTTransformerBatchProcessor(BatchProcessor): + '''Batch processor for transformer training on Machine translation + + The batch training and validation procedure on transformer network + + Parameters + ---------- + rescale_loss : int + normalization constant for loss computation + batch_size : int + number of tokens per gpu in a minibatch + label_smoothing : HybridBlock + Apply label smoothing on the given network + loss_function : mxnet.gluon.loss + training loss function + ''' def __init__(self, rescale_loss=100, batch_size=1024, label_smoothing=None, loss_function=None): + super(MTTransformerBatchProcessor, self).__init__() self.rescale_loss = rescale_loss self.batch_size = batch_size self.label_smoothing = label_smoothing @@ -49,9 +64,12 @@ def fit_batch(self, estimator, train_batch, batch_axis=0): self._get_parallel_model(estimator) data = [shard[0] for shard in train_batch] target = [shard[1] for shard in train_batch] - src_word_count, tgt_word_count, bs = np.sum([(shard[2].sum(), - shard[3].sum(), shard[0].shape[0]) for shard in train_batch], - axis=0) + _, tgt_word_count, bs = np.sum([(shard[2].sum(), + shard[3].sum(), + shard[0].shape[0]) + for shard in + train_batch], + axis=0) estimator.tgt_valid_length = tgt_word_count.asscalar() - bs seqs = [[seq.as_in_context(context) for seq in shard] for context, shard in zip(estimator.context, train_batch)] @@ -78,8 +96,12 @@ def evaluate_batch(self, estimator, val_batch, batch_axis=0): return src_seq, [tgt_seq, val_tgt_valid_length], out, loss class MTGNMTBatchProcessor(BatchProcessor): + '''Batch processor for GNMT training + + Batch training and validation on the GNMT network for the machine translation task. + ''' def __init__(self): - pass + super(MTGNMTBatchProcess, self).__init__() def fit_batch(self, estimator, train_batch, batch_axis=0): ctx = estimator.context[0] @@ -100,15 +122,15 @@ def fit_batch(self, estimator, train_batch, batch_axis=0): def evaluate_batch(self, estimator, val_batch, batch_axis=0): ctx = estimator.context[0] - src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids = val_batch + src_seq, tgt_seq, src_valid_length, tgt_valid_length, _ = val_batch src_seq = src_seq.as_in_context(ctx) tgt_seq = tgt_seq.as_in_context(ctx) src_valid_length = src_valid_length.as_in_context(ctx) tgt_valid_length = tgt_valid_length.as_in_context(ctx) out, _ = estimator.val_net(src_seq, tgt_seq[:, :-1], src_valid_length, - tgt_valid_length - 1) + tgt_valid_length - 1) loss = estimator.val_loss(out, tgt_seq[:, 1:], - tgt_valid_length - 1).sum().asscalar() + tgt_valid_length - 1).sum().asscalar() loss = loss * (tgt_seq.shape[1] - 1) val_tgt_valid_length = (tgt_valid_length - 1).sum().asscalar() return src_seq, [tgt_seq, val_tgt_valid_length], out, loss diff --git a/src/gluonnlp/estimator/machine_translation_estimator.py b/src/gluonnlp/estimator/machine_translation_estimator.py index 4b0a9fcde4..eb824ad91a 100644 --- a/src/gluonnlp/estimator/machine_translation_estimator.py +++ b/src/gluonnlp/estimator/machine_translation_estimator.py @@ -17,17 +17,40 @@ # pylint: disable=eval-used, redefined-outer-name """ Gluon Machine Translation Estimator """ -import copy -import warnings - -import numpy as np -import mxnet as mx from mxnet.gluon.contrib.estimator import Estimator from .machine_translation_batch_processor import MTTransformerBatchProcessor __all__ = ['MachineTranslationEstimator'] class MachineTranslationEstimator(Estimator): + '''Estimator class for machine translation tasks + + Facilitates training and validation on machine translation tasks + Parameters + ---------- + net : gluon.Block + The model used for training. + loss : gluon.loss.Loss + Loss (objective) function to calculate during training. + train_metrics : EvalMetric or list of EvalMetric + Training metrics for evaluating models on training dataset. + val_metrics : EvalMetric or list of EvalMetric + Validation metrics for evaluating models on validation dataset. + initializer : Initializer + Initializer to initialize the network. + trainer : Trainer + Trainer to apply optimizer on network parameters. + context : Context or list of Context + Device(s) to run the training on. + val_net : gluon.Block + The model used for validation. The validation model does not necessarily belong to + the same model class as the training model. + val_loss : gluon.loss.loss + Loss (objective) function to calculate during validation. If set val_loss + None, it will use the same loss function as self.loss + batch_processor: BatchProcessor + BatchProcessor provides customized fit_batch() and evaluate_batch() methods + ''' def __init__(self, net, loss, train_metrics=None, val_metrics=None, diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py index 6e70814b3d..8d6f95e1e4 100644 --- a/src/gluonnlp/estimator/machine_translation_event_handler.py +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -17,8 +17,6 @@ # pylint: disable=eval-used, redefined-outer-name """ Gluon Machine Translation Event Handler """ -import copy -import warnings import math import os import time @@ -29,7 +27,6 @@ from mxnet.gluon.contrib.estimator import EpochEnd, BatchBegin, BatchEnd from mxnet.gluon.contrib.estimator import GradientUpdateHandler, CheckpointHandler from mxnet.gluon.contrib.estimator import MetricHandler, LoggingHandler -from mxnet.gluon.contrib.estimator import ValidationHandler from mxnet import gluon from mxnet.metric import Loss as MetricLoss from ..metric.length_normalized_loss import LengthNormalizedLoss @@ -41,6 +38,17 @@ 'MTTransformerLoggingHandler'] class MTTransformerParamUpdateHandler(EpochBegin, BatchEnd, EpochEnd): + '''Transformer average parameter update handler + + Update weighted average parameters of the transformer during training + + Parameters + ---------- + avg_start : int + the starting epoch of performing average sgd update + grad_interval : int + The interval of update avarege model parameters + ''' def __init__(self, avg_start, grad_interval=1): self.batch_id = 0 self.grad_interval = grad_interval @@ -59,7 +67,7 @@ def _update_avg_param(self, estimator): def epoch_begin(self, estimator, *args, **kwargs): self.batch_id = 0 - + def batch_end(self, estimator, *args, **kwargs): if self.batch_id % self.grad_interval == 0: self.step_num += 1 @@ -71,6 +79,17 @@ def epoch_end(self, estimator, *args, **kwargs): self._update_avg_param(estimator) class MTGNMTLearningRateHandler(EpochEnd): + '''GNMT learning rate update handler + + dynamically adjust the learning rate during GNMT training + + Parameters + ---------- + epochs : int + total number of epoches for GNMT training + lr_update_factor : float + the decaying factor of learning rate + ''' def __init__(self, epochs, lr_update_factor): self.epoch_id = 0 self.epochs = epochs @@ -83,6 +102,21 @@ def epoch_end(self, estimator, *args, **kwargs): self.epoch_id += 1 class TransformerLearningRateHandler(EpochBegin, BatchBegin): + '''Transformer learning rate update handler + + dynamically adjust the learning rate during transformer training + + Parameters + ---------- + lr : float + initial learning rate for transformer training + num_units : int + dimension of the embedding vector + warmup_steps : int + number of warmup steps used in training schedule + grad_interval : int + the interval of updating learning rate + ''' def __init__(self, lr, num_units=512, warmup_steps=4000, @@ -106,6 +140,16 @@ def batch_begin(self, estimator, *args, **kwargs): self.batch_id += 1 class MTGNMTGradientUpdateHandler(GradientUpdateHandler): + '''Gradient update handler of GNMT training + + clip gradient if gradient norm exceeds some threshold during GNMT training + + Parameters + ---------- + clip : float + gradient norm threshold. If gradient norm exceeds this value, it should be + scaled down to the valid range. + ''' def __init__(self, clip): super(MTGNMTGradientUpdateHandler, self).__init__() self.clip = clip @@ -113,13 +157,27 @@ def __init__(self, clip): def batch_end(self, estimator, *args, **kwargs): grads = [p.grad(estimator.context[0]) for p in estimator.net.collect_params().values()] - gnorm = gluon.utils.clip_global_norm(grads, self.clip) + gluon.utils.clip_global_norm(grads, self.clip) estimator.trainer.step(1) class TransformerGradientAccumulationHandler(GradientUpdateHandler, TrainBegin, EpochBegin, EpochEnd): + '''Gradient accumulation handler for transformer training + + Accumulates gradients of the network for a few iterations, and updates + network parameters with the accumulated gradients + + Parameters + ---------- + grad_interval : int + the interval of updating gradients + batch_size : int + number of tokens per gpu in a minibatch + rescale_loss : float + normalization constant + ''' def __init__(self, grad_interval=1, batch_size=1024, rescale_loss=100): @@ -154,10 +212,19 @@ def epoch_end(self, estimator, *args, **kwargs): if self.loss_denom > 0: self._update_gradient(estimator) -"""TODO: merge this event handler with metricresethandler for language model -""" class MTTransformerMetricHandler(MetricHandler, BatchBegin): - def __init__(self, *args, grad_interval=None, **kwargs): + '''Metric update handler for transformer training + + Reset the local metric stats for every few iterations and include the LengthNormalizedLoss + for metrics update + TODO : Refactor this event handler and share it with other estimators + + Parameters + ---------- + grad_interval : int + interval of resetting local metrics during transformer training + ''' + def __init__(self, *args, grad_interval=None, **kwargs): super(MTTransformerMetricHandler, self).__init__(*args, **kwargs) self.grad_interval = grad_interval @@ -185,6 +252,22 @@ def batch_end(self, estimator, *args, **kwargs): metric.update(label, pred) class MTCheckpointHandler(CheckpointHandler, TrainEnd): + '''Checkpoint handler for machine translation tasks training + + save model parameter checkpoint and average parameter checkpoint during transformer + or GNMT training + + Parameters + ---------- + average_checkpoint : bool + whether store the average parameters of last few iterations + num_averages : int + number of last few model checkpoints to be averaged + average_start : int + performing average sgd on last average_start epochs + epochs : int + total epochs of machine translation model training + ''' def __init__(self, *args, average_checkpoint=None, num_averages=None, @@ -219,7 +302,7 @@ def train_end(self, estimator, *args, **kwargs): alpha = 1. / (j + 1) for k, v in estimator.net._collect_params_with_prefix().items(): for c in ctx: - v.data(c)[:] == alpha * (params[k].as_in_context(c) - v.data(c)) + v.data(c)[:] = alpha * (params[k].as_in_context(c) - v.data(c)) save_path = os.path.join(self.model_dir, 'average_checkpoint_{}.params'.format(self.num_averages)) estimator.net.save_parameters(save_path) @@ -232,9 +315,14 @@ def train_end(self, estimator, *args, **kwargs): estimator.net.load_parameters(os.path.join(self.model_dir, 'valid_best.params'), ctx) -# A temporary workaround for computing the bleu function. After bleu is in the metric -# api, this event handler could be removed. + class ComputeBleuHandler(BatchEnd, EpochEnd): + '''Bleu score computation handler + + this event handler serves as a temporary workaround for computing Bleu score for + estimator training. + TODO: please remove this event handler after bleu metrics is merged to api + ''' def __init__(self, tgt_vocab, tgt_sentence, @@ -265,7 +353,6 @@ def __init__(self, def batch_end(self, estimator, *args, **kwargs): ctx = estimator.context[0] batch = kwargs['batch'] - label = kwargs['label'] src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids = batch src_seq = src_seq.as_in_context(ctx) tgt_seq = tgt_seq.as_in_context(ctx) @@ -280,7 +367,7 @@ def batch_end(self, estimator, *args, **kwargs): self.translation_out.append( [self.tgt_vocab.idx_to_token[ele] for ele in max_score_sample[i][1:(sample_valid_length[i] - 1)]]) - + def epoch_end(self, estimator, *args, **kwargs): real_translation_out = [None for _ in range(len(self.all_inst_ids))] for ind, sentence in zip(self.all_inst_ids, self.translation_out): @@ -290,17 +377,22 @@ def epoch_end(self, estimator, *args, **kwargs): real_translation_out[ind] = self.detokenizer(self._bpe_to_words(sentence)) else: raise NotImplementedError - estimator.bleu_score, _, _, _, _ = self.compute_bleu_fn([self.tgt_sentence], - real_translation_out, - tokenized=self.tokenized, - tokenizer=self.tokenizer, - split_compound_word=self.split_compound_word, - bpe=self.bpe) + estimator.bleu_score, _, _, _, _ = \ + self.compute_bleu_fn([self.tgt_sentence], + real_translation_out, + tokenized=self.tokenized, + tokenizer=self.tokenizer, + split_compound_word=self.split_compound_word, + bpe=self.bpe) print(estimator.bleu_score) - -# temporary validation bleu metric hack, it can be removed once bleu metric api is available class ValBleuHandler(EpochEnd): + '''Handler of validation Bleu score computation + + This handler is similar to the ComputeBleuHandler. It computes the Bleu score on the + validation dataset + TODO: please remove this event handler after bleu metric is available in the api + ''' def __init__(self, val_data, val_tgt_vocab, val_tgt_sentences, @@ -352,15 +444,21 @@ def epoch_end(self, estimator, *args, **kwargs): real_translation_out[ind] = self.detokenizer(self._bpe_to_words(sentence)) else: raise NotImplementedError - estimator.bleu_score, _, _, _, _ = self.compute_bleu_fn([self.val_tgt_sentences], - real_translation_out, - tokenized=self.tokenized, - tokenizer=self.tokenizer, - split_compound_word=self.split_compound_word, - bpe=self.bpe) + estimator.bleu_score, _, _, _, _ = \ + self.compute_bleu_fn([self.val_tgt_sentences], + real_translation_out, + tokenized=self.tokenized, + tokenizer=self.tokenizer, + split_compound_word=self.split_compound_word, + bpe=self.bpe) print(estimator.bleu_score) class MTTransformerLoggingHandler(LoggingHandler): + '''Logging handler for transformer training + + Logging the training metrics for transformer training. This handler is introduced + due to batch cannot be handled by default LoggingHandler + ''' def __init__(self, *args, **kwargs): super(MTTransformerLoggingHandler, self).__init__(*args, **kwargs) @@ -372,7 +470,7 @@ def batch_end(self, estimator, *args, **kwargs): for batch in cur_batches: self.processed_samples += batch[0].shape[0] msg += '[Samples %s]' % (self.processed_samples) - self.log_interval_time += batch_time + self.log_interval_time += batch_time if self.batch_index % self.log_interval == 0: msg += 'time/interval: %.3fs ' % self.log_interval_time self.log_interval_time = 0 From b6faf295c78e808e6c2968272d03fd0016e290da Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Fri, 14 Feb 2020 15:24:33 +0000 Subject: [PATCH 20/26] fix typo --- src/gluonnlp/estimator/machine_translation_batch_processor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gluonnlp/estimator/machine_translation_batch_processor.py b/src/gluonnlp/estimator/machine_translation_batch_processor.py index d02b07f3ed..7b2bdbcd50 100644 --- a/src/gluonnlp/estimator/machine_translation_batch_processor.py +++ b/src/gluonnlp/estimator/machine_translation_batch_processor.py @@ -101,7 +101,7 @@ class MTGNMTBatchProcessor(BatchProcessor): Batch training and validation on the GNMT network for the machine translation task. ''' def __init__(self): - super(MTGNMTBatchProcess, self).__init__() + super(MTGNMTBatchProcessor, self).__init__() def fit_batch(self, estimator, train_batch, batch_axis=0): ctx = estimator.context[0] From e2dd1bdf839255573a4912e62e7992f54ab8f1a0 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Fri, 14 Feb 2020 15:42:12 +0000 Subject: [PATCH 21/26] fix docstring errors --- .../machine_translation_batch_processor.py | 8 ++-- .../machine_translation_estimator.py | 4 +- .../machine_translation_event_handler.py | 48 +++++++++---------- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/src/gluonnlp/estimator/machine_translation_batch_processor.py b/src/gluonnlp/estimator/machine_translation_batch_processor.py index 7b2bdbcd50..89cb81ffae 100644 --- a/src/gluonnlp/estimator/machine_translation_batch_processor.py +++ b/src/gluonnlp/estimator/machine_translation_batch_processor.py @@ -26,7 +26,7 @@ __all__ = ['MTTransformerBatchProcessor', 'MTGNMTBatchProcessor'] class MTTransformerBatchProcessor(BatchProcessor): - '''Batch processor for transformer training on Machine translation + """Batch processor for transformer training on Machine translation The batch training and validation procedure on transformer network @@ -40,7 +40,7 @@ class MTTransformerBatchProcessor(BatchProcessor): Apply label smoothing on the given network loss_function : mxnet.gluon.loss training loss function - ''' + """ def __init__(self, rescale_loss=100, batch_size=1024, label_smoothing=None, @@ -96,10 +96,10 @@ def evaluate_batch(self, estimator, val_batch, batch_axis=0): return src_seq, [tgt_seq, val_tgt_valid_length], out, loss class MTGNMTBatchProcessor(BatchProcessor): - '''Batch processor for GNMT training + """Batch processor for GNMT training Batch training and validation on the GNMT network for the machine translation task. - ''' + """ def __init__(self): super(MTGNMTBatchProcessor, self).__init__() diff --git a/src/gluonnlp/estimator/machine_translation_estimator.py b/src/gluonnlp/estimator/machine_translation_estimator.py index eb824ad91a..334d6061dd 100644 --- a/src/gluonnlp/estimator/machine_translation_estimator.py +++ b/src/gluonnlp/estimator/machine_translation_estimator.py @@ -23,7 +23,7 @@ __all__ = ['MachineTranslationEstimator'] class MachineTranslationEstimator(Estimator): - '''Estimator class for machine translation tasks + """Estimator class for machine translation tasks Facilitates training and validation on machine translation tasks Parameters @@ -50,7 +50,7 @@ class MachineTranslationEstimator(Estimator): None, it will use the same loss function as self.loss batch_processor: BatchProcessor BatchProcessor provides customized fit_batch() and evaluate_batch() methods - ''' + """ def __init__(self, net, loss, train_metrics=None, val_metrics=None, diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py index 8d6f95e1e4..fd51ac47a4 100644 --- a/src/gluonnlp/estimator/machine_translation_event_handler.py +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -38,7 +38,7 @@ 'MTTransformerLoggingHandler'] class MTTransformerParamUpdateHandler(EpochBegin, BatchEnd, EpochEnd): - '''Transformer average parameter update handler + """Transformer average parameter update handler Update weighted average parameters of the transformer during training @@ -48,7 +48,7 @@ class MTTransformerParamUpdateHandler(EpochBegin, BatchEnd, EpochEnd): the starting epoch of performing average sgd update grad_interval : int The interval of update avarege model parameters - ''' + """ def __init__(self, avg_start, grad_interval=1): self.batch_id = 0 self.grad_interval = grad_interval @@ -79,7 +79,7 @@ def epoch_end(self, estimator, *args, **kwargs): self._update_avg_param(estimator) class MTGNMTLearningRateHandler(EpochEnd): - '''GNMT learning rate update handler + """GNMT learning rate update handler dynamically adjust the learning rate during GNMT training @@ -89,7 +89,7 @@ class MTGNMTLearningRateHandler(EpochEnd): total number of epoches for GNMT training lr_update_factor : float the decaying factor of learning rate - ''' + """ def __init__(self, epochs, lr_update_factor): self.epoch_id = 0 self.epochs = epochs @@ -102,7 +102,7 @@ def epoch_end(self, estimator, *args, **kwargs): self.epoch_id += 1 class TransformerLearningRateHandler(EpochBegin, BatchBegin): - '''Transformer learning rate update handler + """Transformer learning rate update handler dynamically adjust the learning rate during transformer training @@ -116,7 +116,7 @@ class TransformerLearningRateHandler(EpochBegin, BatchBegin): number of warmup steps used in training schedule grad_interval : int the interval of updating learning rate - ''' + """ def __init__(self, lr, num_units=512, warmup_steps=4000, @@ -140,7 +140,7 @@ def batch_begin(self, estimator, *args, **kwargs): self.batch_id += 1 class MTGNMTGradientUpdateHandler(GradientUpdateHandler): - '''Gradient update handler of GNMT training + """Gradient update handler of GNMT training clip gradient if gradient norm exceeds some threshold during GNMT training @@ -149,7 +149,7 @@ class MTGNMTGradientUpdateHandler(GradientUpdateHandler): clip : float gradient norm threshold. If gradient norm exceeds this value, it should be scaled down to the valid range. - ''' + """ def __init__(self, clip): super(MTGNMTGradientUpdateHandler, self).__init__() self.clip = clip @@ -164,9 +164,9 @@ class TransformerGradientAccumulationHandler(GradientUpdateHandler, TrainBegin, EpochBegin, EpochEnd): - '''Gradient accumulation handler for transformer training + """Gradient accumulation handler for transformer training - Accumulates gradients of the network for a few iterations, and updates + Accumulates gradients of the network for a few iterations, and updates network parameters with the accumulated gradients Parameters @@ -177,7 +177,7 @@ class TransformerGradientAccumulationHandler(GradientUpdateHandler, number of tokens per gpu in a minibatch rescale_loss : float normalization constant - ''' + """ def __init__(self, grad_interval=1, batch_size=1024, rescale_loss=100): @@ -213,7 +213,7 @@ def epoch_end(self, estimator, *args, **kwargs): self._update_gradient(estimator) class MTTransformerMetricHandler(MetricHandler, BatchBegin): - '''Metric update handler for transformer training + """Metric update handler for transformer training Reset the local metric stats for every few iterations and include the LengthNormalizedLoss for metrics update @@ -223,7 +223,7 @@ class MTTransformerMetricHandler(MetricHandler, BatchBegin): ---------- grad_interval : int interval of resetting local metrics during transformer training - ''' + """ def __init__(self, *args, grad_interval=None, **kwargs): super(MTTransformerMetricHandler, self).__init__(*args, **kwargs) self.grad_interval = grad_interval @@ -252,7 +252,7 @@ def batch_end(self, estimator, *args, **kwargs): metric.update(label, pred) class MTCheckpointHandler(CheckpointHandler, TrainEnd): - '''Checkpoint handler for machine translation tasks training + """Checkpoint handler for machine translation tasks training save model parameter checkpoint and average parameter checkpoint during transformer or GNMT training @@ -267,7 +267,7 @@ class MTCheckpointHandler(CheckpointHandler, TrainEnd): performing average sgd on last average_start epochs epochs : int total epochs of machine translation model training - ''' + """ def __init__(self, *args, average_checkpoint=None, num_averages=None, @@ -317,12 +317,12 @@ def train_end(self, estimator, *args, **kwargs): class ComputeBleuHandler(BatchEnd, EpochEnd): - '''Bleu score computation handler + """Bleu score computation handler this event handler serves as a temporary workaround for computing Bleu score for estimator training. TODO: please remove this event handler after bleu metrics is merged to api - ''' + """ def __init__(self, tgt_vocab, tgt_sentence, @@ -387,12 +387,12 @@ def epoch_end(self, estimator, *args, **kwargs): print(estimator.bleu_score) class ValBleuHandler(EpochEnd): - '''Handler of validation Bleu score computation + """Handler of validation Bleu score computation - This handler is similar to the ComputeBleuHandler. It computes the Bleu score on the + This handler is similar to the ComputeBleuHandler. It computes the Bleu score on the validation dataset TODO: please remove this event handler after bleu metric is available in the api - ''' + """ def __init__(self, val_data, val_tgt_vocab, val_tgt_sentences, @@ -454,11 +454,11 @@ def epoch_end(self, estimator, *args, **kwargs): print(estimator.bleu_score) class MTTransformerLoggingHandler(LoggingHandler): - '''Logging handler for transformer training + """Logging handler for transformer training - Logging the training metrics for transformer training. This handler is introduced - due to batch cannot be handled by default LoggingHandler - ''' + Logging the training metrics for transformer training. This handler is introduced + due to batch cannot be handled by default LoggingHandler + """ def __init__(self, *args, **kwargs): super(MTTransformerLoggingHandler, self).__init__(*args, **kwargs) From cc9665cc3079ac2c1d44288f1259478c6487dd44 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Fri, 14 Feb 2020 17:05:31 +0000 Subject: [PATCH 22/26] fix typo --- src/gluonnlp/estimator/machine_translation_event_handler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/gluonnlp/estimator/machine_translation_event_handler.py b/src/gluonnlp/estimator/machine_translation_event_handler.py index fd51ac47a4..8885f066c9 100644 --- a/src/gluonnlp/estimator/machine_translation_event_handler.py +++ b/src/gluonnlp/estimator/machine_translation_event_handler.py @@ -296,7 +296,7 @@ def train_end(self, estimator, *args, **kwargs): save_path = os.path.join(self.model_dir, 'average.params') mx.nd.save(save_path, estimator.avg_param) if self.average_checkpoint: - for j in range(args.num_averages): + for j in range(self.num_averages): params = mx.nd.load(os.path.join(self.model_dir, 'epoch{:d}.params'.format(self.epochs - j - 1))) alpha = 1. / (j + 1) From d7b51fac0835af3103a2a1489783ef423ca0804b Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Fri, 14 Feb 2020 17:16:33 +0000 Subject: [PATCH 23/26] fix init file lint error --- src/gluonnlp/estimator/__init__.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/gluonnlp/estimator/__init__.py b/src/gluonnlp/estimator/__init__.py index 3dca823d2c..1672dff82b 100644 --- a/src/gluonnlp/estimator/__init__.py +++ b/src/gluonnlp/estimator/__init__.py @@ -17,6 +17,9 @@ # pylint: disable=eval-used, redefined-outer-name """ Gluon NLP Estimator Module """ +from . import machine_translation_estimator, machine_translation_event_handler +from . import machine_translation_batch_processor + from .machine_translation_estimator import * from .machine_translation_event_handler import * from .machine_translation_batch_processor import * From 448934b45b7ff601ee64834fda02ddeb024cfa6c Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Mon, 17 Feb 2020 05:03:29 +0000 Subject: [PATCH 24/26] resolve import lint errors --- .../train_gnmt_estimator.py | 22 +++++++++---------- .../train_transformer_estimator.py | 7 ++---- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/scripts/machine_translation/train_gnmt_estimator.py b/scripts/machine_translation/train_gnmt_estimator.py index f7c9629ca6..766319034d 100644 --- a/scripts/machine_translation/train_gnmt_estimator.py +++ b/scripts/machine_translation/train_gnmt_estimator.py @@ -34,29 +34,29 @@ # pylint:disable=redefined-outer-name,logging-format-interpolation import argparse -import time import random import os import logging import numpy as np import mxnet as mx from mxnet import gluon -import gluonnlp as nlp +from mxnet.gluon.contrib.estimator import LoggingHandler, ValidationHandler +import gluonnlp as nlp from gluonnlp.model.translation import NMTModel from gluonnlp.loss import MaskedSoftmaxCELoss -from gnmt import get_gnmt_encoder_decoder -from translation import BeamSearchTranslator -from utils import logging_config -from bleu import compute_bleu -import dataprocessor from gluonnlp.metric import LengthNormalizedLoss from gluonnlp.estimator import MachineTranslationEstimator from gluonnlp.estimator import MTGNMTBatchProcessor, MTGNMTGradientUpdateHandler from gluonnlp.estimator import ComputeBleuHandler, ValBleuHandler from gluonnlp.estimator import MTTransformerMetricHandler, MTGNMTLearningRateHandler -from gluonnlp.estimator import MTCheckpointHandler, MTTransformerMetricHandler -from mxnet.gluon.contrib.estimator import LoggingHandler, ValidationHandler +from gluonnlp.estimator import MTCheckpointHandler + +from gnmt import get_gnmt_encoder_decoder +from translation import BeamSearchTranslator +from utils import logging_config +from bleu import compute_bleu +import dataprocessor np.random.seed(100) random.seed(100) @@ -190,8 +190,8 @@ val_metric_handler = MTTransformerMetricHandler(metrics=gnmt_estimator.val_metrics) val_validation_handler = ValidationHandler(val_data=val_data_loader, - eval_fn=gnmt_estimator.evaluate, - event_handlers=val_metric_handler) + eval_fn=gnmt_estimator.evaluate, + event_handlers=val_metric_handler) logging_handler = LoggingHandler(log_interval=args.log_interval, metrics=gnmt_estimator.train_metrics) diff --git a/scripts/machine_translation/train_transformer_estimator.py b/scripts/machine_translation/train_transformer_estimator.py index 16216f315f..b46abba85e 100644 --- a/scripts/machine_translation/train_transformer_estimator.py +++ b/scripts/machine_translation/train_transformer_estimator.py @@ -34,20 +34,18 @@ import argparse import logging -import math import os import random -import time import numpy as np import mxnet as mx from mxnet import gluon -import gluonnlp as nlp +from mxnet.gluon.contrib.estimator import ValidationHandler +import gluonnlp as nlp from gluonnlp.loss import LabelSmoothing, MaskedSoftmaxCELoss from gluonnlp.model.transformer import ParallelTransformer, get_transformer_encoder_decoder from gluonnlp.model.translation import NMTModel -from gluonnlp.utils.parallel import Parallel import dataprocessor from bleu import _bpe_to_words, compute_bleu from translation import BeamSearchTranslator @@ -59,7 +57,6 @@ from gluonnlp.estimator import TransformerGradientAccumulationHandler, ComputeBleuHandler from gluonnlp.estimator import ValBleuHandler, MTCheckpointHandler from gluonnlp.estimator import MTTransformerLoggingHandler -from mxnet.gluon.contrib.estimator import ValidationHandler np.random.seed(100) random.seed(100) From 7faf836b6acc43a827d1a6c8fcddf98b43f07a76 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Mon, 17 Feb 2020 05:34:43 +0000 Subject: [PATCH 25/26] refine imports --- .../machine_translation/train_transformer_estimator.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scripts/machine_translation/train_transformer_estimator.py b/scripts/machine_translation/train_transformer_estimator.py index b46abba85e..ea3d1c3570 100644 --- a/scripts/machine_translation/train_transformer_estimator.py +++ b/scripts/machine_translation/train_transformer_estimator.py @@ -46,10 +46,6 @@ from gluonnlp.loss import LabelSmoothing, MaskedSoftmaxCELoss from gluonnlp.model.transformer import ParallelTransformer, get_transformer_encoder_decoder from gluonnlp.model.translation import NMTModel -import dataprocessor -from bleu import _bpe_to_words, compute_bleu -from translation import BeamSearchTranslator -from utils import logging_config from gluonnlp.metric import LengthNormalizedLoss from gluonnlp.estimator import MachineTranslationEstimator from gluonnlp.estimator import MTTransformerBatchProcessor, MTTransformerParamUpdateHandler @@ -58,6 +54,11 @@ from gluonnlp.estimator import ValBleuHandler, MTCheckpointHandler from gluonnlp.estimator import MTTransformerLoggingHandler +import dataprocessor +from bleu import _bpe_to_words, compute_bleu +from translation import BeamSearchTranslator +from utils import logging_config + np.random.seed(100) random.seed(100) mx.random.seed(10000) From eab2ef7ba4dbf4c15367ce775e402e1dc9579e99 Mon Sep 17 00:00:00 2001 From: Zhuanghua Liu Date: Mon, 17 Feb 2020 09:00:42 +0000 Subject: [PATCH 26/26] disable pylint errors --- scripts/machine_translation/train_gnmt_estimator.py | 3 +-- scripts/machine_translation/train_transformer_estimator.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/scripts/machine_translation/train_gnmt_estimator.py b/scripts/machine_translation/train_gnmt_estimator.py index 766319034d..8a4f152bb0 100644 --- a/scripts/machine_translation/train_gnmt_estimator.py +++ b/scripts/machine_translation/train_gnmt_estimator.py @@ -31,7 +31,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint:disable=redefined-outer-name,logging-format-interpolation +# pylint:disable=redefined-outer-name,logging-format-interpolation,unexpected-keyword-arg import argparse import random @@ -202,7 +202,6 @@ gnmt_estimator.fit(train_data=train_data_loader, val_data=val_data_loader, epochs=args.epochs, - #batches=5, event_handlers=event_handlers, batch_axis=0) diff --git a/scripts/machine_translation/train_transformer_estimator.py b/scripts/machine_translation/train_transformer_estimator.py index ea3d1c3570..ebea489773 100644 --- a/scripts/machine_translation/train_transformer_estimator.py +++ b/scripts/machine_translation/train_transformer_estimator.py @@ -30,7 +30,7 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -# pylint:disable=redefined-outer-name,logging-format-interpolation +# pylint:disable=redefined-outer-name,logging-format-interpolation,unexpected-keyword-arg import argparse import logging