|
4 | 4 | "cell_type": "markdown",
|
5 | 5 | "metadata": {},
|
6 | 6 | "source": [
|
7 |
| - "## How to run a bert model under ONNX\n", |
| 7 | + "# Converting a Tensorflow Bert model to ONNX\n", |
8 | 8 | "\n",
|
9 |
| - "This tutorial shows how to convert the original tensorflow bert model to ONNX. In this example we use a bert model that is fine tuned for squad-1.1 on top of [BERT-Base, Uncased](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip).\n", |
| 9 | + "This tutorial shows how to convert the original Tensorflow Bert model to ONNX. \n", |
| 10 | + "In this example we fine tune Bert for squad-1.1 on top of [BERT-Base, Uncased](https://storage.googleapis.com/bert_models/2018_10_18/uncased_L-12_H-768_A-12.zip).\n", |
10 | 11 | "\n",
|
11 |
| - "To keep this tuturial at a resonable size we reuse tokenizer and utilities defined in the bert source tree for onnx.\n", |
12 |
| - "We used the following versions:\n", |
| 12 | + "Since this tutorial cares mostly about the conversion process we reuse tokenizer and utilities defined in the Bert source tree as much as possible.\n", |
| 13 | + "\n", |
| 14 | + "This should work with all versions supported by the [tensorflow-onnx converter](https://github.com/onnx/tensorflow-onnx), we used the following versions while writing the tutorial:\n", |
13 | 15 | "```\n",
|
14 | 16 | "tensorflow-gpu: 1.13.1\n",
|
15 | 17 | "onnx: 1.5.1\n",
|
16 | 18 | "tf2onnx: 1.5.1\n",
|
17 | 19 | "onnxruntime: 0.4\n",
|
18 | 20 | "```\n",
|
19 | 21 | "\n",
|
20 |
| - "The steps to convert the models:\n", |
21 |
| - "1. setup our environment\n", |
22 |
| - "2. clone the tensorflow bert model from https://github.com/google-research/bert\n", |
23 |
| - "3. download the pretrained model and the squad-1.1 dataset\n", |
24 |
| - "4. fine tune on squad\n", |
25 |
| - "5. export the inference graph as saved_model format\n", |
26 |
| - "6. convert the saved_model to onnx\n", |
27 |
| - "7. run the converted model in onnxruntime" |
| 22 | + "To make the fine tuning work on my Gtx-1080 gpu, we changed the MAX_SEQ_LENGTH to 256 and used a training batch size of 8." |
28 | 23 | ]
|
29 | 24 | },
|
30 | 25 | {
|
31 | 26 | "cell_type": "markdown",
|
32 | 27 | "metadata": {},
|
33 | 28 | "source": [
|
34 |
| - "## Step 1\n", |
35 |
| - "Before we start, lets setup some varibales where to find things." |
| 29 | + "## Step 1 - define some environment variables\n", |
| 30 | + "Before we start, lets setup some variables where to find things." |
36 | 31 | ]
|
37 | 32 | },
|
38 | 33 | {
|
39 | 34 | "cell_type": "code",
|
40 |
| - "execution_count": 4, |
| 35 | + "execution_count": 1, |
41 | 36 | "metadata": {},
|
42 | 37 | "outputs": [],
|
43 | 38 | "source": [
|
|
62 | 57 | "cell_type": "markdown",
|
63 | 58 | "metadata": {},
|
64 | 59 | "source": [
|
65 |
| - "## Step 2 \n", |
66 |
| - "Clone https://github.com/google-research/bert" |
| 60 | + "## Step 2 - clone the Bert github repository" |
67 | 61 | ]
|
68 | 62 | },
|
69 | 63 | {
|
|
92 | 86 | "cell_type": "markdown",
|
93 | 87 | "metadata": {},
|
94 | 88 | "source": [
|
95 |
| - "## Step 3\n", |
96 |
| - "Download the pretrained bert model and the squad-1.1 dataset" |
| 89 | + "## Step 3 - download the pretrained Bert model and squad-1.1 dataset" |
97 | 90 | ]
|
98 | 91 | },
|
99 | 92 | {
|
|
112 | 105 | "!wget -O squad-1.1/evaluate-v1.1.json https://rajpurkar.github.io/SQuAD-explorer/dataset/evaluate-v1.1.json "
|
113 | 106 | ]
|
114 | 107 | },
|
115 |
| - { |
116 |
| - "cell_type": "code", |
117 |
| - "execution_count": null, |
118 |
| - "metadata": {}, |
119 |
| - "outputs": [], |
120 |
| - "source": [ |
121 |
| - "!mkdir squad-1.1 out\n", |
122 |
| - "!wget -O squad-1.1/train-v1.1.json https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v1.1.json \n", |
123 |
| - "!wget -O squad-1.1/dev-v1.1.json https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v1.1.json \n", |
124 |
| - "!wget -O squad-1.1/evaluate-v1.1.json https://rajpurkar.github.io/SQuAD-explorer/dataset/evaluate-v1.1.json " |
125 |
| - ] |
126 |
| - }, |
127 | 108 | {
|
128 | 109 | "cell_type": "markdown",
|
129 | 110 | "metadata": {},
|
130 | 111 | "source": [
|
131 |
| - "## Step 4\n", |
132 |
| - "Fine tune the bert model on squad-1.1. This is the same as described in the bert repository. We use a smaller MAX_SEQ_LENGTH and batch size so this trains nicely on a Gtx1080. If you already have a fined tuned model you can just copy it into the ```out``` folder." |
| 112 | + "## Step 4 - fine tune the Bert model for squad-1.1\n", |
| 113 | + "This is the same as described in the [Bert repository](https://github.com/google-research/bert). You need to do this only once.\n" |
133 | 114 | ]
|
134 | 115 | },
|
135 | 116 | {
|
|
164 | 145 | "cell_type": "markdown",
|
165 | 146 | "metadata": {},
|
166 | 147 | "source": [
|
167 |
| - "## Step 5\n", |
168 |
| - "With a fined tuned model in hands we want to create a inference graph for it and save it to saved_model format." |
| 148 | + "## Step 5 - create the inference graph and save it\n", |
| 149 | + "With a fined tuned model in hands we want to create the inference graph for it and save it as saved_model format.\n", |
| 150 | + "\n", |
| 151 | + "***We assune that after 2 epochs the checkpoint is model.ckpt-21899 - if the following code does not find it, check the $OUT directory for the higest checkpoint***." |
169 | 152 | ]
|
170 | 153 | },
|
171 | 154 | {
|
|
297 | 280 | }
|
298 | 281 | ],
|
299 | 282 | "source": [
|
| 283 | + "# N is the number of examples we are evaluating. On the CPU this might take a bit.\n", |
| 284 | + "# During development you can set N to some more practical\n", |
300 | 285 | "N = len(eval_features)\n",
|
301 |
| - "N = 100\n", |
302 | 286 | "\n",
|
303 | 287 | "all_results = []\n",
|
304 | 288 | "for result in estimator.predict(predict_input_fn, yield_single_examples=True):\n",
|
305 | 289 | " if len(all_results) % 1000 == 0:\n",
|
306 |
| - " print(\"example: %d\" % (len(all_results)))\n", |
| 290 | + " print(\"sample: %d\" % (len(all_results)))\n", |
307 | 291 | " unique_id = int(result[\"unique_ids\"])\n",
|
308 | 292 | " start_logits = [float(x) for x in result[\"start_logits\"].flat]\n",
|
309 | 293 | " end_logits = [float(x) for x in result[\"end_logits\"].flat]\n",
|
|
346 | 330 | " }\n",
|
347 | 331 | " return tf.estimator.export.ServingInputReceiver(receiver_tensors, receiver_tensors)\n",
|
348 | 332 | "\n",
|
349 |
| - "#estimator._export_to_tpu = False\n", |
350 | 333 | "path = estimator.export_savedmodel(os.path.join(OUT, \"export\"), serving_input_fn)\n",
|
351 | 334 | "os.environ['LAST_SAVED_MODEL'] = path.decode('utf-8')"
|
352 | 335 | ]
|
|
366 | 349 | "metadata": {},
|
367 | 350 | "outputs": [],
|
368 | 351 | "source": [
|
369 |
| - "# install tf2onnx if needed\n", |
370 |
| - "!pip install tf2onnx" |
| 352 | + "# install the latest version of tf2onnx if needed\n", |
| 353 | + "!pip install -U tf2onnx" |
371 | 354 | ]
|
372 | 355 | },
|
373 | 356 | {
|
|
394 | 377 | ],
|
395 | 378 | "source": [
|
396 | 379 | "# convert model\n",
|
| 380 | + "# because we still have a tensorflow session open in this notebook, force the converter to use the CPU.\n", |
| 381 | + "#\n", |
397 | 382 | "!CUDA_VISIBLE_DEVICES='' python -m tf2onnx.convert --saved-model $LAST_SAVED_MODEL --output $OUT/bert.onnx --opset 8"
|
398 | 383 | ]
|
399 | 384 | },
|
|
408 | 393 | "cell_type": "markdown",
|
409 | 394 | "metadata": {},
|
410 | 395 | "source": [
|
411 |
| - "Lets look at the inputs to the ONNX model. The input 'unique_ids' is special and creates some issue in onnx: the input is passed directly to the output and in tensorflow both have the same name. In ONNX that is not supported and the converter creates a name. We need to use that created name so we remember it." |
| 396 | + "Lets look at the inputs to the ONNX model. The input 'unique_ids' is special and creates some issue in ONNX: the input is passed directly to the output and in Tensorflow both have the same name. In ONNX that is not supported and the converter creates a new name for the input. We need to use that created name so we remember it." |
412 | 397 | ]
|
413 | 398 | },
|
414 | 399 | {
|
|
456 | 441 | "source": [
|
457 | 442 | "RawResult = collections.namedtuple(\"RawResult\", [\"unique_id\", \"start_logits\", \"end_logits\"])\n",
|
458 | 443 | "\n",
|
459 |
| - "batch_size = 1\n", |
460 |
| - "N = len(eval_features)\n", |
461 |
| - "N = 100\n", |
462 |
| - "\n", |
463 | 444 | "all_results = []\n",
|
464 | 445 | "for idx in range(0, N):\n",
|
465 | 446 | " item = eval_features[idx]\n",
|
466 | 447 | " # this is using batch_size=1\n",
|
| 448 | + " # feed the input data as int64\n", |
467 | 449 | " data = {\"unique_ids_raw_output___9:0\": np.array([item.unique_id], dtype=np.int64),\n",
|
468 | 450 | " \"input_ids:0\": np.array([item.input_ids], dtype=np.int64),\n",
|
469 | 451 | " \"input_mask:0\": np.array([item.input_mask], dtype=np.int64),\n",
|
470 | 452 | " \"segment_ids:0\": np.array([item.segment_ids], dtype=np.int64)}\n",
|
471 | 453 | " result = sess.run([\"unique_ids:0\", \"unstack:0\", \"unstack:1\"], data)\n",
|
472 | 454 | " unique_id = result[0][0]\n",
|
473 |
| - " start_logits = result[1][0]\n", |
474 |
| - " end_logits = result[2][0]\n", |
475 |
| - " start_logits = [float(x) for x in start_logits.flat]\n", |
476 |
| - " end_logits = [float(x) for x in end_logits.flat]\n", |
477 |
| - "\n", |
478 |
| - " # all_results.append(RawResult(unique_id=unique_id, start_logits=result[0][0][i], end_logits=result[1][0][i]))\n", |
| 455 | + " start_logits = [float(x) for x in result[1][0].flat]\n", |
| 456 | + " end_logits = [float(x) for x in result[2][0].flat]\n", |
479 | 457 | " all_results.append(RawResult(unique_id=unique_id, start_logits=start_logits, end_logits=end_logits))\n",
|
480 | 458 | " if unique_id % 1000 == 0:\n",
|
481 |
| - " print(\"example: %d\" % (len(all_results)))\n", |
| 459 | + " print(\"sample: %d\" % (len(all_results)))\n", |
482 | 460 | " if len(all_results) >= N:\n",
|
483 | 461 | " break\n",
|
484 | 462 | "\n",
|
|
493 | 471 | "cell_type": "markdown",
|
494 | 472 | "metadata": {},
|
495 | 473 | "source": [
|
496 |
| - "Compare some results between tensorflow and ONNX:" |
| 474 | + "Compare some results between Tensorflow and ONNX:" |
497 | 475 | ]
|
498 | 476 | },
|
499 | 477 | {
|
|
568 | 546 | "!head -20 $OUT/onnx_predictions.json"
|
569 | 547 | ]
|
570 | 548 | },
|
| 549 | + { |
| 550 | + "cell_type": "markdown", |
| 551 | + "metadata": {}, |
| 552 | + "source": [ |
| 553 | + "## Summary\n", |
| 554 | + "\n", |
| 555 | + "That was all it takes to convert a relativly complex model from Tensorflow to ONNX. \n", |
| 556 | + "\n", |
| 557 | + "You find more documentation about tensorflow-onnx [here](https://github.com/onnx/tensorflow-onnx)." |
| 558 | + ] |
| 559 | + }, |
571 | 560 | {
|
572 | 561 | "cell_type": "code",
|
573 | 562 | "execution_count": null,
|
|
0 commit comments