Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 44547c7

Browse files
authored
1 parent 34fae5c commit 44547c7

File tree

2 files changed

+105
-0
lines changed

2 files changed

+105
-0
lines changed

mms/README.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
Multi-model-server example
2+
==========================
3+
4+
https://github.com/awslabs/multi-model-server/
5+
6+
Assuming you are located in the root of the GluonNLP repo, you can run this
7+
example via:
8+
9+
```
10+
pip install --user multi-model-server
11+
curl https://dist-bert.s3.amazonaws.com/demo/finetune/sst.params -o mms/sst.params
12+
~/.local/bin/model-archiver --model-name bert_sst --model-path mms --handler bert:handle --runtime python --export-path /tmp
13+
~/.local/bin/multi-model-server --start --models bert_sst.mar --model-store /tmp
14+
curl -X POST http://127.0.0.1:8080/bert_sst/predict -F 'data=["Positive sentiment", "Negative sentiment"]'
15+
```
16+
17+

mms/bert.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
import json
18+
import logging
19+
20+
import mxnet as mx
21+
import gluonnlp as nlp
22+
23+
24+
class BertHandler:
25+
"""GluonNLP based Bert Handler"""
26+
27+
def __init__(self):
28+
self.error = None
29+
self._context = None
30+
self.initialized = False
31+
32+
def initialize(self, context):
33+
"""
34+
Initialize model. This will be called during model loading time
35+
:param context: Initial context contains model server system properties.
36+
:return:
37+
"""
38+
self._context = context
39+
gpu_id = context.system_properties["gpu_id"]
40+
self._mx_ctx = mx.cpu() if gpu_id is None else mx.gpu(gpu_id)
41+
bert, vocab = nlp.model.get_model('bert_12_768_12',
42+
dataset_name='book_corpus_wiki_en_uncased',
43+
pretrained=False, ctx=self._mx_ctx, use_pooler=True,
44+
use_decoder=False, use_classifier=False)
45+
tokenizer = nlp.data.BERTTokenizer(vocab, lower=True)
46+
self.sentence_transform = nlp.data.BERTSentenceTransform(tokenizer, max_seq_length=128,
47+
vocab=vocab, pad=True, pair=False)
48+
self.batchify = nlp.data.batchify.Tuple(
49+
nlp.data.batchify.Pad(axis=0, pad_val=vocab[vocab.padding_token]), # input
50+
nlp.data.batchify.Stack(), # length
51+
nlp.data.batchify.Pad(axis=0, pad_val=0)) # segment
52+
# Set dropout to non-zero, to match pretrained model parameter names
53+
self.net = nlp.model.BERTClassifier(bert, dropout=0.1)
54+
self.net.load_parameters('sst.params', self._mx_ctx)
55+
self.net.hybridize()
56+
57+
self.initialized = True
58+
59+
def handle(self, batch, context):
60+
# we're just faking batch_size==1 but allow dynamic batch size. Ie the
61+
# actual batch size is the len of the first element.
62+
try:
63+
assert len(batch) == 1
64+
batch = json.loads(batch[0]["data"].decode('utf-8'))
65+
except (json.JSONDecodeError, KeyError, AssertionError) as e:
66+
print('call like: curl -X POST http://127.0.0.1:8080/bert_sst/predict '
67+
'-F \'data=["sentence 1", "sentence 2"]\'')
68+
raise e
69+
model_input = self.batchify([self.sentence_transform(sentence) for sentence in batch])
70+
71+
inputs, valid_length, token_types = [arr.as_in_context(self._mx_ctx) for arr in model_input]
72+
inference_output = self.net(inputs, token_types, valid_length.astype('float32'))
73+
inference_output = inference_output.as_in_context(mx.cpu())
74+
75+
return [mx.nd.softmax(inference_output).argmax(axis=1).astype('int').asnumpy().tolist()]
76+
77+
78+
_service = BertHandler()
79+
80+
81+
def handle(data, context):
82+
if not _service.initialized:
83+
_service.initialize(context)
84+
85+
if data is None:
86+
return None
87+
88+
return _service.handle(data, context)

0 commit comments

Comments
 (0)