Skip to content

Commit 6beaaa3

Browse files
authored
Fixed IntelTensorFlow_AMX_BF16_Inference (#2397)
* using new tf & keras version * deleted models file, will download it from the Hub
1 parent e03f296 commit 6beaaa3

File tree

9 files changed

+326
-145
lines changed

9 files changed

+326
-145
lines changed

AI-and-Analytics/Features-and-Functionality/IntelTensorFlow_AMX_BF16_Inference/IntelTensorFlow_AMX_BF16_Inference.ipynb

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,9 @@
3232
"from datetime import datetime\n",
3333
"import requests\n",
3434
"from copy import deepcopy\n",
35-
"print(\"We are using Tensorflow version: \", tf.__version__)"
35+
"\n",
36+
"print(\"TF version:\", tf.__version__)\n",
37+
"print(\"Hub version:\", hub.__version__)"
3638
]
3739
},
3840
{
@@ -158,8 +160,34 @@
158160
"metadata": {},
159161
"outputs": [],
160162
"source": [
161-
"# import pre-trained fp_32 model\n",
162-
"fp32_model = tf.keras.models.load_model('models/my_saved_model_fp32')"
163+
"IMAGE_SIZE = (224, 224, 3)\n",
164+
"model_handle = \"https://www.kaggle.com/models/google/resnet-v1/TensorFlow2/50-feature-vector/2\"\n",
165+
"\n",
166+
"print(\"Building model with\", model_handle)\n",
167+
"fp32_model = tf.keras.Sequential([\n",
168+
" # Explicitly define the input shape so the model can be properly\n",
169+
" # loaded by the TFLiteConverter\n",
170+
" tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE),\n",
171+
" hub.KerasLayer(model_handle, trainable=False),\n",
172+
" tf.keras.layers.Dropout(rate=0.2),\n",
173+
" tf.keras.layers.Dense(len(class_names),\n",
174+
" kernel_regularizer=tf.keras.regularizers.l2(0.0001))\n",
175+
"])\n",
176+
"fp32_model.build((None,)+IMAGE_SIZE)\n",
177+
"fp32_model.summary()"
178+
]
179+
},
180+
{
181+
"cell_type": "code",
182+
"execution_count": null,
183+
"id": "a9e75958",
184+
"metadata": {},
185+
"outputs": [],
186+
"source": [
187+
"fp32_model.compile(\n",
188+
" optimizer=tf.keras.optimizers.SGD(learning_rate=0.005, momentum=0.9), \n",
189+
" loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
190+
" metrics=['accuracy'])"
163191
]
164192
},
165193
{
@@ -204,6 +232,24 @@
204232
" break"
205233
]
206234
},
235+
{
236+
"cell_type": "markdown",
237+
"id": "e1f94d18",
238+
"metadata": {},
239+
"source": [
240+
"Lets save the model . . . "
241+
]
242+
},
243+
{
244+
"cell_type": "code",
245+
"execution_count": null,
246+
"id": "924ab886",
247+
"metadata": {},
248+
"outputs": [],
249+
"source": [
250+
"fp32_model.save(\"models/my_saved_model_fp32\")"
251+
]
252+
},
207253
{
208254
"cell_type": "markdown",
209255
"id": "c1d45604",
@@ -318,7 +364,7 @@
318364
"plt.title(\"Resnet50 Inference Time\")\n",
319365
"plt.xlabel(\"Test Case\")\n",
320366
"plt.ylabel(\"Inference Time (seconds)\")\n",
321-
"plt.bar([\"FP32\", \"BF16 with AVX512\", \"BF16 with Intel® AMX\"], [fp32_inference_time, bf16_noAmx_inference_time, bf16_withAmx_inference_time])"
367+
"plt.bar([\"FP32\", \"BF16 with AVX512\", \"BF16 with Intel® AMX\"], [fp32_inference_time, bf16_noAmx_inference_time, bf16_withAmx_inference_time]);"
322368
]
323369
},
324370
{
@@ -347,7 +393,7 @@
347393
"plt.title(\"Intel® AMX Speedup\")\n",
348394
"plt.xlabel(\"Test Case\")\n",
349395
"plt.ylabel(\"Speedup\")\n",
350-
"plt.bar([\"FP32\", \"BF16 with AVX512\", \"BF16 with Intel® AMX\"], [1, speedup_bf16_noAMX_from_fp32, speedup_bf16_withAMX_from_fp32])"
396+
"plt.bar([\"FP32\", \"BF16 with AVX512\", \"BF16 with Intel® AMX\"], [1, speedup_bf16_noAMX_from_fp32, speedup_bf16_withAMX_from_fp32]);"
351397
]
352398
},
353399
{
@@ -372,7 +418,7 @@
372418
"plt.title(\"Resnet50 Inference Accuracy\")\n",
373419
"plt.xlabel(\"Test Case\")\n",
374420
"plt.ylabel(\"Inference Accuracy\")\n",
375-
"plt.bar([\"FP32\", \"BF16 with AVX512\", \"BF16 with Intel® AMX\"], [fp32_inference_accuracy, bf16_noAmx_inference_accuracy, bf16_withAmx_inference_accuracy])"
421+
"plt.bar([\"FP32\", \"BF16 with AVX512\", \"BF16 with Intel® AMX\"], [fp32_inference_accuracy, bf16_noAmx_inference_accuracy, bf16_withAmx_inference_accuracy]);"
376422
]
377423
},
378424
{

0 commit comments

Comments
 (0)