|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +import json |
| 18 | +import logging |
| 19 | + |
| 20 | +import mxnet as mx |
| 21 | +import gluonnlp as nlp |
| 22 | + |
| 23 | + |
| 24 | +class BertHandler: |
| 25 | + """GluonNLP based Bert Handler""" |
| 26 | + |
| 27 | + def __init__(self): |
| 28 | + self.error = None |
| 29 | + self._context = None |
| 30 | + self.initialized = False |
| 31 | + |
| 32 | + def initialize(self, context): |
| 33 | + """ |
| 34 | + Initialize model. This will be called during model loading time |
| 35 | + :param context: Initial context contains model server system properties. |
| 36 | + :return: |
| 37 | + """ |
| 38 | + self._context = context |
| 39 | + gpu_id = context.system_properties["gpu_id"] |
| 40 | + self._mx_ctx = mx.cpu() if gpu_id is None else mx.gpu(gpu_id) |
| 41 | + bert, vocab = nlp.model.get_model('bert_12_768_12', |
| 42 | + dataset_name='book_corpus_wiki_en_uncased', |
| 43 | + pretrained=False, ctx=self._mx_ctx, use_pooler=True, |
| 44 | + use_decoder=False, use_classifier=False) |
| 45 | + tokenizer = nlp.data.BERTTokenizer(vocab, lower=True) |
| 46 | + self.sentence_transform = nlp.data.BERTSentenceTransform(tokenizer, max_seq_length=128, |
| 47 | + vocab=vocab, pad=True, pair=False) |
| 48 | + self.batchify = nlp.data.batchify.Tuple( |
| 49 | + nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token]), # input |
| 50 | + nlp.data.batchify.Stack(), # length |
| 51 | + nlp.data.batchify.Pad(axis=0, pad_val=0)) # segment |
| 52 | + # Set dropout to non-zero, to match pretrained model parameter names |
| 53 | + self.net = nlp.model.BERTClassifier(bert, dropout=0.1) |
| 54 | + self.net.load_parameters('sst.params', self._mx_ctx) |
| 55 | + self.net.hybridize() |
| 56 | + |
| 57 | + self.initialized = True |
| 58 | + |
| 59 | + def handle(self, batch, context): |
| 60 | + # we're just faking batch_size==1 but allow dynamic batch size. Ie the |
| 61 | + # actual batch size is the len of the first element. |
| 62 | + try: |
| 63 | + assert len(batch) == 1 |
| 64 | + batch = json.loads(batch[0]["data"].decode('utf-8')) |
| 65 | + except (json.JSONDecodeError, KeyError, AssertionError) as e: |
| 66 | + print('call like: curl -X POST http://127.0.0.1:8080/bert_sst/predict ' |
| 67 | + '-F \'data=["sentence 1", "sentence 2"]\'') |
| 68 | + raise e |
| 69 | + model_input = self.batchify([self.sentence_transform(sentence) for sentence in batch]) |
| 70 | + |
| 71 | + inputs, valid_length, token_types = [arr.as_in_context(self._mx_ctx) for arr in model_input] |
| 72 | + inference_output = self.net(inputs, token_types, valid_length.astype('float32')) |
| 73 | + inference_output = inference_output.as_in_context(mx.cpu()) |
| 74 | + |
| 75 | + return [mx.nd.softmax(inference_output).argmax(axis=1).astype('int').asnumpy().tolist()] |
| 76 | + |
| 77 | + |
| 78 | +_service = BertHandler() |
| 79 | + |
| 80 | + |
| 81 | +def handle(data, context): |
| 82 | + if not _service.initialized: |
| 83 | + _service.initialize(context) |
| 84 | + |
| 85 | + if data is None: |
| 86 | + return None |
| 87 | + |
| 88 | + return _service.handle(data, context) |
0 commit comments