|
32 | 32 | "import tensorflow_hub as hub\n",
|
33 | 33 | "from datetime import datetime\n",
|
34 | 34 | "import requests\n",
|
35 |
| - "import tf_keras\n", |
36 | 35 | "print(\"We are using Tensorflow version: \", tf.__version__)"
|
37 | 36 | ]
|
38 | 37 | },
|
|
99 | 98 | "outputs": [],
|
100 | 99 | "source": [
|
101 | 100 | "tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)\n",
|
102 |
| - "data_root = tf_keras.utils.get_file(\n", |
| 101 | + "data_root = tf.keras.utils.get_file(\n", |
103 | 102 | " 'flower_photos',\n",
|
104 | 103 | " 'https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz',\n",
|
105 | 104 | " untar=True)\n",
|
|
108 | 107 | "img_height = 224\n",
|
109 | 108 | "img_width = 224\n",
|
110 | 109 | "\n",
|
111 |
| - "train_ds = tf_keras.utils.image_dataset_from_directory(\n", |
| 110 | + "train_ds = tf.keras.utils.image_dataset_from_directory(\n", |
112 | 111 | " str(data_root),\n",
|
113 | 112 | " validation_split=0.2,\n",
|
114 | 113 | " subset=\"training\",\n",
|
|
117 | 116 | " batch_size=batch_size\n",
|
118 | 117 | ")\n",
|
119 | 118 | "\n",
|
120 |
| - "val_ds = tf_keras.utils.image_dataset_from_directory(\n", |
| 119 | + "val_ds = tf.keras.utils.image_dataset_from_directory(\n", |
121 | 120 | " str(data_root),\n",
|
122 | 121 | " validation_split=0.2,\n",
|
123 | 122 | " subset=\"validation\",\n",
|
|
147 | 146 | "metadata": {},
|
148 | 147 | "outputs": [],
|
149 | 148 | "source": [
|
150 |
| - "normalization_layer = tf_keras.layers.Rescaling(1./255)\n", |
| 149 | + "normalization_layer = tf.keras.layers.Rescaling(1./255)\n", |
151 | 150 | "train_ds = train_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.\n",
|
152 | 151 | "val_ds = val_ds.map(lambda x, y: (normalization_layer(x), y)) # Where x—images, y—labels.\n",
|
153 | 152 | "\n",
|
|
221 | 220 | "id": "70b3eb9b",
|
222 | 221 | "metadata": {},
|
223 | 222 | "source": [
|
224 |
| - "Attach the last fully connected classification layer in a **tf_keras.Sequential** model." |
| 223 | + "Attach the last fully connected classification layer in a **tf.keras.Sequential** model." |
225 | 224 | ]
|
226 | 225 | },
|
227 | 226 | {
|
|
233 | 232 | "source": [
|
234 | 233 | "num_classes = len(class_names)\n",
|
235 | 234 | "\n",
|
236 |
| - "fp32_model = tf_keras.Sequential([\n", |
| 235 | + "fp32_model = tf.keras.Sequential([\n", |
237 | 236 | " feature_extractor_layer,\n",
|
238 |
| - " tf_keras.layers.Dense(num_classes)\n", |
| 237 | + " tf.keras.layers.Dense(num_classes)\n", |
239 | 238 | "])\n",
|
240 | 239 | "\n",
|
241 | 240 | "if arch == 'SPR':\n",
|
242 | 241 | " # Create a deep copy of the model to train the bf16 model separately to compare accuracy\n",
|
243 |
| - " bf16_model = tf_keras.models.clone_model(fp32_model)\n", |
| 242 | + " bf16_model = tf.keras.models.clone_model(fp32_model)\n", |
244 | 243 | "\n",
|
245 | 244 | "fp32_model.summary()"
|
246 | 245 | ]
|
|
260 | 259 | "metadata": {},
|
261 | 260 | "outputs": [],
|
262 | 261 | "source": [
|
263 |
| - "class TimeHistory(tf_keras.callbacks.Callback):\n", |
| 262 | + "class TimeHistory(tf.keras.callbacks.Callback):\n", |
264 | 263 | " def on_train_begin(self, logs={}):\n",
|
265 | 264 | " self.times = []\n",
|
266 | 265 | " self.throughput = []\n",
|
|
290 | 289 | "outputs": [],
|
291 | 290 | "source": [
|
292 | 291 | "fp32_model.compile(\n",
|
293 |
| - " optimizer=tf_keras.optimizers.SGD(),\n", |
294 |
| - " loss=tf_keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", |
| 292 | + " optimizer=tf.keras.optimizers.SGD(),\n", |
| 293 | + " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", |
295 | 294 | " metrics=['acc'])"
|
296 | 295 | ]
|
297 | 296 | },
|
|
374 | 373 | "if arch == 'SPR':\n",
|
375 | 374 | " # Compile\n",
|
376 | 375 | " bf16_model.compile(\n",
|
377 |
| - " optimizer=tf_keras.optimizers.SGD(),\n", |
378 |
| - " loss=tf_keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", |
| 376 | + " optimizer=tf.keras.optimizers.SGD(),\n", |
| 377 | + " loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n", |
379 | 378 | " metrics=['acc'])\n",
|
380 | 379 | " \n",
|
381 | 380 | " # Train\n",
|
|
0 commit comments