|
| 1 | +"""End-to-end example for SNN Toolbox. |
| 2 | +
|
| 3 | +This script sets up a small CNN using Keras and tensorflow, trains it for one |
| 4 | +epoch on MNIST, stores model and dataset in a temporary folder on disk, creates |
| 5 | +a configuration file for SNN toolbox, and finally calls the main function of |
| 6 | +SNN toolbox to convert the trained ANN to an SNN and run it using spiNNaker. |
| 7 | +""" |
| 8 | + |
| 9 | +import os |
| 10 | +import numpy as np |
| 11 | + |
| 12 | +import keras |
| 13 | +from keras import Input, Model |
| 14 | +from keras.layers import AveragePooling2D, Flatten, Dropout |
| 15 | +from keras_rewiring import Sparse, SparseConv2D, SparseDepthwiseConv2D |
| 16 | +from keras_rewiring.optimizers import NoisySGD |
| 17 | +from keras_rewiring.rewiring_callback import RewiringCallback |
| 18 | +from keras.datasets import mnist |
| 19 | +from keras.utils import np_utils |
| 20 | + |
| 21 | +from snntoolbox.bin.run import main |
| 22 | +from snntoolbox.utils.utils import import_configparser |
| 23 | + |
| 24 | + |
| 25 | +# WORKING DIRECTORY # |
| 26 | +##################### |
| 27 | + |
| 28 | +# Define path where model and output files will be stored. |
| 29 | +# The user is responsible for cleaning up this temporary directory. |
| 30 | +path_wd = os.path.abspath(os.path.join(os.path.dirname(os.path.realpath( |
| 31 | + __file__)), '..', 'temp')) |
| 32 | +# os.makedirs(path_wd) |
| 33 | + |
| 34 | +# GET DATASET # |
| 35 | +############### |
| 36 | + |
| 37 | +(x_train, y_train), (x_test, y_test) = mnist.load_data() |
| 38 | + |
| 39 | +# Normalize input so we can train ANN with it. |
| 40 | +# Will be converted back to integers for SNN layer. |
| 41 | +x_train = x_train / 255 |
| 42 | +x_test = x_test / 255 |
| 43 | + |
| 44 | +# Add a channel dimension. |
| 45 | +axis = 1 if keras.backend.image_data_format() == 'channels_first' else -1 |
| 46 | +x_train = np.expand_dims(x_train, axis) |
| 47 | +x_test = np.expand_dims(x_test, axis) |
| 48 | + |
| 49 | +# One-hot encode target vectors. |
| 50 | +y_train = np_utils.to_categorical(y_train, 10) |
| 51 | +y_test = np_utils.to_categorical(y_test, 10) |
| 52 | + |
| 53 | +# Save dataset so SNN toolbox can find it. |
| 54 | +np.savez_compressed(os.path.join(path_wd, 'x_test'), x_test) |
| 55 | +np.savez_compressed(os.path.join(path_wd, 'y_test'), y_test) |
| 56 | +# SNN toolbox will not do any training, but we save a subset of the training |
| 57 | +# set so the toolbox can use it when normalizing the network parameters. |
| 58 | +np.savez_compressed(os.path.join(path_wd, 'x_norm'), x_train[::10]) |
| 59 | + |
| 60 | +# SETUP REWIRING # |
| 61 | +################## |
| 62 | + |
| 63 | +deep_r = RewiringCallback(noise_coeff=10 ** -5) |
| 64 | + |
| 65 | +callback_list = [deep_r] |
| 66 | + |
| 67 | +# CREATE ANN # |
| 68 | +############## |
| 69 | + |
| 70 | +# This section creates a simple CNN using Keras, and trains it |
| 71 | +# with backpropagation. There are no spikes involved at this point. |
| 72 | + |
| 73 | +input_shape = x_train.shape[1:] |
| 74 | +input_layer = Input(input_shape) |
| 75 | + |
| 76 | +built_in_sparsity = [0.5] * 4 |
| 77 | + |
| 78 | +layer = SparseConv2D( |
| 79 | + filters=16, |
| 80 | + kernel_size=( |
| 81 | + 5, |
| 82 | + 5), |
| 83 | + strides=( |
| 84 | + 2, |
| 85 | + 2), |
| 86 | + activation='relu', |
| 87 | + use_bias=False, |
| 88 | + connectivity_level=built_in_sparsity.pop(0) or None)(input_layer) |
| 89 | +layer = SparseConv2D( |
| 90 | + filters=32, |
| 91 | + kernel_size=( |
| 92 | + 3, |
| 93 | + 3), |
| 94 | + activation='relu', |
| 95 | + use_bias=False, |
| 96 | + connectivity_level=built_in_sparsity.pop(0) or None)(layer) |
| 97 | +layer = AveragePooling2D()(layer) |
| 98 | +layer = SparseConv2D( |
| 99 | + filters=8, |
| 100 | + kernel_size=( |
| 101 | + 3, |
| 102 | + 3), |
| 103 | + padding='same', |
| 104 | + activation='relu', |
| 105 | + use_bias=False, |
| 106 | + connectivity_level=built_in_sparsity.pop(0) or None)(layer) |
| 107 | +layer = Flatten()(layer) |
| 108 | +layer = Dropout(0.01)(layer) |
| 109 | +layer = Sparse(units=10, |
| 110 | + activation='softmax', |
| 111 | + use_bias=False, |
| 112 | + connectivity_level=built_in_sparsity.pop(0) or None)(layer) |
| 113 | + |
| 114 | +model = Model(input_layer, layer) |
| 115 | + |
| 116 | +model.summary() |
| 117 | + |
| 118 | +model.compile( |
| 119 | + NoisySGD( |
| 120 | + lr=0.01), |
| 121 | + 'categorical_crossentropy', |
| 122 | + ['accuracy']) |
| 123 | + |
| 124 | +# Train model with backprop. |
| 125 | +model.fit(x_train, y_train, batch_size=64, epochs=5, verbose=2, |
| 126 | + validation_data=(x_test, y_test), |
| 127 | + callbacks=callback_list) |
| 128 | + |
| 129 | +# Store model so SNN Toolbox can find it. |
| 130 | +model_name = 'sparse_mnist_cnn' |
| 131 | +keras.models.save_model(model, os.path.join(path_wd, model_name + '.h5')) |
| 132 | + |
| 133 | +# SNN TOOLBOX CONFIGURATION # |
| 134 | +############################# |
| 135 | + |
| 136 | +# Create a config file with experimental setup for SNN Toolbox. |
| 137 | +configparser = import_configparser() |
| 138 | +config = configparser.ConfigParser() |
| 139 | + |
| 140 | +config['paths'] = { |
| 141 | + 'path_wd': path_wd, # Path to model. |
| 142 | + 'dataset_path': path_wd, # Path to dataset. |
| 143 | + 'filename_ann': model_name # Name of input model. |
| 144 | +} |
| 145 | + |
| 146 | +config['tools'] = { |
| 147 | + 'evaluate_ann': True, # Test ANN on dataset before conversion. |
| 148 | + # Normalize weights for full dynamic range. |
| 149 | + 'normalize': False, |
| 150 | + 'scale_weights_exp': True |
| 151 | +} |
| 152 | + |
| 153 | +config['simulation'] = { |
| 154 | + # Chooses execution backend of SNN toolbox. |
| 155 | + 'simulator': 'spiNNaker', |
| 156 | + 'duration': 50, # Number of time steps to run each sample. |
| 157 | + 'num_to_test': 5, # How many test samples to run. |
| 158 | + 'batch_size': 1, # Batch size for simulation. |
| 159 | + # SpiNNaker seems to require 0.1 for comparable results. |
| 160 | + 'dt': 0.1 |
| 161 | +} |
| 162 | + |
| 163 | +config['input'] = { |
| 164 | + 'poisson_input': True, # Images are encodes as spike trains. |
| 165 | + 'input_rate': 1000 |
| 166 | +} |
| 167 | + |
| 168 | +config['cell'] = { |
| 169 | + 'tau_syn_E': 0.01, |
| 170 | + 'tau_syn_I': 0.01 |
| 171 | +} |
| 172 | + |
| 173 | +config['output'] = { |
| 174 | + 'plot_vars': { # Various plots (slows down simulation). |
| 175 | + 'spiketrains', # Leave section empty to turn off plots. |
| 176 | + 'spikerates', |
| 177 | + 'activations', |
| 178 | + 'correlation', |
| 179 | + 'v_mem', |
| 180 | + 'error_t'} |
| 181 | +} |
| 182 | + |
| 183 | +# Store config file. |
| 184 | +config_filepath = os.path.join(path_wd, 'config') |
| 185 | +with open(config_filepath, 'w') as configfile: |
| 186 | + config.write(configfile) |
| 187 | + |
| 188 | +# RUN SNN TOOLBOX # |
| 189 | +################### |
| 190 | + |
| 191 | +main(config_filepath) |
0 commit comments