From 3a3f2c8836995774e7f9810a92b1d394e8210e4e Mon Sep 17 00:00:00 2001 From: Xie Date: Tue, 28 May 2019 16:28:55 -0700 Subject: [PATCH 1/5] amp --- .../train_transformer_amp.py | 435 ++++++++++++++++++ src/gluonnlp/model/transformer.py | 9 +- 2 files changed, 443 insertions(+), 1 deletion(-) create mode 100644 scripts/machine_translation/train_transformer_amp.py diff --git a/scripts/machine_translation/train_transformer_amp.py b/scripts/machine_translation/train_transformer_amp.py new file mode 100644 index 0000000000..39e5b636d0 --- /dev/null +++ b/scripts/machine_translation/train_transformer_amp.py @@ -0,0 +1,435 @@ +""" +Transformer +================================= + +This example shows how to implement the Transformer model with Gluon NLP Toolkit. + +@inproceedings{vaswani2017attention, + title={Attention is all you need}, + author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, + Llion and Gomez, Aidan N and Kaiser, Lukasz and Polosukhin, Illia}, + booktitle={Advances in Neural Information Processing Systems}, + pages={6000--6010}, + year={2017} +} +""" + +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name,logging-format-interpolation + +import argparse +import time +import random +import os +import logging +import math +import numpy as np +import mxnet as mx +from mxnet import gluon +import gluonnlp as nlp + +from gluonnlp.model.translation import NMTModel +from gluonnlp.model.transformer import get_transformer_encoder_decoder, ParallelTransformer +from gluonnlp.utils.parallel import Parallel +from translation import BeamSearchTranslator +from loss import SoftmaxCEMaskedLoss, LabelSmoothing +from utils import logging_config +from bleu import _bpe_to_words, compute_bleu +import dataprocessor + +from mxnet.contrib import amp + + +np.random.seed(100) +random.seed(100) +mx.random.seed(10000) + +parser = argparse.ArgumentParser(description='Neural Machine Translation Example.' + 'We train the Transformer Model') +parser.add_argument('--dataset', type=str, default='WMT2016BPE', help='Dataset to use.') +parser.add_argument('--src_lang', type=str, default='en', help='Source language') +parser.add_argument('--tgt_lang', type=str, default='de', help='Target language') +parser.add_argument('--epochs', type=int, default=10, help='upper epoch limit') +parser.add_argument('--num_units', type=int, default=512, help='Dimension of the embedding ' + 'vectors and states.') +parser.add_argument('--hidden_size', type=int, default=2048, + help='Dimension of the hidden state in position-wise feed-forward networks.') +parser.add_argument('--dropout', type=float, default=0.1, + help='dropout applied to layers (0 = no dropout)') +parser.add_argument('--epsilon', type=float, default=0.1, + help='epsilon parameter for label smoothing') +parser.add_argument('--num_layers', type=int, default=6, + help='number of layers in the encoder and decoder') +parser.add_argument('--num_heads', type=int, default=8, + help='number of heads in multi-head attention') +parser.add_argument('--scaled', action='store_true', help='Turn on to use scale in attention') +parser.add_argument('--batch_size', type=int, default=1024, + help='Batch size. Number of tokens per gpu in a minibatch') +parser.add_argument('--beam_size', type=int, default=4, help='Beam size') +parser.add_argument('--lp_alpha', type=float, default=0.6, + help='Alpha used in calculating the length penalty') +parser.add_argument('--lp_k', type=int, default=5, help='K used in calculating the length penalty') +parser.add_argument('--test_batch_size', type=int, default=256, help='Test batch size') +parser.add_argument('--num_buckets', type=int, default=10, help='Bucket number') +parser.add_argument('--bucket_scheme', type=str, default='constant', + help='Strategy for generating bucket keys. It supports: ' + '"constant": all the buckets have the same width; ' + '"linear": the width of bucket increases linearly; ' + '"exp": the width of bucket increases exponentially') +parser.add_argument('--bucket_ratio', type=float, default=0.0, help='Ratio for increasing the ' + 'throughput of the bucketing') +parser.add_argument('--src_max_len', type=int, default=-1, help='Maximum length of the source ' + 'sentence, -1 means no clipping') +parser.add_argument('--tgt_max_len', type=int, default=-1, help='Maximum length of the target ' + 'sentence, -1 means no clipping') +parser.add_argument('--optimizer', type=str, default='adam', help='optimization algorithm') +parser.add_argument('--local_sgd', type=int, default=0, help='the number of local iterations of local SGD') +parser.add_argument('--local_sgd_regularization', type=float, default=0, help='the regularization weight of local SGD') +parser.add_argument('--local_sgd_regularization_interval', type=int, default=0, help='the interval of regularization of local SGD') +parser.add_argument('--lr', type=float, default=1.0, help='Initial learning rate') +parser.add_argument('--warmup_steps', type=float, default=4000, + help='number of warmup steps used in NOAM\'s stepsize schedule') +parser.add_argument('--num_accumulated', type=int, default=1, + help='Number of steps to accumulate the gradients. ' + 'This is useful to mimic large batch training with limited gpu memory') +parser.add_argument('--magnitude', type=float, default=3.0, + help='Magnitude of Xavier initialization') +parser.add_argument('--average_checkpoint', action='store_true', + help='Turn on to perform final testing based on ' + 'the average of last few checkpoints') +parser.add_argument('--num_averages', type=int, default=5, + help='Perform final testing based on the ' + 'average of last num_averages checkpoints. ' + 'This is only used if average_checkpoint is True') +parser.add_argument('--average_start', type=int, default=5, + help='Perform average SGD on last average_start epochs') +parser.add_argument('--full', action='store_true', + help='In default, we use the test dataset in' + ' http://statmt.org/wmt14/test-filtered.tgz.' + ' When the option full is turned on, we use the test dataset in' + ' http://statmt.org/wmt14/test-full.tgz') +parser.add_argument('--bleu', type=str, default='tweaked', + help='Schemes for computing bleu score. It can be: ' + '"tweaked": it uses similar steps in get_ende_bleu.sh in tensor2tensor ' + 'repository, where compound words are put in ATAT format; ' + '"13a": This uses official WMT tokenization and produces the same results' + ' as official script (mteval-v13a.pl) used by WMT; ' + '"intl": This use international tokenization in mteval-v14a.pl') +parser.add_argument('--log_interval', type=int, default=100, metavar='N', + help='report interval') +parser.add_argument('--save_dir', type=str, default='transformer_out', + help='directory path to save the final model and training log') +parser.add_argument('--gpus', type=str, + help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu.' + '(using single gpu is suggested)') +args = parser.parse_args() +logging_config(args.save_dir) +logging.info(args) + +amp.init() + +data_train, data_val, data_test, val_tgt_sentences, test_tgt_sentences, src_vocab, tgt_vocab \ + = dataprocessor.load_translation_data(dataset=args.dataset, bleu=args.bleu, args=args) + +dataprocessor.write_sentences(val_tgt_sentences, os.path.join(args.save_dir, 'val_gt.txt')) +dataprocessor.write_sentences(test_tgt_sentences, os.path.join(args.save_dir, 'test_gt.txt')) + +data_train = data_train.transform(lambda src, tgt: (src, tgt, len(src), len(tgt)), lazy=False) +data_val = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i) + for i, ele in enumerate(data_val)]) +data_test = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i) + for i, ele in enumerate(data_test)]) + +ctx = [mx.cpu()] if args.gpus is None or args.gpus == '' else \ + [mx.gpu(int(x)) for x in args.gpus.split(',')] +num_ctxs = len(ctx) + +data_train_lengths, data_val_lengths, data_test_lengths = [dataprocessor.get_data_lengths(x) + for x in + [data_train, data_val, data_test]] + +if args.src_max_len <= 0 or args.tgt_max_len <= 0: + max_len = np.max( + [np.max(data_train_lengths, axis=0), np.max(data_val_lengths, axis=0), + np.max(data_test_lengths, axis=0)], + axis=0) +if args.src_max_len > 0: + src_max_len = args.src_max_len +else: + src_max_len = max_len[0] +if args.tgt_max_len > 0: + tgt_max_len = args.tgt_max_len +else: + tgt_max_len = max_len[1] +encoder, decoder = get_transformer_encoder_decoder(units=args.num_units, + hidden_size=args.hidden_size, + dropout=args.dropout, + num_layers=args.num_layers, + num_heads=args.num_heads, + max_src_length=max(src_max_len, 500), + max_tgt_length=max(tgt_max_len, 500), + scaled=args.scaled) +model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder, + share_embed=args.dataset != 'TOY', embed_size=args.num_units, + tie_weights=args.dataset != 'TOY', embed_initializer=None, prefix='transformer_') +model.initialize(init=mx.init.Xavier(magnitude=args.magnitude), ctx=ctx) +static_alloc = True +model.hybridize(static_alloc=static_alloc) +# logging.info(model) + +translator = BeamSearchTranslator(model=model, beam_size=args.beam_size, + scorer=nlp.model.BeamSearchScorer(alpha=args.lp_alpha, + K=args.lp_k), + max_length=200) +logging.info('Use beam_size={}, alpha={}, K={}'.format(args.beam_size, args.lp_alpha, args.lp_k)) + +label_smoothing = LabelSmoothing(epsilon=args.epsilon, units=len(tgt_vocab)) +label_smoothing.hybridize(static_alloc=static_alloc) + +loss_function = SoftmaxCEMaskedLoss(sparse_label=False) +loss_function.hybridize(static_alloc=static_alloc) + +test_loss_function = SoftmaxCEMaskedLoss() +test_loss_function.hybridize(static_alloc=static_alloc) + +trainer = gluon.Trainer(model.collect_params(), args.optimizer, + {'learning_rate': args.lr, 'beta2': 0.98, 'epsilon': 1e-9}, local_sgd=args.local_sgd, local_sgd_regularization=args.local_sgd_regularization, local_sgd_regularization_interval=args.local_sgd_regularization_interval) + +rescale_loss = 100 +parallel_model = ParallelTransformer(model, label_smoothing, loss_function, rescale_loss) +# parallel_model = ParallelTransformer(model, label_smoothing, loss_function, rescale_loss, amp=amp, trainer=trainer) +detokenizer = nlp.data.SacreMosesDetokenizer() + + +def evaluate(data_loader, context=ctx[0]): + """Evaluate given the data loader + + Parameters + ---------- + data_loader : DataLoader + + Returns + ------- + avg_loss : float + Average loss + real_translation_out : list of list of str + The translation output + """ + translation_out = [] + all_inst_ids = [] + avg_loss_denom = 0 + avg_loss = 0.0 + for _, (src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids) \ + in enumerate(data_loader): + src_seq = src_seq.as_in_context(context) + tgt_seq = tgt_seq.as_in_context(context) + src_valid_length = src_valid_length.as_in_context(context) + tgt_valid_length = tgt_valid_length.as_in_context(context) + # Calculating Loss + out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) + loss = test_loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean().asscalar() + all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist()) + avg_loss += loss * (tgt_seq.shape[1] - 1) + avg_loss_denom += (tgt_seq.shape[1] - 1) + # Translate + samples, _, sample_valid_length = \ + translator.translate(src_seq=src_seq, src_valid_length=src_valid_length) + max_score_sample = samples[:, 0, :].asnumpy() + sample_valid_length = sample_valid_length[:, 0].asnumpy() + for i in range(max_score_sample.shape[0]): + translation_out.append( + [tgt_vocab.idx_to_token[ele] for ele in + max_score_sample[i][1:(sample_valid_length[i] - 1)]]) + avg_loss = avg_loss / avg_loss_denom + real_translation_out = [None for _ in range(len(all_inst_ids))] + for ind, sentence in zip(all_inst_ids, translation_out): + if args.bleu == 'tweaked': + real_translation_out[ind] = sentence + elif args.bleu == '13a' or args.bleu == 'intl': + real_translation_out[ind] = detokenizer(_bpe_to_words(sentence), + return_str=True) + else: + raise NotImplementedError + return avg_loss, real_translation_out + + +def train(): + """Training function.""" + + train_data_loader, val_data_loader, test_data_loader \ + = dataprocessor.make_dataloader(data_train, data_val, data_test, args, + use_average_length=True, num_shards=len(ctx)) + + if args.bleu == 'tweaked': + bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY') + split_compound_word = bpe + tokenized = True + elif args.bleu == '13a' or args.bleu == 'intl': + bpe = False + split_compound_word = False + tokenized = False + else: + raise NotImplementedError + + best_valid_bleu = 0.0 + step_num = 0 + warmup_steps = args.warmup_steps + grad_interval = args.num_accumulated + model.collect_params().setattr('grad_req', 'add') + average_start = (len(train_data_loader) // grad_interval) * (args.epochs - args.average_start) + average_param_dict = None + model.collect_params().zero_grad() + parallel = Parallel(num_ctxs, parallel_model) + for epoch_id in range(args.epochs): + log_avg_loss = 0 + log_wc = 0 + loss_denom = 0 + step_loss = 0 + log_start_time = time.time() + epoch_start_time = time.time() + for batch_id, seqs \ + in enumerate(train_data_loader): + + # if epoch_id == 0 and batch_id == 100: + # # amp + # mx.nd.waitall() + # logging.info('[Epoch {} Batch {}/{}] Activate amp'.format(epoch_id, batch_id + 1)) + # amp.init_trainer(trainer) + # parallel_model = ParallelTransformer(model, label_smoothing, loss_function, rescale_loss, amp=amp, trainer=trainer) + # parallel = Parallel(num_ctxs, parallel_model) + + if batch_id % grad_interval == 0: + step_num += 1 + new_lr = args.lr / math.sqrt(args.num_units) \ + * min(1. / math.sqrt(step_num), step_num * warmup_steps ** (-1.5)) + trainer.set_learning_rate(new_lr) + src_wc, tgt_wc, bs = np.sum([(shard[2].sum(), shard[3].sum(), shard[0].shape[0]) + for shard in seqs], axis=0) + seqs = [[seq.as_in_context(context) for seq in shard] + for context, shard in zip(ctx, seqs)] + Ls = [] + for seq in seqs: + parallel.put((seq, args.batch_size)) + Ls = [parallel.get() for _ in range(len(ctx))] + src_wc = src_wc.asscalar() + tgt_wc = tgt_wc.asscalar() + loss_denom += tgt_wc - bs + if batch_id % grad_interval == grad_interval - 1 or\ + batch_id == len(train_data_loader) - 1: + if average_param_dict is None: + average_param_dict = {k: v.data(ctx[0]).copy() for k, v in + model.collect_params().items()} + trainer.step(float(loss_denom) / args.batch_size / 100.0) + param_dict = model.collect_params() + param_dict.zero_grad() + if step_num > average_start: + alpha = 1. / max(1, step_num - average_start) + for name, average_param in average_param_dict.items(): + average_param[:] += alpha * (param_dict[name].data(ctx[0]) - average_param) + step_loss += sum([L.asscalar() for L in Ls]) + if batch_id % grad_interval == grad_interval - 1 or\ + batch_id == len(train_data_loader) - 1: + log_avg_loss += step_loss / loss_denom * args.batch_size * 100.0 + loss_denom = 0 + step_loss = 0 + log_wc += src_wc + tgt_wc + if (batch_id + 1) % (args.log_interval * grad_interval) == 0: + wps = log_wc / (time.time() - log_start_time) + logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' + 'throughput={:.2f}K wps, wc={:.2f}K' + .format(epoch_id, batch_id + 1, len(train_data_loader), + log_avg_loss / args.log_interval, + np.exp(log_avg_loss / args.log_interval), + wps / 1000, log_wc / 1000)) + log_start_time = time.time() + log_avg_loss = 0 + log_wc = 0 + if args.local_sgd > 1: + # synchronous model parameters for local sgd + trainer.allreduce_params() + mx.nd.waitall() + logging.info('[Epoch {}] time={:.2f}s'.format(epoch_id, time.time()-epoch_start_time)) + valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) + valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out, + tokenized=tokenized, tokenizer=args.bleu, + split_compound_word=split_compound_word, + bpe=bpe) + logging.info('[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' + .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) + test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) + test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out, + tokenized=tokenized, tokenizer=args.bleu, + split_compound_word=split_compound_word, + bpe=bpe) + logging.info('[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}' + .format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100)) + dataprocessor.write_sentences(valid_translation_out, + os.path.join(args.save_dir, + 'epoch{:d}_valid_out.txt').format(epoch_id)) + dataprocessor.write_sentences(test_translation_out, + os.path.join(args.save_dir, + 'epoch{:d}_test_out.txt').format(epoch_id)) + if valid_bleu_score > best_valid_bleu: + best_valid_bleu = valid_bleu_score + save_path = os.path.join(args.save_dir, 'valid_best.params') + logging.info('Save best parameters to {}'.format(save_path)) + model.save_parameters(save_path) + save_path = os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)) + model.save_parameters(save_path) + save_path = os.path.join(args.save_dir, 'average.params') + mx.nd.save(save_path, average_param_dict) + if args.average_checkpoint: + for j in range(args.num_averages): + params = mx.nd.load(os.path.join(args.save_dir, + 'epoch{:d}.params'.format(args.epochs - j - 1))) + alpha = 1. / (j + 1) + for k, v in model._collect_params_with_prefix().items(): + for c in ctx: + v.data(c)[:] += alpha * (params[k].as_in_context(c) - v.data(c)) + save_path = os.path.join(args.save_dir, + 'average_checkpoint_{}.params'.format(args.num_averages)) + model.save_parameters(save_path) + elif args.average_start > 0: + for k, v in model.collect_params().items(): + v.set_data(average_param_dict[k]) + save_path = os.path.join(args.save_dir, 'average.params') + model.save_parameters(save_path) + else: + model.load_parameters(os.path.join(args.save_dir, 'valid_best.params'), ctx) + valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) + valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out, + tokenized=tokenized, tokenizer=args.bleu, bpe=bpe, + split_compound_word=split_compound_word) + logging.info('Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' + .format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) + test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) + test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out, + tokenized=tokenized, tokenizer=args.bleu, bpe=bpe, + split_compound_word=split_compound_word) + logging.info('Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}' + .format(test_loss, np.exp(test_loss), test_bleu_score * 100)) + dataprocessor.write_sentences(valid_translation_out, + os.path.join(args.save_dir, 'best_valid_out.txt')) + dataprocessor.write_sentences(test_translation_out, + os.path.join(args.save_dir, 'best_test_out.txt')) + + +if __name__ == '__main__': + train() diff --git a/src/gluonnlp/model/transformer.py b/src/gluonnlp/model/transformer.py index 75980440b2..12a83a63c3 100644 --- a/src/gluonnlp/model/transformer.py +++ b/src/gluonnlp/model/transformer.py @@ -1265,11 +1265,13 @@ class ParallelTransformer(Parallelizable): rescale_loss : float The scale to which the loss is rescaled to avoid gradient explosion. """ - def __init__(self, model, label_smoothing, loss_function, rescale_loss): + def __init__(self, model, label_smoothing, loss_function, rescale_loss, amp=None, trainer=None): self._model = model self._label_smoothing = label_smoothing self._loss = loss_function self._rescale_loss = rescale_loss + self._amp = amp + self._trainer = trainer def forward_backward(self, x): """Perform forward and backward computation for a batch of src seq and dst seq""" @@ -1280,5 +1282,10 @@ def forward_backward(self, x): smoothed_label = self._label_smoothing(tgt_seq[:, 1:]) ls = self._loss(out, smoothed_label, tgt_valid_length - 1).sum() ls = (ls * (tgt_seq.shape[1] - 1)) / batch_size / self._rescale_loss + if self._amp is not None: + with self._amp.scale_loss(ls, self._trainer) as scaled_loss: + mx.autograd.backward(scaled_loss) + return scaled_loss ls.backward() return ls + From 8be9c8bd523e1ed62976c7c187e6f26fb181c280 Mon Sep 17 00:00:00 2001 From: Xie Date: Wed, 7 Aug 2019 14:38:00 -0700 Subject: [PATCH 2/5] revert to the original version --- .../train_transformer_amp.py | 435 ------------------ src/gluonnlp/model/transformer.py | 9 +- 2 files changed, 1 insertion(+), 443 deletions(-) delete mode 100644 scripts/machine_translation/train_transformer_amp.py diff --git a/scripts/machine_translation/train_transformer_amp.py b/scripts/machine_translation/train_transformer_amp.py deleted file mode 100644 index 39e5b636d0..0000000000 --- a/scripts/machine_translation/train_transformer_amp.py +++ /dev/null @@ -1,435 +0,0 @@ -""" -Transformer -================================= - -This example shows how to implement the Transformer model with Gluon NLP Toolkit. - -@inproceedings{vaswani2017attention, - title={Attention is all you need}, - author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, - Llion and Gomez, Aidan N and Kaiser, Lukasz and Polosukhin, Illia}, - booktitle={Advances in Neural Information Processing Systems}, - pages={6000--6010}, - year={2017} -} -""" - -# coding: utf-8 - -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -# pylint:disable=redefined-outer-name,logging-format-interpolation - -import argparse -import time -import random -import os -import logging -import math -import numpy as np -import mxnet as mx -from mxnet import gluon -import gluonnlp as nlp - -from gluonnlp.model.translation import NMTModel -from gluonnlp.model.transformer import get_transformer_encoder_decoder, ParallelTransformer -from gluonnlp.utils.parallel import Parallel -from translation import BeamSearchTranslator -from loss import SoftmaxCEMaskedLoss, LabelSmoothing -from utils import logging_config -from bleu import _bpe_to_words, compute_bleu -import dataprocessor - -from mxnet.contrib import amp - - -np.random.seed(100) -random.seed(100) -mx.random.seed(10000) - -parser = argparse.ArgumentParser(description='Neural Machine Translation Example.' - 'We train the Transformer Model') -parser.add_argument('--dataset', type=str, default='WMT2016BPE', help='Dataset to use.') -parser.add_argument('--src_lang', type=str, default='en', help='Source language') -parser.add_argument('--tgt_lang', type=str, default='de', help='Target language') -parser.add_argument('--epochs', type=int, default=10, help='upper epoch limit') -parser.add_argument('--num_units', type=int, default=512, help='Dimension of the embedding ' - 'vectors and states.') -parser.add_argument('--hidden_size', type=int, default=2048, - help='Dimension of the hidden state in position-wise feed-forward networks.') -parser.add_argument('--dropout', type=float, default=0.1, - help='dropout applied to layers (0 = no dropout)') -parser.add_argument('--epsilon', type=float, default=0.1, - help='epsilon parameter for label smoothing') -parser.add_argument('--num_layers', type=int, default=6, - help='number of layers in the encoder and decoder') -parser.add_argument('--num_heads', type=int, default=8, - help='number of heads in multi-head attention') -parser.add_argument('--scaled', action='store_true', help='Turn on to use scale in attention') -parser.add_argument('--batch_size', type=int, default=1024, - help='Batch size. Number of tokens per gpu in a minibatch') -parser.add_argument('--beam_size', type=int, default=4, help='Beam size') -parser.add_argument('--lp_alpha', type=float, default=0.6, - help='Alpha used in calculating the length penalty') -parser.add_argument('--lp_k', type=int, default=5, help='K used in calculating the length penalty') -parser.add_argument('--test_batch_size', type=int, default=256, help='Test batch size') -parser.add_argument('--num_buckets', type=int, default=10, help='Bucket number') -parser.add_argument('--bucket_scheme', type=str, default='constant', - help='Strategy for generating bucket keys. It supports: ' - '"constant": all the buckets have the same width; ' - '"linear": the width of bucket increases linearly; ' - '"exp": the width of bucket increases exponentially') -parser.add_argument('--bucket_ratio', type=float, default=0.0, help='Ratio for increasing the ' - 'throughput of the bucketing') -parser.add_argument('--src_max_len', type=int, default=-1, help='Maximum length of the source ' - 'sentence, -1 means no clipping') -parser.add_argument('--tgt_max_len', type=int, default=-1, help='Maximum length of the target ' - 'sentence, -1 means no clipping') -parser.add_argument('--optimizer', type=str, default='adam', help='optimization algorithm') -parser.add_argument('--local_sgd', type=int, default=0, help='the number of local iterations of local SGD') -parser.add_argument('--local_sgd_regularization', type=float, default=0, help='the regularization weight of local SGD') -parser.add_argument('--local_sgd_regularization_interval', type=int, default=0, help='the interval of regularization of local SGD') -parser.add_argument('--lr', type=float, default=1.0, help='Initial learning rate') -parser.add_argument('--warmup_steps', type=float, default=4000, - help='number of warmup steps used in NOAM\'s stepsize schedule') -parser.add_argument('--num_accumulated', type=int, default=1, - help='Number of steps to accumulate the gradients. ' - 'This is useful to mimic large batch training with limited gpu memory') -parser.add_argument('--magnitude', type=float, default=3.0, - help='Magnitude of Xavier initialization') -parser.add_argument('--average_checkpoint', action='store_true', - help='Turn on to perform final testing based on ' - 'the average of last few checkpoints') -parser.add_argument('--num_averages', type=int, default=5, - help='Perform final testing based on the ' - 'average of last num_averages checkpoints. ' - 'This is only used if average_checkpoint is True') -parser.add_argument('--average_start', type=int, default=5, - help='Perform average SGD on last average_start epochs') -parser.add_argument('--full', action='store_true', - help='In default, we use the test dataset in' - ' http://statmt.org/wmt14/test-filtered.tgz.' - ' When the option full is turned on, we use the test dataset in' - ' http://statmt.org/wmt14/test-full.tgz') -parser.add_argument('--bleu', type=str, default='tweaked', - help='Schemes for computing bleu score. It can be: ' - '"tweaked": it uses similar steps in get_ende_bleu.sh in tensor2tensor ' - 'repository, where compound words are put in ATAT format; ' - '"13a": This uses official WMT tokenization and produces the same results' - ' as official script (mteval-v13a.pl) used by WMT; ' - '"intl": This use international tokenization in mteval-v14a.pl') -parser.add_argument('--log_interval', type=int, default=100, metavar='N', - help='report interval') -parser.add_argument('--save_dir', type=str, default='transformer_out', - help='directory path to save the final model and training log') -parser.add_argument('--gpus', type=str, - help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu.' - '(using single gpu is suggested)') -args = parser.parse_args() -logging_config(args.save_dir) -logging.info(args) - -amp.init() - -data_train, data_val, data_test, val_tgt_sentences, test_tgt_sentences, src_vocab, tgt_vocab \ - = dataprocessor.load_translation_data(dataset=args.dataset, bleu=args.bleu, args=args) - -dataprocessor.write_sentences(val_tgt_sentences, os.path.join(args.save_dir, 'val_gt.txt')) -dataprocessor.write_sentences(test_tgt_sentences, os.path.join(args.save_dir, 'test_gt.txt')) - -data_train = data_train.transform(lambda src, tgt: (src, tgt, len(src), len(tgt)), lazy=False) -data_val = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i) - for i, ele in enumerate(data_val)]) -data_test = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i) - for i, ele in enumerate(data_test)]) - -ctx = [mx.cpu()] if args.gpus is None or args.gpus == '' else \ - [mx.gpu(int(x)) for x in args.gpus.split(',')] -num_ctxs = len(ctx) - -data_train_lengths, data_val_lengths, data_test_lengths = [dataprocessor.get_data_lengths(x) - for x in - [data_train, data_val, data_test]] - -if args.src_max_len <= 0 or args.tgt_max_len <= 0: - max_len = np.max( - [np.max(data_train_lengths, axis=0), np.max(data_val_lengths, axis=0), - np.max(data_test_lengths, axis=0)], - axis=0) -if args.src_max_len > 0: - src_max_len = args.src_max_len -else: - src_max_len = max_len[0] -if args.tgt_max_len > 0: - tgt_max_len = args.tgt_max_len -else: - tgt_max_len = max_len[1] -encoder, decoder = get_transformer_encoder_decoder(units=args.num_units, - hidden_size=args.hidden_size, - dropout=args.dropout, - num_layers=args.num_layers, - num_heads=args.num_heads, - max_src_length=max(src_max_len, 500), - max_tgt_length=max(tgt_max_len, 500), - scaled=args.scaled) -model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder, - share_embed=args.dataset != 'TOY', embed_size=args.num_units, - tie_weights=args.dataset != 'TOY', embed_initializer=None, prefix='transformer_') -model.initialize(init=mx.init.Xavier(magnitude=args.magnitude), ctx=ctx) -static_alloc = True -model.hybridize(static_alloc=static_alloc) -# logging.info(model) - -translator = BeamSearchTranslator(model=model, beam_size=args.beam_size, - scorer=nlp.model.BeamSearchScorer(alpha=args.lp_alpha, - K=args.lp_k), - max_length=200) -logging.info('Use beam_size={}, alpha={}, K={}'.format(args.beam_size, args.lp_alpha, args.lp_k)) - -label_smoothing = LabelSmoothing(epsilon=args.epsilon, units=len(tgt_vocab)) -label_smoothing.hybridize(static_alloc=static_alloc) - -loss_function = SoftmaxCEMaskedLoss(sparse_label=False) -loss_function.hybridize(static_alloc=static_alloc) - -test_loss_function = SoftmaxCEMaskedLoss() -test_loss_function.hybridize(static_alloc=static_alloc) - -trainer = gluon.Trainer(model.collect_params(), args.optimizer, - {'learning_rate': args.lr, 'beta2': 0.98, 'epsilon': 1e-9}, local_sgd=args.local_sgd, local_sgd_regularization=args.local_sgd_regularization, local_sgd_regularization_interval=args.local_sgd_regularization_interval) - -rescale_loss = 100 -parallel_model = ParallelTransformer(model, label_smoothing, loss_function, rescale_loss) -# parallel_model = ParallelTransformer(model, label_smoothing, loss_function, rescale_loss, amp=amp, trainer=trainer) -detokenizer = nlp.data.SacreMosesDetokenizer() - - -def evaluate(data_loader, context=ctx[0]): - """Evaluate given the data loader - - Parameters - ---------- - data_loader : DataLoader - - Returns - ------- - avg_loss : float - Average loss - real_translation_out : list of list of str - The translation output - """ - translation_out = [] - all_inst_ids = [] - avg_loss_denom = 0 - avg_loss = 0.0 - for _, (src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids) \ - in enumerate(data_loader): - src_seq = src_seq.as_in_context(context) - tgt_seq = tgt_seq.as_in_context(context) - src_valid_length = src_valid_length.as_in_context(context) - tgt_valid_length = tgt_valid_length.as_in_context(context) - # Calculating Loss - out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) - loss = test_loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean().asscalar() - all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist()) - avg_loss += loss * (tgt_seq.shape[1] - 1) - avg_loss_denom += (tgt_seq.shape[1] - 1) - # Translate - samples, _, sample_valid_length = \ - translator.translate(src_seq=src_seq, src_valid_length=src_valid_length) - max_score_sample = samples[:, 0, :].asnumpy() - sample_valid_length = sample_valid_length[:, 0].asnumpy() - for i in range(max_score_sample.shape[0]): - translation_out.append( - [tgt_vocab.idx_to_token[ele] for ele in - max_score_sample[i][1:(sample_valid_length[i] - 1)]]) - avg_loss = avg_loss / avg_loss_denom - real_translation_out = [None for _ in range(len(all_inst_ids))] - for ind, sentence in zip(all_inst_ids, translation_out): - if args.bleu == 'tweaked': - real_translation_out[ind] = sentence - elif args.bleu == '13a' or args.bleu == 'intl': - real_translation_out[ind] = detokenizer(_bpe_to_words(sentence), - return_str=True) - else: - raise NotImplementedError - return avg_loss, real_translation_out - - -def train(): - """Training function.""" - - train_data_loader, val_data_loader, test_data_loader \ - = dataprocessor.make_dataloader(data_train, data_val, data_test, args, - use_average_length=True, num_shards=len(ctx)) - - if args.bleu == 'tweaked': - bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY') - split_compound_word = bpe - tokenized = True - elif args.bleu == '13a' or args.bleu == 'intl': - bpe = False - split_compound_word = False - tokenized = False - else: - raise NotImplementedError - - best_valid_bleu = 0.0 - step_num = 0 - warmup_steps = args.warmup_steps - grad_interval = args.num_accumulated - model.collect_params().setattr('grad_req', 'add') - average_start = (len(train_data_loader) // grad_interval) * (args.epochs - args.average_start) - average_param_dict = None - model.collect_params().zero_grad() - parallel = Parallel(num_ctxs, parallel_model) - for epoch_id in range(args.epochs): - log_avg_loss = 0 - log_wc = 0 - loss_denom = 0 - step_loss = 0 - log_start_time = time.time() - epoch_start_time = time.time() - for batch_id, seqs \ - in enumerate(train_data_loader): - - # if epoch_id == 0 and batch_id == 100: - # # amp - # mx.nd.waitall() - # logging.info('[Epoch {} Batch {}/{}] Activate amp'.format(epoch_id, batch_id + 1)) - # amp.init_trainer(trainer) - # parallel_model = ParallelTransformer(model, label_smoothing, loss_function, rescale_loss, amp=amp, trainer=trainer) - # parallel = Parallel(num_ctxs, parallel_model) - - if batch_id % grad_interval == 0: - step_num += 1 - new_lr = args.lr / math.sqrt(args.num_units) \ - * min(1. / math.sqrt(step_num), step_num * warmup_steps ** (-1.5)) - trainer.set_learning_rate(new_lr) - src_wc, tgt_wc, bs = np.sum([(shard[2].sum(), shard[3].sum(), shard[0].shape[0]) - for shard in seqs], axis=0) - seqs = [[seq.as_in_context(context) for seq in shard] - for context, shard in zip(ctx, seqs)] - Ls = [] - for seq in seqs: - parallel.put((seq, args.batch_size)) - Ls = [parallel.get() for _ in range(len(ctx))] - src_wc = src_wc.asscalar() - tgt_wc = tgt_wc.asscalar() - loss_denom += tgt_wc - bs - if batch_id % grad_interval == grad_interval - 1 or\ - batch_id == len(train_data_loader) - 1: - if average_param_dict is None: - average_param_dict = {k: v.data(ctx[0]).copy() for k, v in - model.collect_params().items()} - trainer.step(float(loss_denom) / args.batch_size / 100.0) - param_dict = model.collect_params() - param_dict.zero_grad() - if step_num > average_start: - alpha = 1. / max(1, step_num - average_start) - for name, average_param in average_param_dict.items(): - average_param[:] += alpha * (param_dict[name].data(ctx[0]) - average_param) - step_loss += sum([L.asscalar() for L in Ls]) - if batch_id % grad_interval == grad_interval - 1 or\ - batch_id == len(train_data_loader) - 1: - log_avg_loss += step_loss / loss_denom * args.batch_size * 100.0 - loss_denom = 0 - step_loss = 0 - log_wc += src_wc + tgt_wc - if (batch_id + 1) % (args.log_interval * grad_interval) == 0: - wps = log_wc / (time.time() - log_start_time) - logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' - 'throughput={:.2f}K wps, wc={:.2f}K' - .format(epoch_id, batch_id + 1, len(train_data_loader), - log_avg_loss / args.log_interval, - np.exp(log_avg_loss / args.log_interval), - wps / 1000, log_wc / 1000)) - log_start_time = time.time() - log_avg_loss = 0 - log_wc = 0 - if args.local_sgd > 1: - # synchronous model parameters for local sgd - trainer.allreduce_params() - mx.nd.waitall() - logging.info('[Epoch {}] time={:.2f}s'.format(epoch_id, time.time()-epoch_start_time)) - valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) - valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out, - tokenized=tokenized, tokenizer=args.bleu, - split_compound_word=split_compound_word, - bpe=bpe) - logging.info('[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' - .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) - test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) - test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out, - tokenized=tokenized, tokenizer=args.bleu, - split_compound_word=split_compound_word, - bpe=bpe) - logging.info('[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}' - .format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100)) - dataprocessor.write_sentences(valid_translation_out, - os.path.join(args.save_dir, - 'epoch{:d}_valid_out.txt').format(epoch_id)) - dataprocessor.write_sentences(test_translation_out, - os.path.join(args.save_dir, - 'epoch{:d}_test_out.txt').format(epoch_id)) - if valid_bleu_score > best_valid_bleu: - best_valid_bleu = valid_bleu_score - save_path = os.path.join(args.save_dir, 'valid_best.params') - logging.info('Save best parameters to {}'.format(save_path)) - model.save_parameters(save_path) - save_path = os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)) - model.save_parameters(save_path) - save_path = os.path.join(args.save_dir, 'average.params') - mx.nd.save(save_path, average_param_dict) - if args.average_checkpoint: - for j in range(args.num_averages): - params = mx.nd.load(os.path.join(args.save_dir, - 'epoch{:d}.params'.format(args.epochs - j - 1))) - alpha = 1. / (j + 1) - for k, v in model._collect_params_with_prefix().items(): - for c in ctx: - v.data(c)[:] += alpha * (params[k].as_in_context(c) - v.data(c)) - save_path = os.path.join(args.save_dir, - 'average_checkpoint_{}.params'.format(args.num_averages)) - model.save_parameters(save_path) - elif args.average_start > 0: - for k, v in model.collect_params().items(): - v.set_data(average_param_dict[k]) - save_path = os.path.join(args.save_dir, 'average.params') - model.save_parameters(save_path) - else: - model.load_parameters(os.path.join(args.save_dir, 'valid_best.params'), ctx) - valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) - valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out, - tokenized=tokenized, tokenizer=args.bleu, bpe=bpe, - split_compound_word=split_compound_word) - logging.info('Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' - .format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) - test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) - test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out, - tokenized=tokenized, tokenizer=args.bleu, bpe=bpe, - split_compound_word=split_compound_word) - logging.info('Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}' - .format(test_loss, np.exp(test_loss), test_bleu_score * 100)) - dataprocessor.write_sentences(valid_translation_out, - os.path.join(args.save_dir, 'best_valid_out.txt')) - dataprocessor.write_sentences(test_translation_out, - os.path.join(args.save_dir, 'best_test_out.txt')) - - -if __name__ == '__main__': - train() diff --git a/src/gluonnlp/model/transformer.py b/src/gluonnlp/model/transformer.py index 6f219586be..81f0e2851a 100644 --- a/src/gluonnlp/model/transformer.py +++ b/src/gluonnlp/model/transformer.py @@ -1298,13 +1298,11 @@ class ParallelTransformer(Parallelizable): rescale_loss : float The scale to which the loss is rescaled to avoid gradient explosion. """ - def __init__(self, model, label_smoothing, loss_function, rescale_loss, amp=None, trainer=None): + def __init__(self, model, label_smoothing, loss_function, rescale_loss): self._model = model self._label_smoothing = label_smoothing self._loss = loss_function self._rescale_loss = rescale_loss - self._amp = amp - self._trainer = trainer def forward_backward(self, x): """Perform forward and backward computation for a batch of src seq and dst seq""" @@ -1315,10 +1313,5 @@ def forward_backward(self, x): smoothed_label = self._label_smoothing(tgt_seq[:, 1:]) ls = self._loss(out, smoothed_label, tgt_valid_length - 1).sum() ls = (ls * (tgt_seq.shape[1] - 1)) / batch_size / self._rescale_loss - if self._amp is not None: - with self._amp.scale_loss(ls, self._trainer) as scaled_loss: - mx.autograd.backward(scaled_loss) - return scaled_loss ls.backward() return ls - From bb3813e7bfb09590cfa4edab99f7aa66f5bb9ee3 Mon Sep 17 00:00:00 2001 From: Xie Date: Wed, 7 Aug 2019 15:34:58 -0700 Subject: [PATCH 3/5] local adam for transformer --- .../train_transformer_local_sgd.py | 498 ++++++++++++++++++ .../transformer_local_sgd.py | 338 ++++++++++++ 2 files changed, 836 insertions(+) create mode 100644 scripts/machine_translation/train_transformer_local_sgd.py create mode 100644 scripts/machine_translation/transformer_local_sgd.py diff --git a/scripts/machine_translation/train_transformer_local_sgd.py b/scripts/machine_translation/train_transformer_local_sgd.py new file mode 100644 index 0000000000..f8c449d4aa --- /dev/null +++ b/scripts/machine_translation/train_transformer_local_sgd.py @@ -0,0 +1,498 @@ +""" +Transformer +================================= + +This example shows how to implement the Transformer model with Gluon NLP Toolkit. + +@inproceedings{vaswani2017attention, + title={Attention is all you need}, + author={Vaswani, Ashish and Shazeer, Noam and Parmar, Niki and Uszkoreit, Jakob and Jones, + Llion and Gomez, Aidan N and Kaiser, Lukasz and Polosukhin, Illia}, + booktitle={Advances in Neural Information Processing Systems}, + pages={6000--6010}, + year={2017} +} +""" + +# coding: utf-8 + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint:disable=redefined-outer-name,logging-format-interpolation + +import argparse +import time +import random +import os +import logging +import math +import numpy as np +import mxnet as mx +from mxnet import gluon +import gluonnlp as nlp + +from gluonnlp.loss import MaskedSoftmaxCELoss, LabelSmoothing +from gluonnlp.model.translation import NMTModel +from gluonnlp.model.transformer import get_transformer_encoder_decoder, ParallelTransformer +from gluonnlp.utils.parallel import Parallel +from translation import BeamSearchTranslator + +from utils import logging_config +from bleu import _bpe_to_words, compute_bleu +import dataprocessor + +from transformer_local_sgd import LocalSGDTrainer + +np.random.seed(100) +random.seed(100) +mx.random.seed(10000) + +parser = argparse.ArgumentParser(description='Neural Machine Translation Example.' + 'We train the Transformer Model') +parser.add_argument('--dataset', type=str, default='WMT2016BPE', help='Dataset to use.') +parser.add_argument('--src_lang', type=str, default='en', help='Source language') +parser.add_argument('--tgt_lang', type=str, default='de', help='Target language') +parser.add_argument('--epochs', type=int, default=10, help='upper epoch limit') +parser.add_argument('--num_units', type=int, default=512, help='Dimension of the embedding ' + 'vectors and states.') +parser.add_argument('--hidden_size', type=int, default=2048, + help='Dimension of the hidden state in position-wise feed-forward networks.') +parser.add_argument('--dropout', type=float, default=0.1, + help='dropout applied to layers (0 = no dropout)') +parser.add_argument('--epsilon', type=float, default=0.1, + help='epsilon parameter for label smoothing') +parser.add_argument('--num_layers', type=int, default=6, + help='number of layers in the encoder and decoder') +parser.add_argument('--num_heads', type=int, default=8, + help='number of heads in multi-head attention') +parser.add_argument('--scaled', action='store_true', help='Turn on to use scale in attention') +parser.add_argument('--batch_size', type=int, default=1024, + help='Batch size. Number of tokens per gpu in a minibatch') +parser.add_argument('--beam_size', type=int, default=4, help='Beam size') +parser.add_argument('--lp_alpha', type=float, default=0.6, + help='Alpha used in calculating the length penalty') +parser.add_argument('--lp_k', type=int, default=5, help='K used in calculating the length penalty') +parser.add_argument('--test_batch_size', type=int, default=256, help='Test batch size') +parser.add_argument('--num_buckets', type=int, default=10, help='Bucket number') +parser.add_argument('--bucket_scheme', type=str, default='constant', + help='Strategy for generating bucket keys. It supports: ' + '"constant": all the buckets have the same width; ' + '"linear": the width of bucket increases linearly; ' + '"exp": the width of bucket increases exponentially') +parser.add_argument('--bucket_ratio', type=float, default=0.0, help='Ratio for increasing the ' + 'throughput of the bucketing') +parser.add_argument('--src_max_len', type=int, default=-1, help='Maximum length of the source ' + 'sentence, -1 means no clipping') +parser.add_argument('--tgt_max_len', type=int, default=-1, help='Maximum length of the target ' + 'sentence, -1 means no clipping') +parser.add_argument('--optimizer', type=str, default='adam', help='optimization algorithm') +parser.add_argument('--lr', type=float, default=1.0, help='Initial learning rate') +parser.add_argument('--warmup_steps', type=float, default=4000, + help='number of warmup steps used in NOAM\'s stepsize schedule') +parser.add_argument('--num_accumulated', type=int, default=1, + help='Number of steps to accumulate the gradients. ' + 'This is useful to mimic large batch training with limited gpu memory') +parser.add_argument('--magnitude', type=float, default=3.0, + help='Magnitude of Xavier initialization') +parser.add_argument('--average_checkpoint', action='store_true', + help='Turn on to perform final testing based on ' + 'the average of last few checkpoints') +parser.add_argument('--num_averages', type=int, default=5, + help='Perform final testing based on the ' + 'average of last num_averages checkpoints. ' + 'This is only used if average_checkpoint is True') +parser.add_argument('--average_start', type=int, default=5, + help='Perform average SGD on last average_start epochs') +parser.add_argument('--full', action='store_true', + help='In default, we use the test dataset in' + ' http://statmt.org/wmt14/test-filtered.tgz.' + ' When the option full is turned on, we use the test dataset in' + ' http://statmt.org/wmt14/test-full.tgz') +parser.add_argument('--bleu', type=str, default='tweaked', + help='Schemes for computing bleu score. It can be: ' + '"tweaked": it uses similar steps in get_ende_bleu.sh in tensor2tensor ' + 'repository, where compound words are put in ATAT format; ' + '"13a": This uses official WMT tokenization and produces the same results' + ' as official script (mteval-v13a.pl) used by WMT; ' + '"intl": This use international tokenization in mteval-v14a.pl') +parser.add_argument('--log_interval', type=int, default=100, metavar='N', + help='report interval') +parser.add_argument('--save_dir', type=str, default='transformer_out', + help='directory path to save the final model and training log') +parser.add_argument('--gpus', type=str, + help='list of gpus to run, e.g. 0 or 0,2,5. empty means using cpu.' + '(using single gpu is suggested)') + +parser.add_argument('--local_sgd_interval', type=int, default=0, help='the number of local iterations of local SGD (initial value)') +parser.add_argument('--local_sgd_epochs', type=str, default=None, help='the epoch that local SGD changes') +parser.add_argument('--local_sgd_schedule', type=str, default=None, help='the schedule of local SGD') +parser.add_argument('--local_sgd_regularization', type=float, default=0, help='the regularization weight of local SGD') +parser.add_argument('--local_sgd_regularization_interval', type=int, default=0, help='the interval of regularization of local SGD') +parser.add_argument('--start_epoch', type=int, default=0, help='read from checkpoint') +parser.add_argument('--save_checkpoint', action='store_true', + help='Save checkpoint for each epoch') + +args = parser.parse_args() +logging_config(args.save_dir) +logging.info(args) + + +data_train, data_val, data_test, val_tgt_sentences, test_tgt_sentences, src_vocab, tgt_vocab \ + = dataprocessor.load_translation_data(dataset=args.dataset, bleu=args.bleu, args=args) + +dataprocessor.write_sentences(val_tgt_sentences, os.path.join(args.save_dir, 'val_gt.txt')) +dataprocessor.write_sentences(test_tgt_sentences, os.path.join(args.save_dir, 'test_gt.txt')) + +data_train = data_train.transform(lambda src, tgt: (src, tgt, len(src), len(tgt)), lazy=False) +data_val = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i) + for i, ele in enumerate(data_val)]) +data_test = gluon.data.SimpleDataset([(ele[0], ele[1], len(ele[0]), len(ele[1]), i) + for i, ele in enumerate(data_test)]) + +ctx = [mx.cpu()] if args.gpus is None or args.gpus == '' else \ + [mx.gpu(int(x)) for x in args.gpus.split(',')] +num_ctxs = len(ctx) + +data_train_lengths, data_val_lengths, data_test_lengths = [dataprocessor.get_data_lengths(x) + for x in + [data_train, data_val, data_test]] + +if args.src_max_len <= 0 or args.tgt_max_len <= 0: + max_len = np.max( + [np.max(data_train_lengths, axis=0), np.max(data_val_lengths, axis=0), + np.max(data_test_lengths, axis=0)], + axis=0) +if args.src_max_len > 0: + src_max_len = args.src_max_len +else: + src_max_len = max_len[0] +if args.tgt_max_len > 0: + tgt_max_len = args.tgt_max_len +else: + tgt_max_len = max_len[1] +encoder, decoder = get_transformer_encoder_decoder(units=args.num_units, + hidden_size=args.hidden_size, + dropout=args.dropout, + num_layers=args.num_layers, + num_heads=args.num_heads, + max_src_length=max(src_max_len, 500), + max_tgt_length=max(tgt_max_len, 500), + scaled=args.scaled) +model = NMTModel(src_vocab=src_vocab, tgt_vocab=tgt_vocab, encoder=encoder, decoder=decoder, + share_embed=args.dataset != 'TOY', embed_size=args.num_units, + tie_weights=args.dataset != 'TOY', embed_initializer=None, prefix='transformer_') +model.initialize(init=mx.init.Xavier(magnitude=args.magnitude), ctx=ctx) +static_alloc = True +model.hybridize(static_alloc=static_alloc) +# logging.info(model) + +translator = BeamSearchTranslator(model=model, beam_size=args.beam_size, + scorer=nlp.model.BeamSearchScorer(alpha=args.lp_alpha, + K=args.lp_k), + max_length=200) +logging.info('Use beam_size={}, alpha={}, K={}'.format(args.beam_size, args.lp_alpha, args.lp_k)) + +label_smoothing = LabelSmoothing(epsilon=args.epsilon, units=len(tgt_vocab)) +label_smoothing.hybridize(static_alloc=static_alloc) + +loss_function = MaskedSoftmaxCELoss(sparse_label=False) +loss_function.hybridize(static_alloc=static_alloc) + +test_loss_function = MaskedSoftmaxCELoss() +test_loss_function.hybridize(static_alloc=static_alloc) + +rescale_loss = 100 +parallel_model = ParallelTransformer(model, label_smoothing, loss_function, rescale_loss) +detokenizer = nlp.data.SacreMosesDetokenizer() + + +def evaluate(data_loader, context=ctx[0]): + """Evaluate given the data loader + + Parameters + ---------- + data_loader : DataLoader + + Returns + ------- + avg_loss : float + Average loss + real_translation_out : list of list of str + The translation output + """ + translation_out = [] + all_inst_ids = [] + avg_loss_denom = 0 + avg_loss = 0.0 + for _, (src_seq, tgt_seq, src_valid_length, tgt_valid_length, inst_ids) \ + in enumerate(data_loader): + src_seq = src_seq.as_in_context(context) + tgt_seq = tgt_seq.as_in_context(context) + src_valid_length = src_valid_length.as_in_context(context) + tgt_valid_length = tgt_valid_length.as_in_context(context) + # Calculating Loss + out, _ = model(src_seq, tgt_seq[:, :-1], src_valid_length, tgt_valid_length - 1) + loss = test_loss_function(out, tgt_seq[:, 1:], tgt_valid_length - 1).mean().asscalar() + all_inst_ids.extend(inst_ids.asnumpy().astype(np.int32).tolist()) + avg_loss += loss * (tgt_seq.shape[1] - 1) + avg_loss_denom += (tgt_seq.shape[1] - 1) + # Translate + samples, _, sample_valid_length = \ + translator.translate(src_seq=src_seq, src_valid_length=src_valid_length) + max_score_sample = samples[:, 0, :].asnumpy() + sample_valid_length = sample_valid_length[:, 0].asnumpy() + for i in range(max_score_sample.shape[0]): + translation_out.append( + [tgt_vocab.idx_to_token[ele] for ele in + max_score_sample[i][1:(sample_valid_length[i] - 1)]]) + avg_loss = avg_loss / avg_loss_denom + real_translation_out = [None for _ in range(len(all_inst_ids))] + for ind, sentence in zip(all_inst_ids, translation_out): + if args.bleu == 'tweaked': + real_translation_out[ind] = sentence + elif args.bleu == '13a' or args.bleu == 'intl': + real_translation_out[ind] = detokenizer(_bpe_to_words(sentence)) + else: + raise NotImplementedError + return avg_loss, real_translation_out + + +def train(): + """Training function.""" + + # local sgd + if args.local_sgd_epochs is not None: + local_sgd_epochs = [int(i) for i in args.local_sgd_epochs.split(',')] + else: + local_sgd_epochs = None + if args.local_sgd_schedule is not None: + local_sgd_schedule = [int(i) for i in args.local_sgd_schedule.split(',')] + else: + local_sgd_schedule = None + local_sgd_interval = args.local_sgd_interval + + trainer = LocalSGDTrainer(model.collect_params(), args.optimizer, + {'learning_rate': args.lr, 'beta2': 0.98, 'epsilon': 1e-9}, + local_sgd_interval=local_sgd_interval, local_sgd_regularization=args.local_sgd_regularization, + local_sgd_regularization_interval=args.local_sgd_regularization_interval) + + train_data_loader, val_data_loader, test_data_loader \ + = dataprocessor.make_dataloader(data_train, data_val, data_test, args, + use_average_length=True, num_shards=len(ctx)) + + if args.bleu == 'tweaked': + bpe = bool(args.dataset != 'IWSLT2015' and args.dataset != 'TOY') + split_compound_word = bpe + tokenized = True + elif args.bleu == '13a' or args.bleu == 'intl': + bpe = False + split_compound_word = False + tokenized = False + else: + raise NotImplementedError + + best_valid_bleu = 0.0 + step_num = 0 + warmup_steps = args.warmup_steps + grad_interval = args.num_accumulated + model.collect_params().setattr('grad_req', 'add') + average_start = (len(train_data_loader) // grad_interval) * (args.epochs - args.average_start) + average_param_dict_list = None + model.collect_params().zero_grad() + parallel = Parallel(num_ctxs, parallel_model) + + average_counter = 0 + + if args.start_epoch > 0: + param_path = os.path.join(args.save_dir, 'epoch{:d}.params'.format(args.start_epoch-1)) + logging.info('Loading parameters from %s', param_path) + nlp.utils.load_parameters(model, param_path, ctx=ctx) + state_path = os.path.join(args.save_dir, 'epoch{:d}.states'.format(args.start_epoch-1)) + logging.info('Loading states from %s', state_path) + nlp.utils.load_states(trainer, state_path) + step_num = int(math.ceil(len(train_data_loader) * args.start_epoch / grad_interval)) + 1 + + for epoch_id in range(args.start_epoch, args.epochs): + log_avg_loss = 0 + log_wc = 0 + loss_denom = np.zeros(num_ctxs, dtype='float32') + step_loss = 0 + log_start_time = time.time() + epoch_start_time = time.time() + + if local_sgd_epochs is not None and epoch_id in local_sgd_epochs: + new_local_sgd_interval = local_sgd_schedule[local_sgd_epochs.index(epoch_id)] + if new_local_sgd_interval <= 1: + new_local_sgd_interval = 1 + if args.start_epoch != epoch_id: + if not trainer._is_states_initialized: + trainer.init_states() + trainer.allreduce_states() + trainer._local_sgd_interval = new_local_sgd_interval + local_sgd_interval = new_local_sgd_interval + + for batch_id, seqs \ + in enumerate(train_data_loader): + + if (local_sgd_interval > 1 or local_sgd_epochs is not None) and epoch_id == 0 and batch_id == grad_interval: + # local sgd, synchronous momentum + trainer.init_states() + if batch_id % grad_interval == 0: + step_num += 1 + new_lr = args.lr / math.sqrt(args.num_units) \ + * min(1. / math.sqrt(step_num), step_num * warmup_steps ** (-1.5)) + trainer.set_learning_rate(new_lr) + + # src_wc, tgt_wc, bs = np.sum([(shard[2].sum(), shard[3].sum(), shard[0].shape[0]) + # for shard in seqs], axis=0) + src_wc = np.array([shard[2].sum().asscalar() for shard in seqs]) + tgt_wc = np.array([shard[3].sum().asscalar() for shard in seqs]) + bs = np.array([shard[0].shape[0] for shard in seqs]) + + seqs = [[seq.as_in_context(context) for seq in shard] + for context, shard in zip(ctx, seqs)] + Ls = [] + for seq in seqs: + parallel.put((seq, args.batch_size)) + Ls = [parallel.get() for _ in range(len(ctx))] + loss_denom = loss_denom + tgt_wc - bs + is_sync = False + if batch_id % grad_interval == grad_interval - 1 or\ + batch_id == len(train_data_loader) - 1: + step_size = loss_denom / args.batch_size / 100.0 + is_sync = trainer.step(step_size.tolist()) + param_dict = model.collect_params() + param_dict.zero_grad() + if step_num > average_start: + average_counter += 1 + if average_param_dict_list is None: + average_param_dict_list = [{k: v.data(c).copy() for k, v in + model.collect_params().items()} for c in ctx] + else: + alpha = 1. / average_counter + for i in range(len(ctx)): + for name, average_param in average_param_dict_list[i].items(): + average_param[:] += alpha * (param_dict[name].data(ctx[i]) - average_param) + for L in Ls: + step_loss += L.as_in_context(mx.cpu()) + if batch_id % grad_interval == grad_interval - 1 or\ + batch_id == len(train_data_loader) - 1: + log_avg_loss += step_loss / np.asscalar(loss_denom.sum()) * args.batch_size * 100.0 + loss_denom = np.zeros(num_ctxs, dtype='float32') + step_loss = 0 + log_wc += np.asscalar(src_wc.sum() + tgt_wc.sum()) + if (batch_id + 1) % (args.log_interval * grad_interval) == 0: + wps = log_wc / (time.time() - log_start_time) + log_avg_loss_scalar = log_avg_loss.asscalar() + logging.info('[Epoch {} Batch {}/{}] loss={:.4f}, ppl={:.4f}, ' + 'throughput={:.2f}K wps, wc={:.2f}K, lr={:.8f}' + .format(epoch_id, batch_id + 1, len(train_data_loader), + log_avg_loss_scalar / args.log_interval, + np.exp(log_avg_loss_scalar / args.log_interval), + wps / 1000, log_wc / 1000, new_lr)) + log_start_time = time.time() + log_avg_loss = 0 + log_wc = 0 + if local_sgd > 1 and not is_sync: + # synchronous model parameters for local sgd + trainer.allreduce_params() + trainer.allreduce_states() + mx.nd.waitall() + logging.info('[Epoch {}] time={:.2f}s'.format(epoch_id, time.time()-epoch_start_time)) + if epoch_id >= 5: + valid_loss, valid_translation_out = evaluate(val_data_loader, ctx[0]) + valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out, + tokenized=tokenized, tokenizer=args.bleu, + split_compound_word=split_compound_word, + bpe=bpe) + logging.info('[Epoch {}] valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' + .format(epoch_id, valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) + test_loss, test_translation_out = evaluate(test_data_loader, ctx[0]) + test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out, + tokenized=tokenized, tokenizer=args.bleu, + split_compound_word=split_compound_word, + bpe=bpe) + logging.info('[Epoch {}] test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}' + .format(epoch_id, test_loss, np.exp(test_loss), test_bleu_score * 100)) + dataprocessor.write_sentences(valid_translation_out, + os.path.join(args.save_dir, + 'epoch{:d}_valid_out.txt').format(epoch_id)) + dataprocessor.write_sentences(test_translation_out, + os.path.join(args.save_dir, + 'epoch{:d}_test_out.txt').format(epoch_id)) + if valid_bleu_score > best_valid_bleu: + best_valid_bleu = valid_bleu_score + save_path = os.path.join(args.save_dir, 'valid_best.params') + logging.info('Save best parameters to {}'.format(save_path)) + model.save_parameters(save_path) + else: + save_path = os.path.join(args.save_dir, 'valid_best.params') + logging.info('Save best parameters to {}'.format(save_path)) + model.save_parameters(save_path) + if args.save_checkpoint: + param_save_path = os.path.join(args.save_dir, 'epoch{:d}.params'.format(epoch_id)) + logging.info('Save current parameters to {}'.format(param_save_path)) + nlp.utils.save_parameters(model, param_save_path) + states_save_path = os.path.join(args.save_dir, 'epoch{:d}.states'.format(epoch_id)) + logging.info('Save current states to {}'.format(states_save_path)) + nlp.utils.save_states(trainer, states_save_path) + + save_path = os.path.join(args.save_dir, 'average.params') + mx.nd.save(save_path, average_param_dict_list[0]) + if args.average_checkpoint: + for j in range(args.num_averages): + params = mx.nd.load(os.path.join(args.save_dir, + 'epoch{:d}.params'.format(args.epochs - j - 1))) + alpha = 1. / (j + 1) + for k, v in model._collect_params_with_prefix().items(): + for c in ctx: + v.data(c)[:] += alpha * (params[k].as_in_context(c) - v.data(c)) + save_path = os.path.join(args.save_dir, + 'average_checkpoint_{}.params'.format(args.num_averages)) + model.save_parameters(save_path) + elif args.average_start > 0: + param_dict = model.collect_params() + for i in range(len(ctx)): + for name, average_param in average_param_dict_list[i].items(): + param_dict[name].data(ctx[i])[:] = average_param + trainer.allreduce_params() + mx.nd.waitall() + save_path = os.path.join(args.save_dir, 'average.params') + model.save_parameters(save_path) + else: + model.load_parameters(os.path.join(args.save_dir, 'valid_best.params'), ctx) + valid_loss, valid_translation_out = evaluate(val_data_loader) + valid_bleu_score, _, _, _, _ = compute_bleu([val_tgt_sentences], valid_translation_out, + tokenized=tokenized, tokenizer=args.bleu, bpe=bpe, + split_compound_word=split_compound_word) + logging.info('Best model valid Loss={:.4f}, valid ppl={:.4f}, valid bleu={:.2f}' + .format(valid_loss, np.exp(valid_loss), valid_bleu_score * 100)) + test_loss, test_translation_out = evaluate(test_data_loader) + test_bleu_score, _, _, _, _ = compute_bleu([test_tgt_sentences], test_translation_out, + tokenized=tokenized, tokenizer=args.bleu, bpe=bpe, + split_compound_word=split_compound_word) + logging.info('Best model test Loss={:.4f}, test ppl={:.4f}, test bleu={:.2f}' + .format(test_loss, np.exp(test_loss), test_bleu_score * 100)) + dataprocessor.write_sentences(valid_translation_out, + os.path.join(args.save_dir, 'best_valid_out.txt')) + dataprocessor.write_sentences(test_translation_out, + os.path.join(args.save_dir, 'best_test_out.txt')) + + + +if __name__ == '__main__': + train() \ No newline at end of file diff --git a/scripts/machine_translation/transformer_local_sgd.py b/scripts/machine_translation/transformer_local_sgd.py new file mode 100644 index 0000000000..841854767c --- /dev/null +++ b/scripts/machine_translation/transformer_local_sgd.py @@ -0,0 +1,338 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# coding: utf-8 +# pylint: disable=line-too-long +"""Parameter optimizer.""" + +from mxnet import optimizer as opt +from mxnet.model import _create_kvstore, _create_sparse_kvstore +from mxnet.gluon.parameter import ParameterDict, Parameter + +import mxnet as mx +import types +import warnings +import math + +class LocalSGDTrainer(mx.gluon.Trainer): + """Local Adam optimizer for Transformer. + + Parameters + ---------- + local_sgd_interval : int, default 1 + If local_sgd_interval<=1, run fully synchronous SGD, + otherwise, sync params and states for every local_sgd steps. + local_sgd_regularization : float, default 0 + The weight of local regularization, within the range [0, 1) + local_sgd_regularization_interval : int, default 0 + If larger than 0, add the regularization term to the local solver + after every local_sgd_regularization_interval steps + """ + def __init__(self, params, optimizer, optimizer_params=None, kvstore='device', + local_sgd_interval=1, local_sgd_regularization=0, local_sgd_regularization_interval=0): + + super(LocalSGDTrainer, self).__init__( + params, optimizer, optimizer_params=optimizer_params, kvstore=kvstore, update_on_kvstore=False) + + # _scale is used to check and set rescale_grad for optimizer in Trainer.step() + # function. Normalizing it by Horovod size, which is equivalent to performing + # average in allreduce, has better performance. + if local_sgd_interval is None or local_sgd_interval <= 1: + self._local_sgd_interval = 1 + else: + self._local_sgd_interval = local_sgd_interval + self._local_sgd_counter = 0 + update_on_kvstore = False + self._local_sgd_regularization = local_sgd_regularization + self._local_sgd_regularization_interval = local_sgd_regularization_interval + self._local_sgd_regularization_counter = 0 + self._is_states_initialized = False + + def _init_kvstore(self): + """Create kvstore.""" + config = self._kvstore_params + # configure kvstore, update_on_kvstore and self._distributed on three cases: + if self._contains_sparse_weight: + # If weight is sparse, kvstore must be present and the weight must be updated on kvstore. + # The training loop is the following: + # - row_sparse_pull(sparse_weight) + # - forward() + # - backward() + # - push_and_update(grad) + # - pull(weight) + kvstore, update_on_kvstore = _create_sparse_kvstore(config['kvstore']) + self._distributed = 'dist' in kvstore.type + # raise err if user provides unsupported configs + if config['update_on_kvstore'] is False: + raise ValueError("Cannot set update_on_kvstore=False when sparse weights " + "are present.") + + elif self._contains_sparse_grad: + # For single node training with dense weight and sparse grad, + # we prefer update_on_kvstore=False because this is usually faster. + # This means we push and pull sparse gradients, and we do not store weight in kvstore. + # The training loop is the following: + # - forward() + # - backward() + # - push(grad) + # - pull(grad) + # - update(grad, weight) + # + # For multi-node training with dense weight and sparse grad, + # only update_on_kvstore=True is supported, due to the fact that + # kv.row_sparse_pull(grad) is not implemented. + # Therefore, we push sparse gradients and pull dense weights. + # The training loop contains: + # - forward() + # - backward() + # - push_and_update(grad) + # - pull(weight) + arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} + kvstore, _ = _create_kvstore(config['kvstore'], len(self._contexts), arg_arrays) + self._distributed = 'dist' in kvstore.type if kvstore else False + update_on_kvstore = self._distributed + # raise err if user provides unsupported configs + if config['update_on_kvstore'] is not None: + if config['update_on_kvstore'] is False and self._distributed: + raise ValueError("Cannot set update_on_kvstore=False on dist kvstore " + "when sparse gradients are present.") + update_on_kvstore = config['update_on_kvstore'] + + else: + # Training with dense weight and dense gradients. + # The only unsupported mode is async with update_on_kvstore=False + arg_arrays = {param.name: param.data(self._contexts[0]) for param in self._params} + if self._local_sgd_interval > 1: + # local sgd + state_arrays = {param.name+'_state': param.data(self._contexts[0]) for param in self._params} + arg_arrays.update(state_arrays) + kvstore, update_on_kvstore = _create_kvstore(config['kvstore'], len(self._contexts), + arg_arrays) + self._distributed = 'dist' in kvstore.type if kvstore else False + if self._distributed and 'async' in kvstore.type: + update_on_kvstore = True + # raise err if user provides unsupported configs + if config['update_on_kvstore'] is False: + raise ValueError("Please set update_on_kvstore=True " + "when training in async mode.") + if config['update_on_kvstore'] is not None: + update_on_kvstore = config['update_on_kvstore'] + + # set grad compression and optimizers + if kvstore: + if self._compression_params: + kvstore.set_gradient_compression(self._compression_params) + if update_on_kvstore: + # optimizer preferably needs to be set before init for multiprecision + kvstore.set_optimizer(self._optimizer) + self._kvstore = kvstore + self._update_on_kvstore = update_on_kvstore + else: + self._kvstore = None + self._update_on_kvstore = None + + self._kv_initialized = True + + def reset_adam_counter(self, t): + print(self._updaters[1].optimizer._index_update_count) + print(t) + # for i, param in enumerate(self._params): + # if param.grad_req != 'null': + # for updater in self._updaters: + # updater.optimizer._index_update_count[i] = t + + def init_states(self): + """Initialize states (momentum for sgd_mon, or mean/var for adam) in the KVStore, for local sgd + """ + assert self._kv_initialized, "Cannot initialize states in KVStore " \ + "when KVStore is not initialized." + if self._kvstore and self._is_states_initialized == False: + for i, param in enumerate(self._params): + if param.grad_req != 'null': + if isinstance(self._updaters[0].states[i], (tuple, list)): + # for some optimizers, there are multiple states (mean, variance), such as Adam + # TODO(xcong) there might be some other side cases + for j in range(len(self._updaters[0].states[i])): + state_arrays = [updater.states[i][j] for updater in self._updaters] + self._kvstore.init(i+len(self._params)*(j+1), self._updaters[0].states[i][j]) + else: + state_arrays = [updater.states[i] for updater in self._updaters] + self._kvstore.init(i+len(self._params), self._updaters[0].states[i]) + self._is_states_initialized = True + + def step(self, batch_sizes, ignore_stale_grad=False): + """Makes one step of parameter update. Should be called after + `autograd.backward()` and outside of `record()` scope. + + For normal parameter updates, `step()` should be used, which internally calls + `allreduce_grads()` and then `update()`. However, if you need to get the reduced + gradients to perform certain transformation, such as in gradient clipping, then + you may want to manually call `allreduce_grads()` and `update()` separately. + + Parameters + ---------- + batch_sizes : [int] + Batch size of data processed. Gradient will be normalized by `1/batch_size`. + Set this to 1 if you normalized loss manually with `loss = mean(loss)`. + ignore_stale_grad : bool, optional, default=False + If true, ignores Parameters with stale gradient (gradient that has not + been updated by `backward` after last step) and skip update. + """ + # rescale_grad = self._scale / batch_size + # self._check_and_rescale_grad(rescale_grad) + + # rescale the grads + for i, param in enumerate(self._params): + if param.grad_req != 'null': + for j in range(len(batch_sizes)): + param.list_grad()[j] /= batch_sizes[j] + + if not self._kv_initialized: + self._init_kvstore() + if self._params_to_init: + self._init_params() + + if self._local_sgd_interval == 1: + # if not local sgd + self._allreduce_grads() + + if self._local_sgd_interval > 1 and self._local_sgd_counter == 0 and self._local_sgd_regularization > 0: + # regularization for local sgd + self._local_sgd_regularization_params = [] + for i, param in enumerate(self._params): + if param.grad_req != 'null' and param._stype == 'default': + self._local_sgd_regularization_params.append([self._local_sgd_regularization * x.copy() for x in param.list_data()]) + else: + self._local_sgd_regularization_params.append([]) + + self._update(ignore_stale_grad) + + if self._local_sgd_interval > 1 and self._local_sgd_regularization > 0: + # regularization for local sgd + # TODO(xcong): use param.name instead of the indices + mixing_weight = (1 - self._local_sgd_regularization) + self._local_sgd_regularization_counter += 1 + if self._local_sgd_regularization_interval == 0 or self._local_sgd_regularization_interval == self._local_sgd_regularization_counter: + self._local_sgd_regularization_counter = 0 + for i, param in enumerate(self._params): + if param.grad_req != 'null' and param._stype == 'default': + for j, data in enumerate(param.list_data()): + data *= mixing_weight + data += self._local_sgd_regularization_params[i][j] + + if self._local_sgd_interval > 1: + # local sgd + self._local_sgd_counter += 1 + if self._local_sgd_counter == self._local_sgd_interval: + self._local_sgd_counter = 0 + # synchronization + self._allreduce_params() + if self._is_states_initialized: + self._allreduce_states() + # indicate that the parameters are synchronized in the current iteration + return True + return False + return True + + def allreduce_params(self): + """For each parameter, reduce the gradients from different contexts. + + Should be called after `autograd.backward()`, outside of `record()` scope, + and before `trainer.update()`. + + For normal parameter updates, `step()` should be used, which internally calls + `allreduce_grads()` and then `update()`. However, if you need to get the reduced + gradients to perform certain transformation, such as in gradient clipping, then + you may want to manually call `allreduce_grads()` and `update()` separately. + """ + if not self._kv_initialized: + self._init_kvstore() + if self._params_to_init: + self._init_params() + + self._allreduce_params() + + def _allreduce_params(self): + # print("_allreduce_params") + if self._kvstore: + for i, param in enumerate(self._params): + if param.grad_req != 'null': + self._kvstore.push(i, param.list_data(), priority=-i) + if param._stype == 'default': + self._kvstore.pull(i, param.list_data(), priority=-i) + # take average + # assume that every worker has the same number of gpus/contexts + num_workers = self._kvstore.num_workers * len(param.list_data()) + for data in param.list_data(): + data /= num_workers + else: + raise ValueError("Cannot pull row_sparse parameters for local SGD") + + def allreduce_states(self): + """For each parameter, reduce the gradients from different contexts. + + Should be called after `autograd.backward()`, outside of `record()` scope, + and before `trainer.update()`. + + For normal parameter updates, `step()` should be used, which internally calls + `allreduce_grads()` and then `update()`. However, if you need to get the reduced + gradients to perform certain transformation, such as in gradient clipping, then + you may want to manually call `allreduce_grads()` and `update()` separately. + """ + if not self._kv_initialized: + self._init_kvstore() + + if not self._is_states_initialized: + raise ValueError("States are not initiallized") + self._allreduce_states() + + def _allreduce_states(self): + # print("_allreduce_states") + if self._kvstore: + # for i, param in enumerate(self._params): + for i, param in reversed(list(enumerate(self._params))): + if param.grad_req != 'null': + if isinstance(self._updaters[0].states[i], (tuple, list)): + # for some optimizers, there are multiple states (mean, variance), such as Adam + for j in range(len(self._updaters[0].states[i])): + state_arrays = [updater.states[i][j] for updater in self._updaters] + idx = i+len(self._params)*(j+1) + self._kvstore.push(idx, state_arrays, priority=i-len(self._params)*2) + if param._stype == 'default': + self._kvstore.pull(idx, state_arrays, priority=i-len(self._params)*2) + # take average + # assume that every worker has the same number of gpus/contexts + num_workers = float(self._kvstore.num_workers * len(state_arrays)) + for state in state_arrays: + state /= num_workers + + else: + raise ValueError("Cannot pull row_sparse parameters for local SGD") + else: + state_arrays = [updater.states[i] for updater in self._updaters] + idx = i+len(self._params) + self._kvstore.push(idx, state_arrays, priority=i-len(self._params)*2) + if param._stype == 'default': + self._kvstore.pull(idx, state_arrays, priority=i-len(self._params)*2) + # take average + # assume that every worker has the same number of gpus/contexts + num_workers = self._kvstore.num_workers * len(state_arrays) + for state in state_arrays: + state /= num_workers + else: + raise ValueError("Cannot pull row_sparse parameters for local SGD") + From eb5757f8a33cdf7e3833f6a4b586d4831a8bdc2d Mon Sep 17 00:00:00 2001 From: Xie Date: Wed, 7 Aug 2019 16:16:42 -0700 Subject: [PATCH 4/5] local adam for transformer --- scripts/machine_translation/train_transformer_local_sgd.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/machine_translation/train_transformer_local_sgd.py b/scripts/machine_translation/train_transformer_local_sgd.py index f8c449d4aa..30d21a120c 100644 --- a/scripts/machine_translation/train_transformer_local_sgd.py +++ b/scripts/machine_translation/train_transformer_local_sgd.py @@ -408,7 +408,7 @@ def train(): log_avg_loss = 0 log_wc = 0 if local_sgd > 1 and not is_sync: - # synchronous model parameters for local sgd + # synchronize model parameters before evaluation trainer.allreduce_params() trainer.allreduce_states() mx.nd.waitall() From ba36ab6719ee173961ac32b6722800a455497743 Mon Sep 17 00:00:00 2001 From: Xie Date: Wed, 7 Aug 2019 16:57:05 -0700 Subject: [PATCH 5/5] fix some typo --- scripts/machine_translation/train_transformer_local_sgd.py | 2 +- scripts/machine_translation/transformer_local_sgd.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/scripts/machine_translation/train_transformer_local_sgd.py b/scripts/machine_translation/train_transformer_local_sgd.py index 30d21a120c..9f600da64f 100644 --- a/scripts/machine_translation/train_transformer_local_sgd.py +++ b/scripts/machine_translation/train_transformer_local_sgd.py @@ -407,7 +407,7 @@ def train(): log_start_time = time.time() log_avg_loss = 0 log_wc = 0 - if local_sgd > 1 and not is_sync: + if local_sgd_interval > 1 and not is_sync: # synchronize model parameters before evaluation trainer.allreduce_params() trainer.allreduce_states() diff --git a/scripts/machine_translation/transformer_local_sgd.py b/scripts/machine_translation/transformer_local_sgd.py index 841854767c..1c4206ec5f 100644 --- a/scripts/machine_translation/transformer_local_sgd.py +++ b/scripts/machine_translation/transformer_local_sgd.py @@ -35,7 +35,7 @@ class LocalSGDTrainer(mx.gluon.Trainer): ---------- local_sgd_interval : int, default 1 If local_sgd_interval<=1, run fully synchronous SGD, - otherwise, sync params and states for every local_sgd steps. + otherwise, sync params and states for every local_sgd_interval steps. local_sgd_regularization : float, default 0 The weight of local regularization, within the range [0, 1) local_sgd_regularization_interval : int, default 0