Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 8890701

Browse files
bfinerannatuan
andauthored
Add batch dim for TF 2.1 (#131) (#132)
* Add batch dim for TF 2.1 * Clean up * Enable native keras support for notebook Co-authored-by: Tuan Nguyen <[email protected]>
1 parent 0e5e9d3 commit 8890701

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

notebooks/keras_classification.ipynb

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"## Step 1 - Requirements\n",
2525
"To run this notebook, you will need the following packages already installed:\n",
2626
"* SparseML and SparseZoo;\n",
27-
"* Tensorflow >=2.2, which includes Keras and TensorBoard;\n",
27+
"* Tensorflow >=2.1, which includes Keras and TensorBoard;\n",
2828
"* keras2onnx.\n",
2929
"\n",
3030
"You can install any package that is not already present via `pip`."
@@ -60,7 +60,7 @@
6060
"outputs": [],
6161
"source": [
6262
"import os\n",
63-
"from tensorflow import keras\n",
63+
"from sparseml.keras.utils import keras\n",
6464
"from sparsezoo.models import Zoo\n",
6565
"\n",
6666
"# Root directory for the notebook artifacts\n",
@@ -122,6 +122,8 @@
122122
"metadata": {},
123123
"outputs": [],
124124
"source": [
125+
"import numpy\n",
126+
"\n",
125127
"# Number of classes\n",
126128
"num_classes = 10\n",
127129
"\n",
@@ -132,6 +134,10 @@
132134
"x_train = x_train.astype('float32') / 255\n",
133135
"x_test = x_test.astype('float32') / 255\n",
134136
"\n",
137+
"# Add batch dimension (for older TF versions)\n",
138+
"x_train = numpy.expand_dims(x_train, -1)\n",
139+
"x_test = numpy.expand_dims(x_test, -1)\n",
140+
"\n",
135141
"y_train = keras.utils.to_categorical(y_train, num_classes)\n",
136142
"y_test = keras.utils.to_categorical(y_test, num_classes)\n",
137143
"\n",
@@ -418,9 +424,9 @@
418424
],
419425
"metadata": {
420426
"kernelspec": {
421-
"display_name": "Python (pypi_sparseml)",
427+
"display_name": "Python (keras_pruning)",
422428
"language": "python",
423-
"name": "pypi_sparseml"
429+
"name": "keras_pruning"
424430
},
425431
"language_info": {
426432
"codemirror_mode": {

0 commit comments

Comments
 (0)