Skip to content

Commit 8281254

Browse files
authored
add huggingface tutorial (#1508)
Signed-off-by: Guenther Schmuelling <[email protected]>
1 parent 122556e commit 8281254

File tree

2 files changed

+218
-7
lines changed

2 files changed

+218
-7
lines changed

tutorials/README.md

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,21 @@
44

55
The following tutorials show how to convert various models to ONNX.
66

7-
[The original Bert model](https://github.com/onnx/tensorflow-onnx/blob/master/tutorials/BertTutorial.ipynb)<br/>
8-
By now this model is a bit old and it is much easier to use huggingface. To see how to convert the huggingface tensorflow models see [huggingface.py](https://github.com/onnx/tensorflow-onnx/blob/master/tests/huggingface.py)
7+
## Image Classifiers
8+
[efficientnet-edge](https://github.com/onnx/tensorflow-onnx/blob/master/tutorials/efficientnet-edge.ipynb)
9+
10+
[efficientnet-lite](https://github.com/onnx/tensorflow-onnx/blob/master/tutorials/efficientnet-lite.ipynb)
911

12+
[keras-resnet50](https://github.com/onnx/tensorflow-onnx/blob/master/tutorials/keras-resnet50.ipynb) - shows how to convert a keras model via python api
1013

14+
## Object Detectors
1115
[ssd-mobilenet](https://github.com/onnx/tensorflow-onnx/blob/master/tutorials/ConvertingSSDMobilenetToONNX.ipynb)
1216

1317
[efficientdet](https://github.com/onnx/tensorflow-onnx/blob/master/tutorials/efficientdet.ipynb)
1418

15-
[efficientnet-edge](https://github.com/onnx/tensorflow-onnx/blob/master/tutorials/efficientnet-edge.ipynb)
16-
17-
[efficientnet-lite](https://github.com/onnx/tensorflow-onnx/blob/master/tutorials/efficientnet-lite.ipynb)
19+
[mobiledet](https://github.com/onnx/tensorflow-onnx/blob/master/tutorials/mobiledet-tflite.ipynb) - shows how to convert a tflite model
1820

19-
[keras-resnet50](https://github.com/onnx/tensorflow-onnx/blob/master/tutorials/keras-resnet50.ipynb), shows how to convert a keras model via python api
21+
## Nlp
22+
[Huggingface Bert Example](https://github.com/onnx/tensorflow-onnx/blob/master/tutorials/huggingface-bert.ipynb)
2023

21-
[mobiledet](https://github.com/onnx/tensorflow-onnx/blob/master/tutorials/mobiledet-tflite.ipynb), shows how to convert a tflite model
24+
[The original Tensorflow Bert model](https://github.com/onnx/tensorflow-onnx/blob/master/tutorials/BertTutorial.ipynb) - depreciated, use huggingface

tutorials/huggingface-bert.ipynb

Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# Converting a Huggingface model to ONNX with tf2onnx\n",
8+
"\n",
9+
"This is a simple example how to convert a [huggingface](https://huggingface.co/) model to ONNX using [tf2onnx](https://github.com/onnx/tensorflow-onnx).\n",
10+
"\n",
11+
"We use the [TFBertForQuestionAnswering](https://huggingface.co/transformers/model_doc/bert.html#tfbertforquestionanswering) example from huggingface.\n",
12+
"\n",
13+
"Other models will work similar. You'll find additional examples for other models in our unit tests [here](https://github.com/onnx/tensorflow-onnx/blob/master/tests/huggingface.py)."
14+
]
15+
},
16+
{
17+
"cell_type": "markdown",
18+
"metadata": {},
19+
"source": [
20+
"## Install dependencies"
21+
]
22+
},
23+
{
24+
"cell_type": "code",
25+
"execution_count": null,
26+
"metadata": {},
27+
"outputs": [],
28+
"source": [
29+
"!pip install tensorflow transformers tf2onnx onnxruntime"
30+
]
31+
},
32+
{
33+
"cell_type": "markdown",
34+
"metadata": {},
35+
"source": [
36+
"## The keras code"
37+
]
38+
},
39+
{
40+
"cell_type": "code",
41+
"execution_count": 1,
42+
"metadata": {},
43+
"outputs": [],
44+
"source": [
45+
"import os\n",
46+
"os.environ['CUDA_VISIBLE_DEVICES'] = \"\"\n",
47+
"\n",
48+
"import warnings\n",
49+
"warnings.filterwarnings('ignore')\n",
50+
"\n",
51+
"import numpy as np\n",
52+
"import onnxruntime as rt\n",
53+
"import tensorflow as tf\n",
54+
"import tf2onnx"
55+
]
56+
},
57+
{
58+
"cell_type": "code",
59+
"execution_count": 2,
60+
"metadata": {},
61+
"outputs": [
62+
{
63+
"name": "stderr",
64+
"output_type": "stream",
65+
"text": []
66+
},
67+
{
68+
"data": {
69+
"text/plain": [
70+
"TFQuestionAnsweringModelOutput(loss=None, start_logits=<tf.Tensor: shape=(1, 16), dtype=float32, numpy=\n",
71+
"array([[ 0.27443457, 0.02250022, -0.32903647, -0.32448006, -0.26440915,\n",
72+
" -0.03356116, -0.11466929, -0.12272861, -0.23254037, -0.21369037,\n",
73+
" 0.02170385, -0.38734213, -0.14865303, -0.04804918, 0.02706608,\n",
74+
" -0.12273058]], dtype=float32)>, end_logits=<tf.Tensor: shape=(1, 16), dtype=float32, numpy=\n",
75+
"array([[-0.23549399, 0.11830041, -0.16875415, 0.04315909, 0.00721513,\n",
76+
" 0.20957005, 0.00850991, -0.49158442, 0.10791501, 0.07153591,\n",
77+
" 0.26274043, -0.15160318, -0.01847767, 0.03389414, 0.25666913,\n",
78+
" -0.49158433]], dtype=float32)>, hidden_states=None, attentions=None)"
79+
]
80+
},
81+
"execution_count": 2,
82+
"metadata": {},
83+
"output_type": "execute_result"
84+
}
85+
],
86+
"source": [
87+
"from transformers import BertTokenizer, TFBertForQuestionAnswering\n",
88+
"tokenizer = BertTokenizer.from_pretrained('bert-base-cased')\n",
89+
"model = TFBertForQuestionAnswering.from_pretrained('bert-base-cased')\n",
90+
"question, text = \"Who was Jim Henson?\", \"Jim Henson was a nice puppet\"\n",
91+
"input_dict = tokenizer(question, text, return_tensors='tf')\n",
92+
"tf_results = model(input_dict)\n",
93+
"tf_results"
94+
]
95+
},
96+
{
97+
"cell_type": "markdown",
98+
"metadata": {},
99+
"source": [
100+
"## Convert to ONNX"
101+
]
102+
},
103+
{
104+
"cell_type": "code",
105+
"execution_count": 3,
106+
"metadata": {},
107+
"outputs": [
108+
{
109+
"name": "stderr",
110+
"output_type": "stream",
111+
"text": []
112+
}
113+
],
114+
"source": [
115+
"# describe the inputs\n",
116+
"input_spec = (\n",
117+
" tf.TensorSpec((None, None), tf.int32, name=\"input_ids\"),\n",
118+
" tf.TensorSpec((None, None), tf.int32, name=\"token_type_ids\"),\n",
119+
" tf.TensorSpec((None, None), tf.int32, name=\"attention_mask\")\n",
120+
")\n",
121+
"\n",
122+
"# and convert\n",
123+
"_, _ = tf2onnx.convert.from_keras(model, input_signature=input_spec, opset=13, output_path=\"bert.onnx\")"
124+
]
125+
},
126+
{
127+
"cell_type": "markdown",
128+
"metadata": {},
129+
"source": [
130+
"## Test the ONNX model with onnxruntime"
131+
]
132+
},
133+
{
134+
"cell_type": "code",
135+
"execution_count": 4,
136+
"metadata": {},
137+
"outputs": [
138+
{
139+
"data": {
140+
"text/plain": [
141+
"[array([[ 0.27443478, 0.02250013, -0.32903633, -0.32448038, -0.26440892,\n",
142+
" -0.03356095, -0.11466938, -0.12272887, -0.2325401 , -0.21369015,\n",
143+
" 0.02170385, -0.3873423 , -0.148653 , -0.04804894, 0.02706566,\n",
144+
" -0.1227307 ]], dtype=float32),\n",
145+
" array([[-0.23549382, 0.11830062, -0.16875397, 0.0431588 , 0.00721494,\n",
146+
" 0.2095699 , 0.00850987, -0.49158436, 0.10791501, 0.07153573,\n",
147+
" 0.26274025, -0.15160298, -0.01847767, 0.03389416, 0.25666922,\n",
148+
" -0.49158415]], dtype=float32)]"
149+
]
150+
},
151+
"execution_count": 4,
152+
"metadata": {},
153+
"output_type": "execute_result"
154+
}
155+
],
156+
"source": [
157+
"# get the names we want as output\n",
158+
"output_names = list(tf_results.keys())\n",
159+
"\n",
160+
"# switch the input_dict to numpy\n",
161+
"input_dict_np = {k: v.numpy() for k, v in input_dict.items()}\n",
162+
"\n",
163+
"opt = rt.SessionOptions()\n",
164+
"sess = rt.InferenceSession(\"bert.onnx\")\n",
165+
"onnx_results = sess.run(output_names, input_dict_np)\n",
166+
"onnx_results"
167+
]
168+
},
169+
{
170+
"cell_type": "markdown",
171+
"metadata": {},
172+
"source": [
173+
"## Make sure tensorflow and onnxruntime results are the same"
174+
]
175+
},
176+
{
177+
"cell_type": "code",
178+
"execution_count": 5,
179+
"metadata": {},
180+
"outputs": [],
181+
"source": [
182+
"for i, name in enumerate(output_names):\n",
183+
" np.testing.assert_allclose(tf_results[name], onnx_results[i], rtol=1e-5, atol=1e-5)"
184+
]
185+
}
186+
],
187+
"metadata": {
188+
"kernelspec": {
189+
"display_name": "Python [conda env:root] *",
190+
"language": "python",
191+
"name": "conda-root-py"
192+
},
193+
"language_info": {
194+
"codemirror_mode": {
195+
"name": "ipython",
196+
"version": 3
197+
},
198+
"file_extension": ".py",
199+
"mimetype": "text/x-python",
200+
"name": "python",
201+
"nbconvert_exporter": "python",
202+
"pygments_lexer": "ipython3",
203+
"version": "3.7.3"
204+
}
205+
},
206+
"nbformat": 4,
207+
"nbformat_minor": 2
208+
}

0 commit comments

Comments
 (0)