|
32 | 32 | "from datetime import datetime\n",
|
33 | 33 | "import requests\n",
|
34 | 34 | "from copy import deepcopy\n",
|
35 |
| - "print(\"We are using Tensorflow version: \", tf.__version__)" |
| 35 | + "\n", |
| 36 | + "print(\"TF version:\", tf.__version__)\n", |
| 37 | + "print(\"Hub version:\", hub.__version__)" |
36 | 38 | ]
|
37 | 39 | },
|
38 | 40 | {
|
|
158 | 160 | "metadata": {},
|
159 | 161 | "outputs": [],
|
160 | 162 | "source": [
|
161 |
| - "# import pre-trained fp_32 model\n", |
162 |
| - "fp32_model = tf.keras.models.load_model('models/my_saved_model_fp32')" |
| 163 | + "IMAGE_SIZE = (224, 224, 3)\n", |
| 164 | + "model_handle = \"https://www.kaggle.com/models/google/resnet-v1/TensorFlow2/50-feature-vector/2\"\n", |
| 165 | + "\n", |
| 166 | + "print(\"Building model with\", model_handle)\n", |
| 167 | + "fp32_model = tf.keras.Sequential([\n", |
| 168 | + " # Explicitly define the input shape so the model can be properly\n", |
| 169 | + " # loaded by the TFLiteConverter\n", |
| 170 | + " tf.keras.layers.InputLayer(input_shape=IMAGE_SIZE),\n", |
| 171 | + " hub.KerasLayer(model_handle, trainable=False),\n", |
| 172 | + " tf.keras.layers.Dropout(rate=0.2),\n", |
| 173 | + " tf.keras.layers.Dense(len(class_names),\n", |
| 174 | + " kernel_regularizer=tf.keras.regularizers.l2(0.0001))\n", |
| 175 | + "])\n", |
| 176 | + "fp32_model.build((None,)+IMAGE_SIZE)\n", |
| 177 | + "fp32_model.summary()" |
| 178 | + ] |
| 179 | + }, |
| 180 | + { |
| 181 | + "cell_type": "code", |
| 182 | + "execution_count": null, |
| 183 | + "id": "a9e75958", |
| 184 | + "metadata": {}, |
| 185 | + "outputs": [], |
| 186 | + "source": [ |
| 187 | + "fp32_model.compile(\n", |
| 188 | + " optimizer=tf.keras.optimizers.SGD(learning_rate=0.005, momentum=0.9), \n", |
| 189 | + " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", |
| 190 | + " metrics=['accuracy'])" |
163 | 191 | ]
|
164 | 192 | },
|
165 | 193 | {
|
|
204 | 232 | " break"
|
205 | 233 | ]
|
206 | 234 | },
|
| 235 | + { |
| 236 | + "cell_type": "markdown", |
| 237 | + "id": "e1f94d18", |
| 238 | + "metadata": {}, |
| 239 | + "source": [ |
| 240 | + "Lets save the model . . . " |
| 241 | + ] |
| 242 | + }, |
| 243 | + { |
| 244 | + "cell_type": "code", |
| 245 | + "execution_count": null, |
| 246 | + "id": "924ab886", |
| 247 | + "metadata": {}, |
| 248 | + "outputs": [], |
| 249 | + "source": [ |
| 250 | + "fp32_model.save(\"models/my_saved_model_fp32\")" |
| 251 | + ] |
| 252 | + }, |
207 | 253 | {
|
208 | 254 | "cell_type": "markdown",
|
209 | 255 | "id": "c1d45604",
|
|
318 | 364 | "plt.title(\"Resnet50 Inference Time\")\n",
|
319 | 365 | "plt.xlabel(\"Test Case\")\n",
|
320 | 366 | "plt.ylabel(\"Inference Time (seconds)\")\n",
|
321 |
| - "plt.bar([\"FP32\", \"BF16 with AVX512\", \"BF16 with Intel® AMX\"], [fp32_inference_time, bf16_noAmx_inference_time, bf16_withAmx_inference_time])" |
| 367 | + "plt.bar([\"FP32\", \"BF16 with AVX512\", \"BF16 with Intel® AMX\"], [fp32_inference_time, bf16_noAmx_inference_time, bf16_withAmx_inference_time]);" |
322 | 368 | ]
|
323 | 369 | },
|
324 | 370 | {
|
|
347 | 393 | "plt.title(\"Intel® AMX Speedup\")\n",
|
348 | 394 | "plt.xlabel(\"Test Case\")\n",
|
349 | 395 | "plt.ylabel(\"Speedup\")\n",
|
350 |
| - "plt.bar([\"FP32\", \"BF16 with AVX512\", \"BF16 with Intel® AMX\"], [1, speedup_bf16_noAMX_from_fp32, speedup_bf16_withAMX_from_fp32])" |
| 396 | + "plt.bar([\"FP32\", \"BF16 with AVX512\", \"BF16 with Intel® AMX\"], [1, speedup_bf16_noAMX_from_fp32, speedup_bf16_withAMX_from_fp32]);" |
351 | 397 | ]
|
352 | 398 | },
|
353 | 399 | {
|
|
372 | 418 | "plt.title(\"Resnet50 Inference Accuracy\")\n",
|
373 | 419 | "plt.xlabel(\"Test Case\")\n",
|
374 | 420 | "plt.ylabel(\"Inference Accuracy\")\n",
|
375 |
| - "plt.bar([\"FP32\", \"BF16 with AVX512\", \"BF16 with Intel® AMX\"], [fp32_inference_accuracy, bf16_noAmx_inference_accuracy, bf16_withAmx_inference_accuracy])" |
| 421 | + "plt.bar([\"FP32\", \"BF16 with AVX512\", \"BF16 with Intel® AMX\"], [fp32_inference_accuracy, bf16_noAmx_inference_accuracy, bf16_withAmx_inference_accuracy]);" |
376 | 422 | ]
|
377 | 423 | },
|
378 | 424 | {
|
|
0 commit comments