diff --git a/docs/examples/machine_translation/CustomNMTModel.py b/docs/examples/machine_translation/CustomNMTModel.py new file mode 100644 index 0000000000..daf115d3aa --- /dev/null +++ b/docs/examples/machine_translation/CustomNMTModel.py @@ -0,0 +1,162 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""class ONNXNMTModel, used in transformer_onnx_based.md""" + +import numpy as np +import mxnet as mx +import onnxruntime + +class ONNXNMTModel: + """This class mimics the actual NMTModel class defined here: + https://github.com/dmlc/gluon-nlp/blob/v0.10.0/src/gluonnlp/model/translation.py#L28 + """ + class ONNXRuntimeSession: + """This class is used to wrap the onnxruntime sessions of the components in the + transforman model, namely: src_embed, encoder, tgt_embed, one_step_ahead_decoder, + and tgt_proj. + """ + def __init__(self, onnx_file): + """Init the onnxruntime session. Performace tuning code can be added here. + Parameters + ---------- + onnx_file : str + """ + ses_opt = onnxruntime.SessionOptions() + ses_opt.log_severity_level = 3 + self.session = onnxruntime.InferenceSession(onnx_file, ses_opt) + + def __call__(self, *onnx_inputs): + """Notice that the inputs here are MXNet NDArrays. We first convert them to numpy + ndarrays, run inference, and then convert the outputs back to MXNet NDArrays. + Parameters + ---------- + onnx_inputs: list of NDArrays + Returns + ------- + list of NDArrays + """ + input_dict = dict((self.session.get_inputs()[i].name, onnx_inputs[i].asnumpy()) + for i in range(len(onnx_inputs))) + outputs = self.session.run(None, input_dict) + if len(outputs) == 1: + return mx.nd.array(outputs[0]) + return [mx.nd.array(output) for output in outputs] + + class DummyDecoder: + """This Dummy Decoder mimics the actualy decoder defined here: + https://github.com/dmlc/gluon-nlp/blob/v0.10.0/src/gluonnlp/model/transformer.py#L724 + For inference we only need to define init_state_from_encoder() + """ + def init_state_from_encoder(self, encoder_outputs, encoder_valid_length=None): + """Initialize the state from the encoder outputs. Refer to the original function here: + https://github.com/dmlc/gluon-nlp/blob/v0.10.0/src/gluonnlp/model/transformer.py#L621 + Parameters + ---------- + encoder_outputs : list + encoder_valid_length : NDArray or None + Returns + ------- + decoder_states : list + The decoder states, includes: + - mem_value : NDArray + - mem_masks : NDArray or None + """ + mem_value = encoder_outputs + decoder_states = [mem_value] + mem_length = mem_value.shape[1] + if encoder_valid_length is not None: + dtype = encoder_valid_length.dtype + ctx = encoder_valid_length.context + mem_masks = mx.nd.broadcast_lesser( + mx.nd.arange(mem_length, ctx=ctx, dtype=dtype).reshape((1, -1)), + encoder_valid_length.reshape((-1, 1))) + decoder_states.append(mem_masks) + else: + decoder_states.append(None) + return decoder_states + + def __init__(self, tgt_vocab, src_embed_onnx_file, encoder_onnx_file, tgt_embed_onnx_file, + one_step_ahead_decoder_onnx_file, tgt_proj_onnx_file): + """Init the ONNXNMTModel. For inference we need the following components of the original + transformer model: src_embed, encoder, tgt_embed, one_step_ahead_decoder, and tgt_proj. + Parameters + ---------- + tgt_vocab : Vocab + Target vocabulary. + src_embed_onnx_file: str + encoder_onnx_file: str + tgt_embed_onnx_file: str + one_step_ahead_decoder_onnx_file: str + tgt_proj_onnx_file: str + """ + self.tgt_vocab = tgt_vocab + self.src_embed = self.ONNXRuntimeSession(src_embed_onnx_file) + self.encoder = self.ONNXRuntimeSession(encoder_onnx_file) + self.tgt_embed = self.ONNXRuntimeSession(tgt_embed_onnx_file) + self.one_step_ahead_decoder = self.ONNXRuntimeSession(one_step_ahead_decoder_onnx_file) + self.tgt_proj = self.ONNXRuntimeSession(tgt_proj_onnx_file) + self.decoder = self.DummyDecoder() + + def encode(self, inputs, states=None, valid_length=None): + """Encode the input sequence. Refer to the original function here: + https://github.com/dmlc/gluon-nlp/blob/v0.10.0/src/gluonnlp/model/translation.py#L132 + Parameters + ---------- + inputs : NDArray + states : list of NDArrays or None, default None + valid_length : NDArray or None, default None + Returns + ------- + outputs : list + Outputs of the encoder. + """ + return self.encoder(self.src_embed(inputs), valid_length), None + + def decode_step(self, step_input, decoder_states): + """One step decoding of the translation model. Refer to the original function here: + https://github.com/dmlc/gluon-nlp/blob/v0.10.0/src/gluonnlp/model/translation.py#L171 + Parameters + ---------- + step_input : NDArray + Shape (batch_size,) + states : list of NDArrays + Returns + ------- + step_output : NDArray + Shape (batch_size, C_out) + states : list + step_additional_outputs : list + Additional outputs of the step, e.g, the attention weights + """ + step_input = self.tgt_embed(step_input) + + # Refer to https://github.com/dmlc/gluon-nlp/blob/v0.10.0/src/gluonnlp/model/transformer.py#L819 + if len(decoder_states) == 3: # step_input from prior call is included + last_embeds, _, _ = decoder_states + inputs = mx.nd.concat(last_embeds, mx.nd.expand_dims(step_input, axis=1), dim=1) + decoder_states = decoder_states[1:] + else: + inputs = mx.nd.expand_dims(step_input, axis=1) + + # Refer to https://github.com/dmlc/gluon-nlp/blob/v0.10.0/src/gluonnlp/model/transformer.py#L834 + step_output = self.one_step_ahead_decoder(decoder_states[1], inputs, decoder_states[0]) + decoder_states = [inputs] + decoder_states + step_additional_outputs = None + + step_output = self.tgt_proj(step_output) + + return step_output, decoder_states, step_additional_outputs \ No newline at end of file diff --git a/docs/examples/machine_translation/transformer_onnx_based.md b/docs/examples/machine_translation/transformer_onnx_based.md new file mode 100644 index 0000000000..89edd0ae8f --- /dev/null +++ b/docs/examples/machine_translation/transformer_onnx_based.md @@ -0,0 +1,238 @@ +# Running MXNet-trained Transformer with ONNXRuntime + +In [Using Pre-trained Transformer](https://nlp.gluon.ai/examples/machine_translation/transformer.html) we have seen how to run a pretrained MXNet transformer model for end-2-end machine translation. In this blog, we are going to export the transformer model to the ONNX format, run inference with ONNXRuntime, and achieve the same end-to-end translation as before. + +## Setup + +```{.python .input} +import warnings +warnings.filterwarnings('ignore') + +import numpy as np +import mxnet as mx +import gluonnlp as nlp +# make sure gluonnlp version is >= 0.7.0 +nlp.utils.check_version('0.7.0') + +# use cpu context to load the model +ctx = mx.cpu(0) +print('ctx: ', ctx) +``` + +## Load the Pre-trained Transformer + +```{.python .input} +# load the model +wmt_model_name = 'transformer_en_de_512' +wmt_transformer_model, wmt_src_vocab, wmt_tgt_vocab = \ + nlp.model.get_model(wmt_model_name, + dataset_name='WMT2014', + pretrained=True, + ctx=ctx) + +# we are using mixed vocab of EN-DE, so the source and target language vocab are the same +print('EN size: ', len(wmt_src_vocab), '\nDE size: ', len(wmt_tgt_vocab)) +``` + +## Save the Components of the Transformer as MXNet Models + +Note that the Transformer, which is an instance of class `NMTModel`, is not a monolith, but a collection of several smaller components such as `src_embed`, `encoder`, `tgt_embed`, `one_step_ahead_decoder`, and `tgt_proj`. Those components are by themselves MXNet hybrid models. In `NMTModel`, there are high-level member functions such as `encode` and `decode_step`, and they will in turn call the finer components in combination. For example, `encode` internally uses both `src_embed` and `encoder`. In the cell below we are going to create some dummy input data and call `encode` and`decode_step`. This is to make sure that we run foward path on all the hybridized components at least once. Then, we can [save the architecture and the parameters](https://mxnet.apache.org/versions/1.8.0/api/python/docs/tutorials/packages/gluon/blocks/save_load_params.html) of those components. + +```{.python .input} +# save the necessary components of transformer +import os + +wmt_transformer_model.hybridize(static_alloc=False) + +# the transformer model consists of several components, among which we need the following to run inference +print(type(wmt_transformer_model.src_embed)) +print(type(wmt_transformer_model.encoder)) +print(type(wmt_transformer_model.tgt_embed)) +print(type(wmt_transformer_model.one_step_ahead_decoder)) +print(type(wmt_transformer_model.tgt_proj)) + +# define some dummy data +batch = 1 +seq_length = 16 +C_in = 512 +C_out = 512 +src = mx.nd.random.uniform(0, 36794, shape=(batch, seq_length), dtype='float32') +step_input = mx.nd.random.uniform(0, 36794, shape=(batch,), dtype='float32') +src_valid_length = mx.nd.array([seq_length] * batch, dtype='float32') + +# run forward once with the following functions so taht we can export the components +## encode() internally calls src_embed and encoder +encoder_outputs, _ = wmt_transformer_model.encode(src, valid_length=src_valid_length) +## init_state_from_encoder() helps prepare decoder_states +decoder_states = wmt_transformer_model.decoder.init_state_from_encoder(encoder_outputs, + src_valid_length) +## decode_step() internally calls tgt_embed, one_step_ahead_decoder, tgt_proj +_, _, _ = wmt_transformer_model.decode_step(step_input, decoder_states) + +# export the components +base_path = './components' +if not os.path.exists(base_path): + os.makedirs(base_path) +for component in ['src_embed', 'encoder', 'tgt_embed', 'one_step_ahead_decoder', 'tgt_proj']: + prefix = "%s/%s" %(base_path, component) + component = getattr(wmt_transformer_model, component) + component.export(prefix) + sym_file = "%s-symbol.json" % prefix + params_file = "%s-0000.params" % prefix + +print('Files under ./components \n', os.listdir(base_path)) +``` + +## Export the MXNet Transformer Components to the ONNX Format + +Now that we have saved the necessary components of the transformer model, we are going to export each of them to the ONNX format. Here, notice that we are using the dynamic input feature of mx2onnx as the input batch and sequence lenghth can vary depending on different input sentences or beam search widths. + +Please note that MXNet version 1.9 or above is required for this step. + +```{.python .input} +# export the transformer components to ONNX models + +from mxnet import onnx as mx2onnx + +def export_to_onnx(prefix, input_shapes, input_types, **kwargs): + sym_file = "%s-symbol.json" % prefix + params_file = "%s-0000.params" % prefix + onnx_file = "%s.onnx" % prefix + return mx2onnx.export_model(sym_file, params_file, input_shapes, input_types, + onnx_file, **kwargs) + +# export src_embed +prefix = "%s/src_embed" %base_path +input_shapes = [(batch, seq_length)] +dynamic_input_shapes = [(batch, 'seq_length')] +input_types = [np.float32] +onnx_file = export_to_onnx(prefix, input_shapes, input_types, dynamic=True, + dynamic_input_shapes=dynamic_input_shapes) + +# export encoder +prefix = "%s/encoder" %base_path +input_shapes = [(batch, seq_length, C_in), (batch,)] +dynamic_input_shapes = [(batch, 'seq_length', C_in), (batch,)] +input_types = [np.float32, np.float32] +onnx_file = export_to_onnx(prefix, input_shapes, input_types, dynamic=True, + dynamic_input_shapes=dynamic_input_shapes) + +# export tgt_embed +prefix = "%s/tgt_embed" %base_path +input_shapes = [(batch,)] +dynamic_input_shapes = [('batch',)] +input_types = [np.int32] +onnx_file = export_to_onnx(prefix, input_shapes, input_types, dynamic=True, + dynamic_input_shapes=dynamic_input_shapes) + +# export one_step_ahead_decoder +prefix = "%s/one_step_ahead_decoder" %base_path +# mem_masks, decoder_inputs, mem_value +input_shapes = [(batch, seq_length), (batch, 1, C_in), (batch, seq_length, C_out)] +dynamic_input_shapes = [('batch', 'seq_length'), ('batch', 'cur_step_seq_length', C_in), + ('batch', 'seq_length', C_out)] +input_types = [np.float32, np.float32, np.float32] +onnx_file = export_to_onnx(prefix, input_shapes, input_types, dynamic=True, + dynamic_input_shapes=dynamic_input_shapes) + +# export tgt_proj +prefix = "%s/tgt_proj" %base_path +input_shapes = [(batch, C_out)] +dynamic_input_shapes = [('batch', C_out)] +input_types = [np.float32] +onnx_file = export_to_onnx(prefix, input_shapes, input_types, dynamic=True, + dynamic_input_shapes=dynamic_input_shapes) + +print('Files under ./components \n', os.listdir(base_path)) +``` + +## Beam Search Hyper-Parameters + +Now that we have the exported transformer components (in the ONNX format), we can run them with any runtime framework that supports ONNX models and use a custom beam search implementation. However, in this tutorial we are going to stick to the same beam seach as in the origina GluonNLP [transformer tutorial](https://nlp.gluon.ai/examples/machine_translation/transformer.html). Let's review the hyper parameters before we proceed. + +```{.python .input} +import hyperparameters as hparams +# check the hyper-parameters +print('beam_size:'.ljust(12), hparams.beam_size) +print('lp_alpha:'.ljust(12), hparams.lp_alpha) +print('lp_k:'.ljust(12), hparams.lp_k) +``` + +## Define a Translator with a Custom NMTModel + +in the original [transformer tutorial](https://nlp.gluon.ai/examples/machine_translation/transformer.html) we use a `BeamSearchTranslator` to run end-to-end machine translation task. `BeamSearchTranslator` would take in a `NMTModel` which our pre-trained transformer model is an instance of, and use it to make predictions word by word. + +```python +# in class BeamSearchTranslator +class BeamSearchTranslator: + ...... + def _decode_logprob(self, step_input, states): + out, states, _ = self._model.decode_step(step_input, states) + return mx.nd.log_softmax(out), states + + def translate(self, src_seq, src_valid_length): + batch_size = src_seq.shape[0] + encoder_outputs, _ = self._model.encode(src_seq, valid_length=src_valid_length) + decoder_states = self._model.decoder.init_state_from_encoder(encoder_outputs, + src_valid_length) + inputs = mx.nd.full(shape=(batch_size,), ctx=src_seq.context, dtype=np.float32, + val=self._model.tgt_vocab.token_to_idx[ + self._model.tgt_vocab.bos_token]) + samples, scores, sample_valid_length = self._sampler(inputs, decoder_states) + return samples, scores, sample_valid_length +``` + +Here we can see that a `BeamSearchTranslator` will make calls to `encode`, `decode_step`, `decoder.init_state_from_encoder`, `tgt_vocab.token_to_idx`, which are functions or objects defined in `NMTModel`. This means, if we can define a customized `NMTModel` class, say `ONNXNMTModel`, and define those same interfaces, then it would be compatible with `BeamSearchTranslator`. Whether within this customized `ONNXNMTModel` we call the original MXNet model or use the exported ONNX models with ONNXRuntime, it would not matter from the `BeamSearchTranslator`'s perspecrive. If you are intrested in this customized class, you can refer to `CustomNMTModel.py` for the full implementation. + +```{.python .input} +import nmt +import utils +import CustomNMTModel + +# detokenizer +wmt_detokenizer = nlp.data.SacreMosesDetokenizer() + +# create a custom ONNXNMTModel +onnxnmtmodel = CustomNMTModel.ONNXNMTModel(wmt_transformer_model.tgt_vocab, + 'components/src_embed.onnx', + 'components/encoder.onnx', + 'components/tgt_embed.onnx', + 'components/one_step_ahead_decoder.onnx', + 'components/tgt_proj.onnx') + +# define beam search translator +onnx_wmt_translator = nmt.translation.BeamSearchTranslator( + # note here that we are using an ONNXNMTModel object to replace the actual + # transformer model which is an NMTModel object + model=onnxnmtmodel, # wmt_transformer_model, + beam_size=hparams.beam_size, + scorer=nlp.model.BeamSearchScorer(alpha=hparams.lp_alpha, K=hparams.lp_k), + max_length=200) + + +print(type(wmt_transformer_model)) +print(type(onnxnmtmodel)) +``` + +## Machine Translation powered by MXNet+GluonNLP... Plus MX2ONNX+ONNXRuntime + +Now that we have a `BeamSearchTranslator` defined, we can run end-to-end inference on it! + +```{.python .input} +# input English sentence +print('Translate the following English sentence into German:') +sample_src_seq = 'I am a software engineer and I love to play football .' +print('[\'' + sample_src_seq + '\']') + +# run end2end inference using onnxruntime + gluonnlp beam search +sample_tgt_seq = utils.translate(onnx_wmt_translator, + sample_src_seq, + wmt_src_vocab, + wmt_tgt_vocab, + wmt_detokenizer, + ctx) + +# output German sentence +print('The German translation is:') +print(sample_tgt_seq) +```