From 4b69577578a340b48a2d6eda20840418d21c05da Mon Sep 17 00:00:00 2001 From: Suvaditya Mukherjee Date: Fri, 23 Dec 2022 10:40:58 +0530 Subject: [PATCH 1/4] added ForwardForward py file example for PR review Signed-off-by: Suvaditya Mukherjee --- examples/vision/forwardforward.py | 352 ++++++++++++++++++++++++++++++ 1 file changed, 352 insertions(+) create mode 100644 examples/vision/forwardforward.py diff --git a/examples/vision/forwardforward.py b/examples/vision/forwardforward.py new file mode 100644 index 0000000000..534e677f25 --- /dev/null +++ b/examples/vision/forwardforward.py @@ -0,0 +1,352 @@ +""" +Title: Using Forward-Forward Algorithm for Image Classification +Author: [Suvaditya Mukherjee](https://twitter.com/halcyonrayes) +Date created: 2022/12/21 +Last modified: 2022/12/23 +Description: Training a Dense-layer based model using the Forward-Forward algorithm. + +""" + +""" +## Introduction + +The following example explores how to use the Forward-Forward algorithm to perform +training instead of the traditionally-used method of backpropagation, as proposed by +[Prof. Geoffrey Hinton](https://www.cs.toronto.edu/~hinton/FFA13.pdf) +The concept was inspired by the understanding behind [Boltzmann +Machines](http://www.cs.toronto.edu/~fritz/absps/dbm.pdf). Backpropagation involves +calculating loss via a cost function and propagating the error across the network. On the +other hand, the FF Algorithm suggests the analogy of neurons which get "excited" based on +looking at a certain recognized combination of an image and its correct corresponding +label. +This method takes certain inspiration from the biological learning process that occurs in +the cortex. A significant advantage that this method brings is the fact that +backpropagation through the network does not need to be performed anymore, and that +weight updates are local to the layer itself. +As this is yet still an experimental method, it does not yield state-of-the-art results. +But with proper tuning, it is supposed to come close to the same. +Through this example, we will examine a process that allows us to implement the +Forward-Forward algorithm within the layers themselves, instead of the traditional method +of relying on the global loss functions and optimizers. +The process is as follows: +- Perform necessary imports +- Load the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) +- Visualize Random samples from the MNIST dataset +- Define a `FFDense` Layer to override `call` and implement a custom `forwardforward` +method which performs weight updates. +- Define a `FFNetwork` Layer to override `train_step`, `predict` and implement 2 custom +functions for per-sample prediction and overlaying labels +- Convert MNIST from `NumPy` arrays to `tf.data.Dataset` +- Fit the network +- Visualize results +- Perform inference testing + +As this example requires the customization of certain core functions with +`tf.keras.layers.Layer` and `tf.keras.models.Model`, refer to the following resources for +a primer on how to do so +- [Customizing what happens in +`model.fit()`](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit) +- [Making new Layers and Models via +subclassing](https://www.tensorflow.org/guide/keras/custom_layers_and_models) +""" + +""" +## Setup imports +""" + +import tensorflow as tf +from tensorflow import keras +from tqdm.notebook import trange, tqdm +import numpy as np +import matplotlib.pyplot as plt +from sklearn.metrics import accuracy_score +import random + +tf.config.run_functions_eagerly(True) + +""" +## Load dataset and visualize + +We use the `keras.datasets.mnist.load_data()` utility to directly pull the MNIST dataset +in the form of `NumPy` arrays. We then arrange it in the form of the train and test +splits. + +Following loading the dataset, we select 4 random samples from within the training set +and visualize them using `matplotlib.pyplot` +""" + +(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + +print("4 Random Training samples and labels") +idx1, idx2, idx3, idx4 = random.sample(range(0, x_train.shape[0]), 4) + +img1 = (x_train[idx1], y_train[idx1]) +img2 = (x_train[idx2], y_train[idx2]) +img3 = (x_train[idx3], y_train[idx3]) +img4 = (x_train[idx4], y_train[idx4]) + +imgs = [img1, img2, img3, img4] + +plt.figure(figsize=(10, 10)) + +for idx, item in enumerate(imgs): + image, label = item[0], item[1] + plt.subplot(2, 2, idx + 1) + plt.imshow(image, cmap="gray") + plt.title(f"Label : {label}") +plt.show() + +""" +## Define `FFDense` Custom Layer + +In this custom layer, we have a base `tf.keras.layers.Dense` object which acts as the +base `Dense` layer within. Since weight updates will happen within the layer itself, we +add an `tf.keras.optimizers.Optimizer` object that is accepted from the user. Here, we +use `Adam` as our optimizer with a rather higher learning rate of `0.03`. +Following the algorithm's specifics, we must set a `threshold` parameter that will be +used to make the positive-negative decision in each prediction. This is set to a default +of 2.0 +As the epochs are localized to the layer itself, we also set a `num_epochs` parameter +(default at 2000). + +We override the `call` function in order to perform a normalization over the complete +input space followed by running it through the base `Dense` layer as would happen in a +normal `Dense` layer call. + +We implement the Forward-Forward algorithm which accepts 2 kinds of input tensors, each +representing the positive and negative samples respectively. We write a custom training +loop here with the use of `tf.GradientTape()`, within which we calculate a loss per +sample by taking the distance of the prediction from the threshold to understand the +error and taking its mean to get a `mean_loss` metric. +With the help of `tf.GradientTape()` we calculate the gradient updates for the trainable +base `Dense` layer and apply them using the layer's local optimizer. + +Finally, we return the `call` result as the `Dense` results of the positive and negative +samples while also returning the last `mean_loss` metric and all the loss values over a +certain all-epoch run. +""" + + +class FFDense(tf.keras.layers.Layer): + def __init__( + self, + units, + optimizer, + num_epochs=2000, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): + super(FFDense, self).__init__() + self.dense = keras.layers.Dense( + units=units, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + **kwargs, + ) + self.relu = keras.layers.ReLU() + self.optimizer = optimizer + self.threshold = 2.0 + self.num_epochs = num_epochs + + def call(self, x): + x_norm = tf.norm(x, ord=2, axis=1, keepdims=True) + x_norm = x_norm + 1e-4 + x_dir = x / x_norm + res = self.dense(x_dir) + return self.relu(res) + + def forwardforward(self, x_pos, x_neg): + loss_list = [] + for i in trange(self.num_epochs): + with tf.GradientTape() as tape: + g_pos = tf.math.reduce_mean(tf.math.pow(self.call(x_pos), 2), 1) + g_neg = tf.math.reduce_mean(tf.math.pow(self.call(x_neg), 2), 1) + + loss = tf.math.log( + 1 + + tf.math.exp( + tf.concat([-g_pos + self.threshold, g_neg - self.threshold], 0) + ) + ) + mean_loss = tf.math.reduce_mean(loss) + loss_list.append(mean_loss.numpy()) + gradients = tape.gradient(mean_loss, self.trainable_weights) + self.optimizer.apply_gradients(zip(gradients, self.dense.trainable_weights)) + return ( + tf.stop_gradient(self.call(x_pos)), + tf.stop_gradient(self.call(x_neg)), + mean_loss, + loss_list, + ) + + +""" +## Define the `FFNetwork` Custom Model + +With our custom layer defined, we also need to override the `train_step` method and +define a custom `tf.keras.models.Model` that works with our `FFDense` layer. + +For this algorithm, we must 'embed' the labels onto the original image. To do so, we +exploit the structure of MNIST images where the top-left 10 pixels are always zeros. We +use that as a label space in order to visually one-hot-encode the labels within the image +itself. This action is performed by the `overlay_y_on_x` function. + +We break down the prediction function with a per-sample prediction function which is then +called over the entire test set by the overriden `predict()` function. The prediction is +performed here with the help of measuring the `excitation` of the neurons per layer for +each image. This is then summed over all layers to calculate a network-wide 'goodness +score'. The label with the highest 'goodness score' is then chosen as the sample +prediction. + +The `train_step` function is overriden to act as the main controlling loop for running +training on each layer as per the number of epochs per layer. +""" + + +class FFNetwork(keras.Model): + def __init__(self, dims, layer_optimizer=keras.optimizers.Adam(learning_rate=0.03)): + super().__init__() + self.layer_optimizer = layer_optimizer + self.mean_loss = keras.metrics.Mean() + self.flatten_layer = keras.layers.Flatten() + self.layer_list = [keras.Input(shape=(dims[0],))] + for d in range(len(dims) - 1): + self.layer_list += [FFDense(dims[d + 1], optimizer=self.layer_optimizer)] + + @tf.function() + def overlay_y_on_x(self, X, y): + x_res = X.numpy() + x_npy = X.numpy() + x_res[:, :10] *= 0.0 + if not isinstance(y, int): + y_npy = y.numpy() + x_res[range(x_npy.shape[0]), y.numpy()] = x_npy.max() + else: + x_res[range(x_npy.shape[0]), y] = x_npy.max() + return tf.convert_to_tensor(x_res) + + @tf.function() + def predict_one_sample(self, x): + goodness_per_label = [] + x = tf.expand_dims(x, axis=0) + for label in range(10): + h = self.overlay_y_on_x(x, label) + h = self.flatten_layer(h) + goodness = [] + for layer_idx in range(1, len(self.layer_list)): + layer = self.layer_list[layer_idx] + h = layer(h) + goodness += [tf.math.reduce_mean(tf.math.pow(h, 2), 1)] + goodness_per_label += [ + tf.expand_dims(tf.reduce_sum(goodness, keepdims=True), 1) + ] + goodness_per_label = tf.concat(goodness_per_label, 1) + return tf.argmax(goodness_per_label, 1) + + def predict(self, data): + x = data + preds = list() + for idx in trange(x.shape[0]): + sample = x[idx] + result = self.predict_one_sample(sample) + preds.append(result) + return np.asarray(preds, dtype=int) + + def train_step(self, data): + x, y = data + x = self.flatten_layer(x) + perm_array = tf.range(start=0, limit=x.get_shape()[0], delta=1) + x_pos = self.overlay_y_on_x(x, y) + y_numpy = y.numpy() + random_y_tensor = y_numpy[tf.random.shuffle(perm_array)] + x_neg = self.overlay_y_on_x(x, tf.convert_to_tensor(random_y_tensor)) + h_pos, h_neg = x_pos, x_neg + for idx, layer in enumerate(self.layers): + if idx == 0: + print("Input layer : No training") + continue + print(f"Training layer {idx+1} now : ") + if isinstance(layer, FFDense): + h_pos, h_neg, loss, loss_list = layer.forwardforward(h_pos, h_neg) + plt.plot(range(len(loss_list)), loss_list) + plt.title(f"Loss over training on layer {idx+1}") + plt.show() + else: + x = layer(x) + return {"FinalLoss": loss} + + +""" +## Convert MNIST `NumPy` arrays to `tf.data.Dataset` + +We now perform some preliminary processing on the `NumPy` arrays and then convert them +into the `tf.data.Dataset` format which allows for optimized loading. +""" + +x_train = x_train.astype(float) / 255 +x_test = x_test.astype(float) / 255 + +train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) +test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) + +train_dataset = train_dataset.batch(60000) +test_dataset = test_dataset.batch(10000) + +""" +## Fit the network and visualize results + +Having performed all previous set-up, we are now going to run `model.fit()` and run 1 +model epoch, which will perform 2000 epochs on each layer. We get to see the plotted loss +curve as each layer is trained. +""" + +model = FFNetwork(dims=[784, 500, 500]) + +model.compile( + optimizer=keras.optimizers.Adam(learning_rate=0.03), loss="mse", run_eagerly=True +) + +history = model.fit(train_dataset, epochs=1) + +""" +## Perform inference and testing + +Having trained the model to a large extent, we now see how it performs on the test set. +We calculate the Accuracy Score to understand the results closely. +""" + +preds = model.predict(tf.convert_to_tensor(x_test)) + +preds = preds.reshape((preds.shape[0], preds.shape[1])) + +results = accuracy_score(preds, y_test) + +print(f"Accuracy score : {results*100}%") + +""" +## Conclusion: + +This example has hereby demonstrated how the Forward-Forward algorithm works using +TensorFlow and Keras modules. While the investigation results presented by Prof. Hinton +in their paper are currently still limited to smaller models and datasets like MNIST and +Fashion-MNIST, subsequent results on larger models like LLMs are expected in future +papers. + +Through the paper, Prof. Hinton has reported results of 1.36% test error with a +2000-units, 4 hidden-layer, fully-connected network run over 60 epochs (while mentioning +that backpropagation takes only 20 epochs to achieve similar performance). Another run of +doubling the learning rate and training for 40 epochs yields a slightly worse error rate +of 1.46% + +The current example does not yield state-of-the-art results. But with proper tuning of +the Learning Rate, model architecture (no. of units in `Dense` layers, kernel +activations, initializations, regularization etc.), the results can be improved +drastically to match the claims of the paper. +""" From 634d63da627d63968947001608db6a623831ad96 Mon Sep 17 00:00:00 2001 From: Suvaditya Mukherjee Date: Sun, 8 Jan 2023 20:55:41 +0530 Subject: [PATCH 2/4] added changes as per Francois's comments Signed-off-by: Suvaditya Mukherjee --- examples/vision/forwardforward.py | 226 ++++++++++++++++++++---------- 1 file changed, 151 insertions(+), 75 deletions(-) diff --git a/examples/vision/forwardforward.py b/examples/vision/forwardforward.py index 534e677f25..33cef9cc4f 100644 --- a/examples/vision/forwardforward.py +++ b/examples/vision/forwardforward.py @@ -1,10 +1,10 @@ """ Title: Using Forward-Forward Algorithm for Image Classification Author: [Suvaditya Mukherjee](https://twitter.com/halcyonrayes) -Date created: 2022/12/21 -Last modified: 2022/12/23 -Description: Training a Dense-layer based model using the Forward-Forward algorithm. - +Date created: 2023/01/08 +Last modified: 2023/01/08 +Description: Training a Dense-layer model using the Forward-Forward algorithm. +Accelerator: GPU """ """ @@ -12,22 +12,27 @@ The following example explores how to use the Forward-Forward algorithm to perform training instead of the traditionally-used method of backpropagation, as proposed by -[Prof. Geoffrey Hinton](https://www.cs.toronto.edu/~hinton/FFA13.pdf) -The concept was inspired by the understanding behind [Boltzmann -Machines](http://www.cs.toronto.edu/~fritz/absps/dbm.pdf). Backpropagation involves -calculating loss via a cost function and propagating the error across the network. On the -other hand, the FF Algorithm suggests the analogy of neurons which get "excited" based on -looking at a certain recognized combination of an image and its correct corresponding -label. +Hinton in [The Forward-Forward Algorithm: Some Preliminary +Investigations](https://www.cs.toronto.edu/~hinton/FFA13.pdf)(2022) + +The concept was inspired by the understanding behind +[Boltzmann Machines](http://www.cs.toronto.edu/~fritz/absps/dbm.pdf). Backpropagation +involves calculating the difference between actual and predicted output via a cost +function to adjust network weights. On the other hand, the FF Algorithm suggests the +analogy of neurons which get "excited" based on looking at a certain recognized +combination of an image and its correct corresponding label. + This method takes certain inspiration from the biological learning process that occurs in the cortex. A significant advantage that this method brings is the fact that backpropagation through the network does not need to be performed anymore, and that weight updates are local to the layer itself. + As this is yet still an experimental method, it does not yield state-of-the-art results. But with proper tuning, it is supposed to come close to the same. Through this example, we will examine a process that allows us to implement the Forward-Forward algorithm within the layers themselves, instead of the traditional method of relying on the global loss functions and optimizers. + The process is as follows: - Perform necessary imports - Load the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) @@ -56,13 +61,11 @@ import tensorflow as tf from tensorflow import keras -from tqdm.notebook import trange, tqdm import numpy as np import matplotlib.pyplot as plt from sklearn.metrics import accuracy_score import random - -tf.config.run_functions_eagerly(True) +from tensorflow.compiler.tf2xla.python import xla """ ## Load dataset and visualize @@ -103,6 +106,7 @@ base `Dense` layer within. Since weight updates will happen within the layer itself, we add an `tf.keras.optimizers.Optimizer` object that is accepted from the user. Here, we use `Adam` as our optimizer with a rather higher learning rate of `0.03`. + Following the algorithm's specifics, we must set a `threshold` parameter that will be used to make the positive-negative decision in each prediction. This is set to a default of 2.0 @@ -118,6 +122,7 @@ loop here with the use of `tf.GradientTape()`, within which we calculate a loss per sample by taking the distance of the prediction from the threshold to understand the error and taking its mean to get a `mean_loss` metric. + With the help of `tf.GradientTape()` we calculate the gradient updates for the trainable base `Dense` layer and apply them using the layer's local optimizer. @@ -127,12 +132,19 @@ """ -class FFDense(tf.keras.layers.Layer): +class FFDense(keras.layers.Layer): + """ + A custom ForwardForward-enabled Dense layer. It has an implementation of the + Forward-Forward network internally for use. + This layer must be used in conjunction with the `FFNetwork` model. + """ + def __init__( self, units, optimizer, - num_epochs=2000, + loss_metric, + num_epochs=50, use_bias=True, kernel_initializer="glorot_uniform", bias_initializer="zeros", @@ -140,7 +152,7 @@ def __init__( bias_regularizer=None, **kwargs, ): - super(FFDense, self).__init__() + super().__init__(**kwargs) self.dense = keras.layers.Dense( units=units, use_bias=use_bias, @@ -148,13 +160,16 @@ def __init__( bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, - **kwargs, ) self.relu = keras.layers.ReLU() self.optimizer = optimizer - self.threshold = 2.0 + self.loss_metric = loss_metric + self.threshold = 1.5 self.num_epochs = num_epochs + # We perform a normalization step before we run the input through the Dense + # layer. + def call(self, x): x_norm = tf.norm(x, ord=2, axis=1, keepdims=True) x_norm = x_norm + 1e-4 @@ -162,9 +177,19 @@ def call(self, x): res = self.dense(x_dir) return self.relu(res) - def forwardforward(self, x_pos, x_neg): - loss_list = [] - for i in trange(self.num_epochs): + # The Forward-Forward algorithm is below. We first perform the Dense-layer + # operation and then get a Mean Square value for all positive and negative + # samples respectively. + # The custom loss function finds the distance between the Mean-squared + # result and the threshold value we set (a hyperparameter) that will define + # whether the prediction is positive or negative in nature. Once the loss is + # calculated, we get a mean across the entire batch combined and perform a + # gradient calculation and optimization step. This does not technically + # qualify as backpropagation since there is no gradient being + # sent to any previous layer and is completely local in nature. + + def forward_forward(self, x_pos, x_neg): + for i in range(self.num_epochs): with tf.GradientTape() as tape: g_pos = tf.math.reduce_mean(tf.math.pow(self.call(x_pos), 2), 1) g_neg = tf.math.reduce_mean(tf.math.pow(self.call(x_neg), 2), 1) @@ -175,15 +200,14 @@ def forwardforward(self, x_pos, x_neg): tf.concat([-g_pos + self.threshold, g_neg - self.threshold], 0) ) ) - mean_loss = tf.math.reduce_mean(loss) - loss_list.append(mean_loss.numpy()) - gradients = tape.gradient(mean_loss, self.trainable_weights) + mean_loss = tf.cast(tf.math.reduce_mean(loss), tf.float32) + self.loss_metric.update_state([mean_loss]) + gradients = tape.gradient(mean_loss, self.dense.trainable_weights) self.optimizer.apply_gradients(zip(gradients, self.dense.trainable_weights)) return ( tf.stop_gradient(self.call(x_pos)), tf.stop_gradient(self.call(x_neg)), - mean_loss, - loss_list, + self.loss_metric.result(), ) @@ -211,34 +235,66 @@ def forwardforward(self, x_pos, x_neg): class FFNetwork(keras.Model): - def __init__(self, dims, layer_optimizer=keras.optimizers.Adam(learning_rate=0.03)): - super().__init__() + """ + A `keras.Model` that supports a `FFDense` network creation. This model + can work for any kind of classification task. It has an internal + implementation with some details specific to the MNIST dataset which can be + changed as per the use-case. + """ + + # Since each layer runs gradient-calculation and optimization locally, each + # layer has its own optimizer that we pass. As a standard choice, we pass + # the `Adam` optimizer with a default learning rate of 0.03 as that was + # found to be the best rate after experimentation. + # Loss is tracked using `loss_var` and `loss_count` variables. + + def __init__( + self, dims, layer_optimizer=keras.optimizers.Adam(learning_rate=0.03), **kwargs + ): + super().__init__(**kwargs) self.layer_optimizer = layer_optimizer - self.mean_loss = keras.metrics.Mean() - self.flatten_layer = keras.layers.Flatten() + self.loss_var = tf.Variable(0.0, trainable=False, dtype=tf.float32) + self.loss_count = tf.Variable(0.0, trainable=False, dtype=tf.float32) self.layer_list = [keras.Input(shape=(dims[0],))] for d in range(len(dims) - 1): - self.layer_list += [FFDense(dims[d + 1], optimizer=self.layer_optimizer)] - - @tf.function() - def overlay_y_on_x(self, X, y): - x_res = X.numpy() - x_npy = X.numpy() - x_res[:, :10] *= 0.0 - if not isinstance(y, int): - y_npy = y.numpy() - x_res[range(x_npy.shape[0]), y.numpy()] = x_npy.max() - else: - x_res[range(x_npy.shape[0]), y] = x_npy.max() - return tf.convert_to_tensor(x_res) - - @tf.function() + self.layer_list += [ + FFDense( + dims[d + 1], + optimizer=self.layer_optimizer, + loss_metric=keras.metrics.Mean(), + ) + ] + + # This function makes a dynamic change to the image wherein the labels are + # put on top of the original image (for this example, as MNIST has 10 + # unique labels, we take the top-left corner's first 10 pixels). This + # function returns the original data tensor with the first 10 pixels being + # a pixel-based one-hot representation of the labels. + + @tf.function(reduce_retracing=True) + def overlay_y_on_x(self, data): + X_sample, y_sample = data + max_sample = tf.reduce_max(X_sample, axis=0, keepdims=True) + max_sample = tf.cast(max_sample, dtype=tf.float64) + X_zeros = tf.zeros([10], dtype=tf.float64) + X_update = xla.dynamic_update_slice(X_zeros, max_sample, [y_sample]) + X_sample = xla.dynamic_update_slice(X_sample, X_update, [0]) + return X_sample, y_sample + + # A custom `predict_one_sample` performs predictions by passing the images + # through the network, measures the results produced by each layer (i.e. + # how high/low the output values are with respect to the set threshold for + # each label) and then simply finding the label with the highest values. + # In such a case, the images are tested for their 'goodness' with all + # labels. + + @tf.function(reduce_retracing=True) def predict_one_sample(self, x): goodness_per_label = [] - x = tf.expand_dims(x, axis=0) + x = tf.reshape(x, [tf.shape(x)[0] * tf.shape(x)[1]]) for label in range(10): - h = self.overlay_y_on_x(x, label) - h = self.flatten_layer(h) + h, label = self.overlay_y_on_x(data=(x, label)) + h = tf.reshape(h, [-1, tf.shape(h)[0]]) goodness = [] for layer_idx in range(1, len(self.layer_list)): layer = self.layer_list[layer_idx] @@ -248,39 +304,49 @@ def predict_one_sample(self, x): tf.expand_dims(tf.reduce_sum(goodness, keepdims=True), 1) ] goodness_per_label = tf.concat(goodness_per_label, 1) - return tf.argmax(goodness_per_label, 1) + return tf.cast(tf.argmax(goodness_per_label, 1), tf.float64) def predict(self, data): x = data preds = list() - for idx in trange(x.shape[0]): - sample = x[idx] - result = self.predict_one_sample(sample) - preds.append(result) + preds = tf.map_fn(fn=self.predict_one_sample, elems=x) return np.asarray(preds, dtype=int) + # This custom `train_step` function overrides the internal `train_step` + # implementation. We take all the input image tensors, flatten them and + # subsequently produce positive and negative samples on the images. + # A positive sample is an image that has the right label encoded on it with + # the `overlay_y_on_x` function. A negative sample is an image that has an + # erroneous label present on it. + # With the samples ready, we pass them through each `FFLayer` and perform + # the Forward-Forward computation on it. The returned loss is the final + # loss value over all the layers. + + @tf.function(jit_compile=True) def train_step(self, data): x, y = data - x = self.flatten_layer(x) - perm_array = tf.range(start=0, limit=x.get_shape()[0], delta=1) - x_pos = self.overlay_y_on_x(x, y) - y_numpy = y.numpy() - random_y_tensor = y_numpy[tf.random.shuffle(perm_array)] - x_neg = self.overlay_y_on_x(x, tf.convert_to_tensor(random_y_tensor)) + + # Flatten op + x = tf.reshape(x, [-1, tf.shape(x)[1] * tf.shape(x)[2]]) + + x_pos, y = tf.map_fn(fn=self.overlay_y_on_x, elems=(x, y)) + + random_y = tf.random.shuffle(y) + x_neg, y = tf.map_fn(fn=self.overlay_y_on_x, elems=(x, random_y)) + h_pos, h_neg = x_pos, x_neg + for idx, layer in enumerate(self.layers): - if idx == 0: - print("Input layer : No training") - continue - print(f"Training layer {idx+1} now : ") if isinstance(layer, FFDense): - h_pos, h_neg, loss, loss_list = layer.forwardforward(h_pos, h_neg) - plt.plot(range(len(loss_list)), loss_list) - plt.title(f"Loss over training on layer {idx+1}") - plt.show() + print(f"Training layer {idx+1} now : ") + h_pos, h_neg, loss = layer.forward_forward(h_pos, h_neg) + self.loss_var.assign_add(loss) + self.loss_count.assign_add(1.0) else: + print(f"Passing layer {idx+1} now : ") x = layer(x) - return {"FinalLoss": loss} + mean_res = tf.math.divide(self.loss_var, self.loss_count) + return {"FinalLoss": mean_res} """ @@ -292,6 +358,8 @@ def train_step(self, data): x_train = x_train.astype(float) / 255 x_test = x_test.astype(float) / 255 +y_train = y_train.astype(int) +y_test = y_test.astype(int) train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) @@ -310,16 +378,20 @@ def train_step(self, data): model = FFNetwork(dims=[784, 500, 500]) model.compile( - optimizer=keras.optimizers.Adam(learning_rate=0.03), loss="mse", run_eagerly=True + optimizer=keras.optimizers.Adam(learning_rate=0.03), + loss="mse", + jit_compile=True, + metrics=[keras.metrics.Mean()], ) -history = model.fit(train_dataset, epochs=1) +epochs = 250 +history = model.fit(train_dataset, epochs=epochs) """ ## Perform inference and testing -Having trained the model to a large extent, we now see how it performs on the test set. -We calculate the Accuracy Score to understand the results closely. +Having trained the model to a large extent, we now see how it performs on the +test set. We calculate the Accuracy Score to understand the results closely. """ preds = model.predict(tf.convert_to_tensor(x_test)) @@ -328,7 +400,11 @@ def train_step(self, data): results = accuracy_score(preds, y_test) -print(f"Accuracy score : {results*100}%") +print(f"Test Accuracy score : {results*100}%") + +plt.plot(range(len(history.history["FinalLoss"])), history.history["FinalLoss"]) +plt.title("Loss over training") +plt.show() """ ## Conclusion: @@ -339,7 +415,7 @@ def train_step(self, data): Fashion-MNIST, subsequent results on larger models like LLMs are expected in future papers. -Through the paper, Prof. Hinton has reported results of 1.36% test error with a +Through the paper, Prof. Hinton has reported results of 1.36% test accuracy error with a 2000-units, 4 hidden-layer, fully-connected network run over 60 epochs (while mentioning that backpropagation takes only 20 epochs to achieve similar performance). Another run of doubling the learning rate and training for 40 epochs yields a slightly worse error rate From 5b59cedb8ab1829cbc77246ec0db08fac6036d15 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Chollet?= Date: Sun, 8 Jan 2023 16:28:47 -0800 Subject: [PATCH 3/4] Copyedits --- examples/vision/forwardforward.py | 47 ++++++++++++++++--------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/examples/vision/forwardforward.py b/examples/vision/forwardforward.py index 33cef9cc4f..51b97fa745 100644 --- a/examples/vision/forwardforward.py +++ b/examples/vision/forwardforward.py @@ -12,8 +12,9 @@ The following example explores how to use the Forward-Forward algorithm to perform training instead of the traditionally-used method of backpropagation, as proposed by -Hinton in [The Forward-Forward Algorithm: Some Preliminary -Investigations](https://www.cs.toronto.edu/~hinton/FFA13.pdf)(2022) +Hinton in +[The Forward-Forward Algorithm: Some Preliminary Investigations](https://www.cs.toronto.edu/~hinton/FFA13.pdf) +(2022). The concept was inspired by the understanding behind [Boltzmann Machines](http://www.cs.toronto.edu/~fritz/absps/dbm.pdf). Backpropagation @@ -33,7 +34,8 @@ Forward-Forward algorithm within the layers themselves, instead of the traditional method of relying on the global loss functions and optimizers. -The process is as follows: +The tutorial is structured as follows: + - Perform necessary imports - Load the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) - Visualize Random samples from the MNIST dataset @@ -44,15 +46,14 @@ - Convert MNIST from `NumPy` arrays to `tf.data.Dataset` - Fit the network - Visualize results -- Perform inference testing +- Perform inference on test samples As this example requires the customization of certain core functions with -`tf.keras.layers.Layer` and `tf.keras.models.Model`, refer to the following resources for -a primer on how to do so -- [Customizing what happens in -`model.fit()`](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit) -- [Making new Layers and Models via -subclassing](https://www.tensorflow.org/guide/keras/custom_layers_and_models) +`keras.layers.Layer` and `keras.models.Model`, refer to the following resources for +a primer on how to do so: + +- [Customizing what happens in `model.fit()`](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit) +- [Making new Layers and Models via subclassing](https://www.tensorflow.org/guide/keras/custom_layers_and_models) """ """ @@ -68,14 +69,14 @@ from tensorflow.compiler.tf2xla.python import xla """ -## Load dataset and visualize +## Load the dataset and visualize the data We use the `keras.datasets.mnist.load_data()` utility to directly pull the MNIST dataset in the form of `NumPy` arrays. We then arrange it in the form of the train and test splits. Following loading the dataset, we select 4 random samples from within the training set -and visualize them using `matplotlib.pyplot` +and visualize them using `matplotlib.pyplot`. """ (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() @@ -100,20 +101,20 @@ plt.show() """ -## Define `FFDense` Custom Layer +## Define `FFDense` custom layer -In this custom layer, we have a base `tf.keras.layers.Dense` object which acts as the +In this custom layer, we have a base `keras.layers.Dense` object which acts as the base `Dense` layer within. Since weight updates will happen within the layer itself, we -add an `tf.keras.optimizers.Optimizer` object that is accepted from the user. Here, we +add an `keras.optimizers.Optimizer` object that is accepted from the user. Here, we use `Adam` as our optimizer with a rather higher learning rate of `0.03`. Following the algorithm's specifics, we must set a `threshold` parameter that will be used to make the positive-negative decision in each prediction. This is set to a default -of 2.0 +of 2.0. As the epochs are localized to the layer itself, we also set a `num_epochs` parameter -(default at 2000). +(defaults to 2000). -We override the `call` function in order to perform a normalization over the complete +We override the `call` method in order to perform a normalization over the complete input space followed by running it through the base `Dense` layer as would happen in a normal `Dense` layer call. @@ -215,7 +216,7 @@ def forward_forward(self, x_pos, x_neg): ## Define the `FFNetwork` Custom Model With our custom layer defined, we also need to override the `train_step` method and -define a custom `tf.keras.models.Model` that works with our `FFDense` layer. +define a custom `keras.models.Model` that works with our `FFDense` layer. For this algorithm, we must 'embed' the labels onto the original image. To do so, we exploit the structure of MNIST images where the top-left 10 pixels are always zeros. We @@ -407,10 +408,10 @@ def train_step(self, data): plt.show() """ -## Conclusion: +## Conclusion This example has hereby demonstrated how the Forward-Forward algorithm works using -TensorFlow and Keras modules. While the investigation results presented by Prof. Hinton +the TensorFlow and Keras packages. While the investigation results presented by Prof. Hinton in their paper are currently still limited to smaller models and datasets like MNIST and Fashion-MNIST, subsequent results on larger models like LLMs are expected in future papers. @@ -422,7 +423,7 @@ def train_step(self, data): of 1.46% The current example does not yield state-of-the-art results. But with proper tuning of -the Learning Rate, model architecture (no. of units in `Dense` layers, kernel +the Learning Rate, model architecture (number of units in `Dense` layers, kernel activations, initializations, regularization etc.), the results can be improved -drastically to match the claims of the paper. +to match the claims of the paper. """ From 0dbbfff29edec38700ae9ae12c3fa9eb298bc3ae Mon Sep 17 00:00:00 2001 From: Suvaditya Mukherjee Date: Mon, 9 Jan 2023 08:29:31 +0530 Subject: [PATCH 4/4] added small final edits and generated files Signed-off-by: Suvaditya Mukherjee --- examples/vision/forwardforward.py | 6 +- .../forwardforward/forwardforward_15_1.png | Bin 0 -> 8867 bytes .../img/forwardforward/forwardforward_5_1.png | Bin 0 -> 12003 bytes examples/vision/ipynb/forwardforward.ipynb | 576 +++++++++++ examples/vision/md/forwardforward.md | 973 ++++++++++++++++++ 5 files changed, 1552 insertions(+), 3 deletions(-) create mode 100644 examples/vision/img/forwardforward/forwardforward_15_1.png create mode 100644 examples/vision/img/forwardforward/forwardforward_5_1.png create mode 100644 examples/vision/ipynb/forwardforward.ipynb create mode 100644 examples/vision/md/forwardforward.md diff --git a/examples/vision/forwardforward.py b/examples/vision/forwardforward.py index 51b97fa745..10041b36b0 100644 --- a/examples/vision/forwardforward.py +++ b/examples/vision/forwardforward.py @@ -112,7 +112,7 @@ used to make the positive-negative decision in each prediction. This is set to a default of 2.0. As the epochs are localized to the layer itself, we also set a `num_epochs` parameter -(defaults to 2000). +(defaults to 50). We override the `call` method in order to perform a normalization over the complete input space followed by running it through the base `Dense` layer as would happen in a @@ -371,8 +371,8 @@ def train_step(self, data): """ ## Fit the network and visualize results -Having performed all previous set-up, we are now going to run `model.fit()` and run 1 -model epoch, which will perform 2000 epochs on each layer. We get to see the plotted loss +Having performed all previous set-up, we are now going to run `model.fit()` and run 250 +model epochs, which will perform 50*250 epochs on each layer. We get to see the plotted loss curve as each layer is trained. """ diff --git a/examples/vision/img/forwardforward/forwardforward_15_1.png b/examples/vision/img/forwardforward/forwardforward_15_1.png new file mode 100644 index 0000000000000000000000000000000000000000..0e95f8e1b8734c10a4b195bfdfe1c09c5979a773 GIT binary patch literal 8867 zcmZvC2|Uwr_`h2SO^!;e7NHz5M{zf_}K0(JEvA{Ejm!73B0|OV|(U-ArEO-wnyo%JhkG$uJK%(q@ zoEY5gkzQ_|NVmt2&iOg{_&)aZfXXV!DoCGmK_b1b$;tiCds$B(XSpQs>qrKMa}4@$ zZBzgB#j(Kj(WmI8l^@Pa^x?qrZyx0w950TEy_m^?yuUM(bLXL9wy0cXlm@SA(uGCP zfP++V!JtF2ZpSArL*6PKQlP0MCmS>NjgIem4t?cs;iL4p2G7j{)ey30i{tP^Wr7-h zqCOMXrjD4U?{<;r(rh<1nzj!-W@%Wg-m0M?k`V&o19Kulv2X5Vy~4hD!DI~=6XT~s zAm_8Sgg_h|taW-}DZ?ibo4@_569=4r@md=tkcbny@J5a+=yaYE zGFCsuvel~5Xc{Tm<200FS%pV;l6hFGOYxR|S4v08!H*D&2tP78Ly1sfH||Vc$ihFx z5A#5YPr~T9l;LV}(J*-vKRnkdEweT12UU~HcYlc&Q6YS%YzLCv9C~OL2tVyxBduBC zom$;Ly_5+X%`LYPL+*$W7lfw@fw}o_48!3Qgt#_1t+PnR-%4anvpZ$DEvyYJPnaVw zT-F|GK=>^lACZ*@S~gAGL->{A_jivA<%)q=>&W@tjLD3J7&DSMY3}(m6a#xGr=94s zC8=}aEJc~}2qnid&aG&)d@#WPA>E}IP;QXG#6Z|Gnv#@KrNY>Rq+f&*PxznMV!6XM z%3dV0U(PYgWW~!*+`Y39QMW75++s$vq0n$_-E+)&q(>C~?zSS5ct($p&d{uBUp(9n626iObSGfTSC%8-5y7G{_ydOJg63aDyk#Ta+7t1zpl2c%-bRzScWPk51Cd9v z&G@|AjU_HI|G^MDr8~&u$Lmp5FpI9^oXE)Jw#e`w)I|@9CpNqTVaK(c1_7;!BUkvW#nAY%3 zBz5jW!OsaICL-*7TG~aIHKlVZ)Q(cjq59YEln=ZNo@@^y4+3EZNMX#2bRKCw4pnF% zzyIKv^S)4+Skh3B&h=KZtTKTRhF&dY^k4~2q1HKW&buhbF5``nuswyG-JL@V)S>$$ z!?@b4^FjdH^t^e6)zcGmYtO5%A9ZzgJ-M+0f#}cAB-?s;)1y{q#ie_I(hsGj#sh7y zuwvrS5tUwb?po~s7ovckBztt=*RNw_A$)xLLe!1uwynO9#wa7pa}b~McU}p%H zP@?J(^bPec9Nr}fR8kV}z3jT^sg~$UkHWsN?PJ}^DR;?g@+(3`B|#t`#OImzznrN9 z0C8Fjs9kG0pCIIT85pKjuTG2ozG-UV(L&_o6noZ=@X|rdDeFhVtanW=uK9N?JQ)J< zY1R9s1s3(@siuq9@q=QaH}A6T#H|J~1QVNNniH|ubk%$TD#9-szj%ef1||BYZy(wr zgrAQXL0xenA4(q*-oJ24WE~}zU}8aUnY7?6`5@|%adTO56|s286B0kUkM1`Scb}ka zxr3Ou4&3k7Nwx z``gFB?4n@kB9}wPgh*3)e!)jc_Rtm)G?`2E753f7cj|YMre@wJsEc(0u~*7X zrk&M|oA#<87#H*k#3m%KV;axJwY=exgAt=)mt#*-?-ti)_w7U0P)>yrdTsm2D*l904 z2P!0P@;ZpRcA@rK4G_E7NHR5Ciff^i_4nmIe+3tN`to8AV|4B-J%ghbk};w(0q?B= zYxr=PAERl@E);tLDwFlUXR)61b@ryhIyz#BFE8~x1zx=?UuFsT1WZ46 zKrln=N3+h+m?x=RkY7U9@VgIPm>TL7&rsQa1)Q{s#OghIaj{285XV-h{5neS3FqxB zp;^bf>B%3IkRW}_Yz}1TpE#@H>{oi)6`-U|@G)n^rK$L)r&(XbfaDRp@Brf2(09C2^%#JG%rs!*DbFS|# z9KXax7#``1KeFXntM^PISyk$p=;EUa8vUF|D6y67mty%GKc=dkmPU37e5%&Dan!3q{u4XIJgI@S1bBLGMAgt(_16kZo>A8N!Lbo?kxvz8hJpJtO_?!v z>L1~ES0tb&PDhiRFUq!MQmok(MS&Sa89mPS&rXpVXya@Y+U*q?-<_m{8RoK5GuL@Y z+=F~Tc9?znZeEuDM{l{)-4Wes!=;@IsVn2{WDg}m0JPB4!z07;>%I3fIJSV?)?OX4 zuU}kqUD(Rb`m1?yEr;R-aPHGIIx+Aq@rRkI>6iIjV{_r6Y?UX*LLhzlevb9$TDQq- zf9xImNj}ko%(UN9k{t`Q7zVI*S;Bz3>)o z$|{5t1aT01M5H*&igFl2t#y?a0DQfWj|0h_jebI z1q7m?a$T4@N0ws&3QGRFKt~mCnmql48_?yOF3ci1EwV@YBJi7Mef#Bt`~2@Ml2?u) z2!wx{&9lBHc=MF^&-_ZD0weIq^)dIy&!0hA8b`2AZoI?3v!}5h5L}xJ%3??n7q1}&j~ca69s ziB4XcVbZvu5nX6bp8Bw;dwO6VsUrrzE8D|3zLEb7xyijUDJr6L3GzEO4io%!l@oAd z{&!md_YaXeedMdHp82uDCj1g&LH}x9Db`XUxRTu>@JcDeZ>PwISiDl&a4Ti=(^m2s z+ojEo~xl)rExM_`TW9*^2p~BA9MZ-ehI;K(PT~oTkyz*DxH2|2leOuzc-D~v5 zE~09f+~|r>nreu}k3B(bAj~I5)3=xd$3sk%+5*__SLaF{fns`h!~|4;%c!lEkt-;_lD0Ge);rpc0{3f};1gKusS=+2w-j12Z5*t= z?@YCvx^!58sq`oW~>4p1R91h5x0I5C)CZ)pA(dib6cVbjwc6AU=u{0z13x_VQ zT=w>B2)V7)a-j>SESg^T3EO_!?YkC=ejdn$^mDZLQ;42Qgh28_c-C+M*{u-o&ZNm2 zRQv)+|HSXI`oYMiA zDTpgSOZ?;Y#LwC4$dF! z|0w+zjNQxE!g*^ymG^`dgzf%V$|r_M{;SK!x^t@97Om{w@LYYpu=QUYXTVQ&C@PVf z6`3<#CJ_P^lVlahk$0k|EcKXjW!RWtyjccon%75!}A#p7H)cBgkw;hrl-IJxk z^|zfL@y9|ln9zJLRq=Z`SM%l7MrF~lvXDPQSghC^CN$@g?Y+(;^PDuU_I-ClSWjn~ zoTrs}{XyyqstLvT{KAExgNRZWst&yCKIi)5_cGmcF%zm01y$ggwlrRey0f5IZ>gSw z`P^hV`=__%GY^{uc`CuiNL2{LGfOKoe_e6vaI`-4H)w8CS4BZRCtC=lN|5X^Zk{X; zqSdab+x!ZA9kMe+gC}2r_yBT~+G$YYsl6_J$FOZf$=uxZkkiD zSTd|ONI%DrFP1+~3w9gPdKlDgz=JwFR~k|ty;p?vXmm8m=>XooJD+?ND~d#(x5%(= zwub%u?67eSnu{fCum+@0JLdaNu**%ZI?V*l2BtmR`uc;)PQ!Ufp*d1J4@CzZ<-FbjShSk&EL{>;Il83qq@S#>xzx5?1hqY z* zduxQiichBSNh6G=Zjw3wa*J_eVqtgjv}bVLTUemu1#9&4NLfSpbvQ;dg0mdDHOnqL zDL0Phfei;?|27^-19=pmb<$DJg^tO~JDnNSrqH-n;K@eG-?gC>!bX}Y%X(p#PsTc` z{2%xlU$KjZ`a%m|*+%|G_D8sidL-LO?$Cek&#FSDy1^(*%2{{dRRm!wQa zxmIF(*(LI~=|Yv_*);I|mfo&lk9zvE16riT)8^wKeC+=ng%yn1GTSokOy0LOsM*}R zlyR`0SP!7r;d;lI)mr6-Ist%Fr%1gxyit}M%Ohj3|G5;uFnlrNlhOSG?;@Ez_^QHo z->w|xH_kc=zY-fVMrsrNrR88~d?a@zRJ|-1*ZM=p@w)OL=oKP~ae<`- za%NgM6y`%OK3NgMw&jZ|&0IkL{qv{3F}Uv{pZ|;C3Bc=3S*gGXZP-LnPH@eD&8;ux!>+Gkvwqd(4SzTHQEtO0&w(=IN zwkp>vm=n-vj4#f<+!LQbTq_(a4%(l$uCJk=nqjrvTX~r5)gk^$va{O&`3$09$WJ{_ z?;Cg6o2p|QT5`L1`9id?g$$*|^TSO2OCXKD2=UI>G6oT2{8oCnqJs`q*ylw~<$HK@ zfP)QXex$2pl9BFO0vyb&0a5c;HQP~t=xfD7n=8LQ*<9izZ0{=#j+$C@_CW*){ghXhU+r2}=SBkHHu;Os<746Py0yag?jeC=I&slV)-Hb*rgob&+8u;9 z2?co(0AQ3fWkm^_%S@N&&(F&^19p`+P4U z)p&q4H5l8!ppFAELKK&B_J$Bz>;#{{%#q#>*#i2!){5)c& z;&U#c#N{r5#Cl&H@yoM9Ph4&~#z^*v%L1}FB1*m%-rk!WyjaEZL#sGydiiKzAbeAi zlerL=jHejrd^Q9)znl*Q+|mB87wSIqu;2nh(s!rRo34C*n3T9SMEL%pQ0BbEJqhlU zkwU;6>M+?LWt!*QcDP`Mr;Gj6fnWH|^rF{_e*}}4XY-t_fVgx}=_{vyY=4$kvg(g9 z*q88zHhlk5}toceBP@&26%1_y!qIX&y z+%$WdNz2C=fYhX>rs_rA*SsQ2tLyHT8Or5Rx3n;_&gqTK_El#WD0vNS?>*IT?l&se z?SwO6rxE2J^HT-~LFR+}1mNFl*aC}%v=1Uhy636iMSW~K87W&LGc)Bv}8&d{fMX7pf!!2C+PTz2)`n3KLIrXRm!=o(X;@p40 zbF_H7U2g_}KyeaaEwZ}v<@-m7xvb>m=toF&n<4cvw{_fI8*{c~S(&N6JST~KCDaDt zafn(PPX%G)j)dHT#|(RuhX133y`1#;>%mTDpjQTK$j|cn9Gpn~O9C&4tr~npElSf- zj~=+9U?aS!Py1dHBRr_`#XAyBqr+scD>m_&iCs~^pkz76XHJYjHsv^6&vc2iymH0t zx(-HGzqIIFT1T?8@Am5I$$mgzho4sL!rMJH2AXjyB012X%bdN@r5%^K?hB44b-!W@ zHWSm;mebeA$Yv+RM`!-RsDo`9Mq0{#O9IafPt(<}1Us_fnm?`C_=#J}XEq{!DAEWr3Cu{tH#x zsPIYD#Ff%q)SHH4jRn_5drw2-6MGHnjWbsk8PT2 zY{e9A0C}7FoB$!9z6_?>j;XTQ1gnXT(#x9`uSUygs8<)hoO<3OQ7PBaol~cX12_>u z!8wH{$%v17JZ4{OO?7-E<=QV=6K%fp1B4mjTwI#In9AogLHPKU?y<*}xmcX1m& zSu$t#9`gDv4^-u@^G2tyr>-kH*w`xrJ&5!!QO0<1LSu9M4OnMZ z`IR|1sar0pWDH=pfE~H*WU{Uadl<55pC^9VCtMzCutosj`wWU3oRdA_9kdK@$8qnA z=+^lobfFE*EuQ}FHs%3}6fWQQlHY^(z3rE4BzNk?jD>yT0caKDz@Ps5%#wf3$=k!M zx5j<+ZCB&@8032Iz@VFOwdVtlDz08SdLsM=l%wN>m2Yulm^bij%^R0 zo_2P=^DqPUn3(+=1T+HJV=chCiktU8WiUPnV7oe`q2+1}8^Iy_erag6tu%8?6kd^d z4ooF#O(C~Q*#|21twRhLMBfIXxb=LOS})D%tANa@)CbsV$6r7`7Ir?x&2F_aIJEg_ z0Jbm+E*$-^n_;ny{qd5>I{@B*9s7mjEibK%sLH^^cU42pQ+4R)&IoS;MPp=Z?95&pyHUXSkQym_Q?qV|NLAN_TTD*v2fMZZ;RraEHr4($F_(xvrz{SS%q2Y}ueuCHb7`*0x-b8HbQ~I?8m%IM) zf2y$>H7s2 zWelbxqxe9v_f+0$O)yBS5X2EltfR+^qTTHf9S#{bybDpbkpQP8axW&Excde_29W3P zn0f1V&7nm0_b7TL$sOURbmdqDBRHcoM+ZI~O>R@GU9k!(vxJs;hReyZKbX94q@edz z^Tw0ShiD^i=;GQV6c}~E@iWstGR3li_(QxdrZyo89X_y7y_0?hk=tG82plK>#~BR7 zl6n6CF7$@s!%t_u&fvoC59euUqwn%a5a?O1qxFrWGn1pEf%CCQjOJkNaLZ@%8HX-; zTkx&RN`%a`;Sx%tSntye9%x_>IhDMSnDTgPvp1(vRGQw2sJiH#*geXEZX6-=uyq9F z;;n_COC?#|YUG8n$PYW~DVEXrVkp5KQ8h?DkWLZ(<)opX>5Q@@2KJ3&ML*OD-#1t< zib)Gij4W|Y+WB=X0iYzDC|#75x1?k~>$%69E1Y)R%im4WEG5ht499DZ`^)!kw0-JE zF^6axjU+Z7uDIFhFNeT1deJG`nT4TB%~RLyoO`N!WdG{#s26rz8y@N@P*(|TA3LVb zuO5sQgK?x1OW9+h@#-dEPnO0RlP0#P&Y~wCm!dN{cGP3Hd>@+tHxM?C?g9v7bVnN| z#^_Ak@I$t^C?eR8KEZ+`hAzU5J4F61M1n8b>8%?^bG&v~VprDy!Lvsu+wgQ<-~~Jd zFFOv$u*P@qS5fMFoW1R4HERnOZ|m{AmnGc@b1ti6^?Fg#f928PmR3XG1gRgUeNeVH z*84&-6$zIgIR~tx9=xZk-hQdc*2t00TExb~9FqCd2N8i6V3=U`pvRZ`2m&)O%5;8e zDe*8X^H1$?d=e3SoR$^}KC!?36uDmD>@|^4!|B`!M;~7-N6~X>nhUfGBi@et5mt`peIcO0LrIiQ_SUmr;rIHEw z2NRfcGqmP76}w%VG0^&&w+Tlr4|mD*R3^AVA2La^+ob;80j?vQUX~;Ip>PeMK`zM~ zaCF9o(_3nBGxTRbkVvmgD-qKzUYB9{XX=k8$#RfX+6$y)#wZk zPNQ~rY3DeEhbWkUBWWvMo@Jxe3{N9VU@_=}>MY~5=dgv3DVv#GJN3*p%Do%iW!=dq zMT~}zyB+uV@$#9aao?P!_0Y9-Tj1(MoO`IZPj8#pen->wkVr22Fh+&f|iO!5p<&L zOIx=(BoKjstP!bTgn)nvJE+7UOGF3|Ls;G?*w)+L+gtzVz3+KBCx?^FoSA2y`K_Pt zFX<}>TZMJo)?qLh1^oWKhcTFyc<^!fKo0zJ?JPAJ{97IN1>Wfc@E7}m?|0z)+OPM! zg<&x3H_1LL+WRT<;78N&eXik0g8jpzj)wYSNJqoJJ{cT-^2D)^BK<Fw6r zt^3jO@bItA_4VH!*9#5}&`;AyPQqY5!r=FQ?sO_g*dKGrFM_N;Ke6ighGT5pho4`5 z3ID_X!=2Au&!4aOBE2a<{#%St;Rgwfn9+K{{ZZb+S{Lgj=YWP6pXTH zv*MMsE0PDZn~=!EPy`Guw9}+pU7*}UkPQu$FP}R#N|Ua19^97W9plVRL@NWPLmFr? z$1JNjl`Fb7#&;897HHX+E{6c^ZYy>jp7Jq!7#%I+_2jpd>c~|fKVNg=dOC8wN zanvG zuEc7wjgawlz6mm)ikcslE`571B`{^U(+}@ZfNDE&m#Akil}UFmen;EABgytQ+Fvjl z4{=FQL_AV}jGJxyIQ4E^L2tPx`5+Qk@_Jm@`DHq4lA7nHmU%Ydq)t1bw~HXPp)=39 z!`=>4cd$inBOzzh8mn81uuHi0F3j4i{O8zwhtS5=DNR?%L&m*x9PBx)#F5UdDi~#8 zy_>G-Z?+*9^vY{Wjv_H72C*%{^o54crn}2V))MuGU0`7?c1d!%s9-b!Uzj;f&aiYr zB7;(Iqm87QJMM_K?=zKFLEH9igM*UVP8Y6GL7*7RqVnU*dFl>R0dA)TA2)#wYBNk7 ztquvTg1Cm)O#GZ7waaG%{FRHQ@7~K|rU^Pb(7^n*5bw+UE4tO_yB0;M>P8D zB&1WTLX*7;-(olQjMtO+!_Z0#8uRb0ymWq|ZgUZ^hgLc>c%xc+cLA@576W|2qVNwc z0KDQ3I);U=MkUp6M44rKiroDubt&*2$%_NPZ5q#iaTrqbgqFU7%AL3;sDbNST(KvuX7#;d2LUdw>nNS6i788$ ztg+D330j3YDZgVa=KDSSS7KiN4Z&c{zr|usZ~XuLQ)G6;s)M}kgKru%#Vm2ngqmf8 zy-EMyT}3}C*Boc~=f};e(99%!sywV$HC|koRW(r?!;r(2ozwX>+%=x}3MxR|sg{@c z&#;P2-p2gBM=%9-8&;+5P7-5lN|})pt5a4PtoPsg_{9rj%;}V?iPzKkQ|R0QD7Rhb zC-ew9T8UN$ogH48QK}h`0d}kWYHeSnFDz_9=aR54PFxP1Sx}~ToKCkZ*2AMH{dFvc zTFRH%!Eo7+I(RzM!Zv>bA9@tosnz!3y+H8tZaBeBF`-USs$(%Qkuz;V9c11CCbeB% zvQ%=`#@7(J!2JGZ2lCxN=Yrp* z2b~2b_Ta;+ftkX(EG0+Fkj>_qHYWYy9W-}LqQ02_6l@D20LIs0*|rKq_s%QLD!g5jU;nbuOX zG#k8Rd@8x1s5YA6pR-22)qIRrmlbldr}P9vARSy<+}rYCipdM)SsmqH~j*-Ft z>?ag#RAlBsFE#U)bk_Wfq?~}m18*8kHu5|oyZsVyQG;>f)s-v!M>^vdI@9}Dl!duE zpX8FS_@wfFV^>jSXH5`8&;u8#(Yk@5$-@qolRwvG?TDGE5feG;G2?fhi^flS;6W7F zx&aSG1VQ~rAd?^_X#%bM5v|3|Rdu69*7+TdCkQi!EHhE5yzI@nEB@OnC^u|d#1&1N~qa zx3Fe@nZqkG67TpgyV&NXaCevvLGH(vDzRgT?``wbZmxlencj>r_g9YPBotkR_)bgm z*v!=i5&l~Zj#s!Kdic=8h>>l+ex9qX#rG2waa97&X8xI?JRI2{(YaY1IyZ7lk!LBc zXjP+;xdmR#=Uxb3Q=iNxa%MWcZCph)6LG=p1(U&8(Nn_#>4116&+^UFgajOIt$){N zF|#H4R<0uL_$h6>R~i8d*E!iHioEJ^KmF~riWh0_a&W_BT6wh}LqH0S={?_Ca-!pA zUDgj~?9^=mYfn3BOO_DT!%&nHm+iJym*oKs9+9M}bYO=Lc-NQ`|`uc1yO4@w8h*X7deFL3nV&vv^= zJcWDJr@+4BTt!1Ri$e)G7cVk!4FM;+@x)r*w=S$XZUS!4`h8k#*5VjTuVP6U94~!p z*x>fbR7h$}ACWV$_-bM!FS(;Na93BXATyxw5joQM-(>bb*1I6a9V%cxh7c$2Xzt02 zl0VtAE7`%dWQIUYnGnmtXds9?@z)wv!jUQoFBjV)X+;@*{Q%w7Dc_D>ts>DlRelIKO98~V94DhrCQBy zrF8@U*A6ZGVU2cC)vL-nAJx!@I}&gpn@WisdF$6_>atQ}`8yME`EyTh$-#*pF+!$} zwOrHCaWEW8}L+CfqxlA8^P6%7?gsWrPTR~S&v+qgZ> z!@Ro;qvan54ynlPISdO?teF0VK1L}(78($cl&_(+hLu-pI*;vxA`U>FE>Qi%-Bqqu zztl&ed)hU$$poYx>AT7oE9;!3(_y;W0Kp}+TpytE^{ZzZqp<_&psA0pZ2E)4CZqY^#TM5;tUq>*9Dlu{LB&MsMkf z4v?kbCf1UarN~o_pPx*?O~edJ9P@kIz~E{c;w0cT4ke zx&kkDX!~HsUW`>|)D~?0OK%y+5QbzFU{oXtOw}WY$?D9Qj& zmp>%nT2D0?X@wRVwOWRK+<0_qIf4CU8C!o0vB=DjqB?OAb?pF0sd47RCcq}@(mjV5U@sWo(V zq**UVLtxgBRRwD84Q($N&4waIk4p!jh<#8TJWVbdDPM13En4rI1oEfV6~2i(+Vhwy$YnqMTASwU`;g54A_QvQIEw<&d@E7hKFy4@s5m6rJ=_Sn>N5V?R4yxh++?T*A~ zHO{YezJI>=H{S8P8uvFUa>tYPbc>eofj(uD$O+t48aer({(>F|my!frr$-nEyk!5E ze5#{5&Fu{{-cH({UJHVL#{axuy89_)OMO~|y=>c|Wabk(W5Ns|w^?=ApP_(}L+$++F&diEX?C_-sP4gDU0~C24@e9B1xDJ zU3{4;2kX9+vp5IxJz2@DrC!lSb|)I?uM=$ql98n!H)5xBo806+usbOdgJ5kII&7>i z4}tx}DVc42TW+V+U2=jP3(-~=ZUZ{!59ONoP`-rv-Js4?YcI2f$Q2C%lodLLElniQ z%5!)6!-o(l~R2iPWx|?r~mgxL+WumBx%e)8$Pbs66GX17^@6V>s0`=<(d zJq!3nQ+5ZUnszoIkToc+<@HqZdY;jl_1cg1??oE1PQTcy15WoTHXtrmGiST8C+MU= zx>X;c7vU@6Q!>1=F-<;{6-~Gn$8rdutBTvq**1Ea1_SC2!3X8U#$cM>KAB^AjR89e zAM_JCX4~uF82SaImUdmU0++OT=``7gGPx7<4movM%GFQTB;d9ip9JBu`3`V7;RQTe zGS<#6Uu9q>bF%+fB+452cYW49G>|^S+u3t8!<#WgI7z8M{=HlLzIbX3Du}I(q4t$= zJEbXQ^vy7R&{sG^8yv}UW;tfwHyq}aiF78#Hl(>+y#+nja0-3&e6LOD6gozL&$ZId z2upc$_B^lFx!gh~Z=%{A7T==WvSW4NE>LPff%_IsHk)JV*^`p;PQsXpTa49MR3(qV z5(Uxw^JQ0c=tRX99by1*nhV^#L1U6(C-dRP=HDTjl2IQI8N%%U3IdoaqU2S8`@d~o zbYuE@K3>Q>1VyF18X7AYy;?FJ0z)%=bS@Sm(a>OL?(J*BFtUbz+eF_$Bb+3`k^Q^E zGmkcn`)rv2-MuGO27eqvZ4sE7^O(?kca)QS8x67#Jq45Wlm}2$S(|LYCexYuDkSNC zL`r}IQsg9Fbmm}|Uw*3V^+6-VvJ2R7TFs0TcLjdqae5pR&AY_c!Kc_yH8ekqtKQ3M z3~-njV&0$6+GupLvYGKBlY900PNc7$Aa#I+^7zG%*@i}=e&-(lzQ1yNBX6v{Wh5%- z59uN-8a#}?P5{OJjEb?Rf7ev=XIQli{w%$=cBjbZo;%W1FscJZJfdZ*bO7dB^=0{l zL-CR>_-28QZog*N9Tw6Z1vc|*WLT@h9DZs&aC6XCp|3HlWu9Oaw)AE+L)w_GVr(7U zr#^TYJh$D{d2FzTXP~E1&;2n}R69`rPt{sLZ%ZnQS|J{n_?~p8w0yuMRUC3_u!5DE zc^+rQLL{-V5zPxjO&fZxOV+((UL@H=sm9f6v(LKSc<+5y0EB6zL9!O@3LwZ_eto_32-^Bkt5{zk<$3 z0VtYsaOy191lzs4CpLPE*UtPZ>8T`&;3+Je+uEtUGjnM`lmmK{k< zzP>vgb8==f*UeRf{nQ*by^Nj7=D^mNi&pr#Iz(Nz}A6m=o9WerBr56XET(Vtv!xdZWmi?g)h(_yuqXL#=sN1l##j)6$-H%sKKPw+gfXlVbwAdr~Ma}iC z&5JP&Z=#$beH4{X|Ab6J+z~)!tFanEbkbIHJx2X*Rz_tYol_}~IbE^bTK>Td3o0EB zL*#VR@$U(+#dM{Z4Lad$xxJqJ?OFnU09`Dv5aN@8iV)J9A4 z0vm;W2kQfl$zZ(;QWhlxN<&!T8qb&nmGgL-Zs~VD6llA^g8H08HAh-3=gL3*Q&zu{^?$+Y3vcUc zt;H~|ZF(T?+}^n*c|pSPxS9_p!DIfBf_`mzcQ1&Mgu^ zjz$-*l6plh-wzaXPCl3pnV#}H@XqV;QyN)EoNH#NJwB2*b)_y#G4|=I1ROv3YGsz zrwCklZpC4WHU_g1Cz}Q6WLdk0jnT_D@@U>NC!D6(9j>e#)4fHahC*p~sEo9D4 zWAnGiDHx>k8iAH|R7MF34}=aHXQ$$ekA}zWLAYeA)kF)R2ZjpEnkRZ?RN&>Tj0(^g zs0<72Ze3-d;KNz=k0u5D(hp2tTJ#B$-&R6b0OV>k$=pn3Mm?+joq->-kGp zF+KwU_wmyv5;QY{wbkq{M_=uqTEqruNe)7}Is7tgqhkr)xl<(|$Ivyb*|2c~Pc^dZ zsM1R<8J3oVjZbzl_=_Pt&C12Ou|`r7?&l4?A;~5FJ*MzSq`tc$jD~Idu|D7a3nZuP zTKgw64!;&Tv^P=Zcgm7Ww)V_SW(Go* zcDo{+{k~i(RJ#NSP7Zw2PVe}Iy@8kpX2?)OT^6rrZ(w1~Y)@U5323epaLQh%o)Hz^ zS7{2&*~iq1EJnE&DowM6Q8j=V1D>o-QwcC+o79ZhgD$$H^H+x7^YwI_ts)QX-cg1> zaJ=m_mEomU8AbtskK7s~2WJ#M5Tw8N1D&1i3u{e!(ic5OS9<9`qYC7pU)Ikty?`$S zhMbd@VR!+*TQL&IOm$g%6s^L4(HG44m-fiPZ@u58`uS$NS1iB{BbncQFCAs zx6$q8J(2G^eYXD!tp~t(;hgkdp7sb`M2UojjJC!bRML9lvESSEXVbsvB0A)u*5>^C57q59fRUAbim6r(J-P=(eV%vz9x2Sj;d-HN50uoe|d zIh~U62cYkb{PH(w_51GpU3s*j@Ghs_IC#L)HSEbDTYJ`Js0^4U5KF=alaCnP4qJg` z|8P#-+%>Fjjq(N_FeqzT!0xJ+8SFOy`o9d&42Sl0!1#i<(Ihg3$Te)X8xW~#Y5#BmW9V!HbV!^{hwrV&KIG-oSys7}$HBn)EnhZ%NYv+cy6FEoEsev`UxatCU z*7Oa4V~>)S74nRmxe*ll%_9zGe(VKV&GFQb@@kk8^vmSH?QM(0%Uh6DUOPiizyaM@ zS^j_NG7hS-uj>+UxALM#iJUu)>Q&E0GWSWqX-K;oK?W}Z$|6{I8cksf)x)X?s0M&= zE~E93EeBPHm%a;0348g&Ik0e04G^w+HJG&s%LhU3SoZ9<*`Ya4W(H;TGNDD=g*#jj z=mW)-pbl!J!73EgGrStn2FEJ=L`Of^`?UiMh#sFOm$zBmQ1TW7`XNR#0)N7w``4glz9~>d`NW9LaD#Yh7;g zgzWx^G`LnGsPplN;6lLrRL|u$y*IrDHUGhHA7IZg9~}dAfA8o2?I|6Z*8c|x_D@1y zurhI(>2nG0mEVt^#>$uSY6+0z{;8|jMlCaIb^@JQ(qQezZP#AKqwxrw%5$Vl?|Y`8 zHLJQfc`-_}0to$Q(8^EQXQ(7a;*{()mW^3~>4IgZ6kR%^2F+ZQrHEMqkSuPaV_@cH zbHqBvO29a4m)L6sVO=xraemYl^; zi|Ob1VaeFOm82quCtGzyj3Swgu5 z(`PIu0P4OyIMO~5&iT<>zx>zf(2|Y#G5^>_A@koeTjuCO%d{1ks}cqBWFm+Fkwt04 z5&-kxHQ%&3FhJT*fJ~gYGhnWbChc~Ha@`@Z6ZfJ!srOy9YEmyR+s~5O*fr^!IW+nO z0JDTYYGxh-jds%(<>E=gF^YZOGQHd_&3F)lpkO9*DTAL*?YqqH!D<6sQW$OIgMu+H z)d50w)7f%#^Z{P|V5k=?4EVCw8sP(!pLy-`S}^Wd?%zGUbIXW08LL>vYsvBpwt^!5 zfvh6H-gHzOD>;DpP^%5YPRRi9WM+r-6y1PzvwlOip3zCNe}v}9ci3`Z{2c`8=Ll83}S-baXu_5crB0>CM+ zRS{QEH1Tj=@%67WXS>Jz8%pD0ArLjHsN*(G{teS9=v>3g!?PY9v;APSGKXJ*FKpR% zH?@4k#BG&$GRKJJ+}dYTYc5!v^h>7CyR73XDF_3B8XnbvejDGMhh0eYDp}&v)5s&WerL}QT1#-6R($fNIwA`oxEm}asyS* zm$lj+b5(--za;{ralgkq6`x!l+HETif!?Vsk(?KC)vhd2tVYoeLFZ^@Q)A~bfdt%R zS4pZH{jSyAotiLYJd5AQ>yh_wf4CJLz3p~t_A0hYQ}U#G#vKAPv`$DI_2dE2wkT^# z4k8IiUmE`nug9ah?d!{w)JOL\n", + "**Date created:** 2023/01/08
\n", + "**Last modified:** 2023/01/08
\n", + "**Description:** Training a Dense-layer model using the Forward-Forward algorithm." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Introduction\n", + "\n", + "The following example explores how to use the Forward-Forward algorithm to perform\n", + "training instead of the traditionally-used method of backpropagation, as proposed by\n", + "Hinton in\n", + "[The Forward-Forward Algorithm: Some Preliminary Investigations](https://www.cs.toronto.edu/~hinton/FFA13.pdf)\n", + "(2022).\n", + "\n", + "The concept was inspired by the understanding behind\n", + "[Boltzmann Machines](http://www.cs.toronto.edu/~fritz/absps/dbm.pdf). Backpropagation\n", + "involves calculating the difference between actual and predicted output via a cost\n", + "function to adjust network weights. On the other hand, the FF Algorithm suggests the\n", + "analogy of neurons which get \"excited\" based on looking at a certain recognized\n", + "combination of an image and its correct corresponding label.\n", + "\n", + "This method takes certain inspiration from the biological learning process that occurs in\n", + "the cortex. A significant advantage that this method brings is the fact that\n", + "backpropagation through the network does not need to be performed anymore, and that\n", + "weight updates are local to the layer itself.\n", + "\n", + "As this is yet still an experimental method, it does not yield state-of-the-art results.\n", + "But with proper tuning, it is supposed to come close to the same.\n", + "Through this example, we will examine a process that allows us to implement the\n", + "Forward-Forward algorithm within the layers themselves, instead of the traditional method\n", + "of relying on the global loss functions and optimizers.\n", + "\n", + "The tutorial is structured as follows:\n", + "\n", + "- Perform necessary imports\n", + "- Load the [MNIST dataset](http://yann.lecun.com/exdb/mnist/)\n", + "- Visualize Random samples from the MNIST dataset\n", + "- Define a `FFDense` Layer to override `call` and implement a custom `forwardforward`\n", + "method which performs weight updates.\n", + "- Define a `FFNetwork` Layer to override `train_step`, `predict` and implement 2 custom\n", + "functions for per-sample prediction and overlaying labels\n", + "- Convert MNIST from `NumPy` arrays to `tf.data.Dataset`\n", + "- Fit the network\n", + "- Visualize results\n", + "- Perform inference on test samples\n", + "\n", + "As this example requires the customization of certain core functions with\n", + "`keras.layers.Layer` and `keras.models.Model`, refer to the following resources for\n", + "a primer on how to do so:\n", + "\n", + "- [Customizing what happens in `model.fit()`](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit)\n", + "- [Making new Layers and Models via subclassing](https://www.tensorflow.org/guide/keras/custom_layers_and_models)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Setup imports" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "from tensorflow import keras\n", + "import numpy as np\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.metrics import accuracy_score\n", + "import random\n", + "from tensorflow.compiler.tf2xla.python import xla" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Load the dataset and visualize the data\n", + "\n", + "We use the `keras.datasets.mnist.load_data()` utility to directly pull the MNIST dataset\n", + "in the form of `NumPy` arrays. We then arrange it in the form of the train and test\n", + "splits.\n", + "\n", + "Following loading the dataset, we select 4 random samples from within the training set\n", + "and visualize them using `matplotlib.pyplot`." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data()\n", + "\n", + "print(\"4 Random Training samples and labels\")\n", + "idx1, idx2, idx3, idx4 = random.sample(range(0, x_train.shape[0]), 4)\n", + "\n", + "img1 = (x_train[idx1], y_train[idx1])\n", + "img2 = (x_train[idx2], y_train[idx2])\n", + "img3 = (x_train[idx3], y_train[idx3])\n", + "img4 = (x_train[idx4], y_train[idx4])\n", + "\n", + "imgs = [img1, img2, img3, img4]\n", + "\n", + "plt.figure(figsize=(10, 10))\n", + "\n", + "for idx, item in enumerate(imgs):\n", + " image, label = item[0], item[1]\n", + " plt.subplot(2, 2, idx + 1)\n", + " plt.imshow(image, cmap=\"gray\")\n", + " plt.title(f\"Label : {label}\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Define `FFDense` custom layer\n", + "\n", + "In this custom layer, we have a base `keras.layers.Dense` object which acts as the\n", + "base `Dense` layer within. Since weight updates will happen within the layer itself, we\n", + "add an `keras.optimizers.Optimizer` object that is accepted from the user. Here, we\n", + "use `Adam` as our optimizer with a rather higher learning rate of `0.03`.\n", + "\n", + "Following the algorithm's specifics, we must set a `threshold` parameter that will be\n", + "used to make the positive-negative decision in each prediction. This is set to a default\n", + "of 2.0.\n", + "As the epochs are localized to the layer itself, we also set a `num_epochs` parameter\n", + "(defaults to 50).\n", + "\n", + "We override the `call` method in order to perform a normalization over the complete\n", + "input space followed by running it through the base `Dense` layer as would happen in a\n", + "normal `Dense` layer call.\n", + "\n", + "We implement the Forward-Forward algorithm which accepts 2 kinds of input tensors, each\n", + "representing the positive and negative samples respectively. We write a custom training\n", + "loop here with the use of `tf.GradientTape()`, within which we calculate a loss per\n", + "sample by taking the distance of the prediction from the threshold to understand the\n", + "error and taking its mean to get a `mean_loss` metric.\n", + "\n", + "With the help of `tf.GradientTape()` we calculate the gradient updates for the trainable\n", + "base `Dense` layer and apply them using the layer's local optimizer.\n", + "\n", + "Finally, we return the `call` result as the `Dense` results of the positive and negative\n", + "samples while also returning the last `mean_loss` metric and all the loss values over a\n", + "certain all-epoch run." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "class FFDense(keras.layers.Layer):\n", + " \"\"\"\n", + " A custom ForwardForward-enabled Dense layer. It has an implementation of the\n", + " Forward-Forward network internally for use.\n", + " This layer must be used in conjunction with the `FFNetwork` model.\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " units,\n", + " optimizer,\n", + " loss_metric,\n", + " num_epochs=50,\n", + " use_bias=True,\n", + " kernel_initializer=\"glorot_uniform\",\n", + " bias_initializer=\"zeros\",\n", + " kernel_regularizer=None,\n", + " bias_regularizer=None,\n", + " **kwargs,\n", + " ):\n", + " super().__init__(**kwargs)\n", + " self.dense = keras.layers.Dense(\n", + " units=units,\n", + " use_bias=use_bias,\n", + " kernel_initializer=kernel_initializer,\n", + " bias_initializer=bias_initializer,\n", + " kernel_regularizer=kernel_regularizer,\n", + " bias_regularizer=bias_regularizer,\n", + " )\n", + " self.relu = keras.layers.ReLU()\n", + " self.optimizer = optimizer\n", + " self.loss_metric = loss_metric\n", + " self.threshold = 1.5\n", + " self.num_epochs = num_epochs\n", + "\n", + " # We perform a normalization step before we run the input through the Dense\n", + " # layer.\n", + "\n", + " def call(self, x):\n", + " x_norm = tf.norm(x, ord=2, axis=1, keepdims=True)\n", + " x_norm = x_norm + 1e-4\n", + " x_dir = x / x_norm\n", + " res = self.dense(x_dir)\n", + " return self.relu(res)\n", + "\n", + " # The Forward-Forward algorithm is below. We first perform the Dense-layer\n", + " # operation and then get a Mean Square value for all positive and negative\n", + " # samples respectively.\n", + " # The custom loss function finds the distance between the Mean-squared\n", + " # result and the threshold value we set (a hyperparameter) that will define\n", + " # whether the prediction is positive or negative in nature. Once the loss is\n", + " # calculated, we get a mean across the entire batch combined and perform a\n", + " # gradient calculation and optimization step. This does not technically\n", + " # qualify as backpropagation since there is no gradient being\n", + " # sent to any previous layer and is completely local in nature.\n", + "\n", + " def forward_forward(self, x_pos, x_neg):\n", + " for i in range(self.num_epochs):\n", + " with tf.GradientTape() as tape:\n", + " g_pos = tf.math.reduce_mean(tf.math.pow(self.call(x_pos), 2), 1)\n", + " g_neg = tf.math.reduce_mean(tf.math.pow(self.call(x_neg), 2), 1)\n", + "\n", + " loss = tf.math.log(\n", + " 1\n", + " + tf.math.exp(\n", + " tf.concat([-g_pos + self.threshold, g_neg - self.threshold], 0)\n", + " )\n", + " )\n", + " mean_loss = tf.cast(tf.math.reduce_mean(loss), tf.float32)\n", + " self.loss_metric.update_state([mean_loss])\n", + " gradients = tape.gradient(mean_loss, self.dense.trainable_weights)\n", + " self.optimizer.apply_gradients(zip(gradients, self.dense.trainable_weights))\n", + " return (\n", + " tf.stop_gradient(self.call(x_pos)),\n", + " tf.stop_gradient(self.call(x_neg)),\n", + " self.loss_metric.result(),\n", + " )\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Define the `FFNetwork` Custom Model\n", + "\n", + "With our custom layer defined, we also need to override the `train_step` method and\n", + "define a custom `keras.models.Model` that works with our `FFDense` layer.\n", + "\n", + "For this algorithm, we must 'embed' the labels onto the original image. To do so, we\n", + "exploit the structure of MNIST images where the top-left 10 pixels are always zeros. We\n", + "use that as a label space in order to visually one-hot-encode the labels within the image\n", + "itself. This action is performed by the `overlay_y_on_x` function.\n", + "\n", + "We break down the prediction function with a per-sample prediction function which is then\n", + "called over the entire test set by the overriden `predict()` function. The prediction is\n", + "performed here with the help of measuring the `excitation` of the neurons per layer for\n", + "each image. This is then summed over all layers to calculate a network-wide 'goodness\n", + "score'. The label with the highest 'goodness score' is then chosen as the sample\n", + "prediction.\n", + "\n", + "The `train_step` function is overriden to act as the main controlling loop for running\n", + "training on each layer as per the number of epochs per layer." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "\n", + "class FFNetwork(keras.Model):\n", + " \"\"\"\n", + " A `keras.Model` that supports a `FFDense` network creation. This model\n", + " can work for any kind of classification task. It has an internal\n", + " implementation with some details specific to the MNIST dataset which can be\n", + " changed as per the use-case.\n", + " \"\"\"\n", + "\n", + " # Since each layer runs gradient-calculation and optimization locally, each\n", + " # layer has its own optimizer that we pass. As a standard choice, we pass\n", + " # the `Adam` optimizer with a default learning rate of 0.03 as that was\n", + " # found to be the best rate after experimentation.\n", + " # Loss is tracked using `loss_var` and `loss_count` variables.\n", + "\n", + " def __init__(\n", + " self, dims, layer_optimizer=keras.optimizers.Adam(learning_rate=0.03), **kwargs\n", + " ):\n", + " super().__init__(**kwargs)\n", + " self.layer_optimizer = layer_optimizer\n", + " self.loss_var = tf.Variable(0.0, trainable=False, dtype=tf.float32)\n", + " self.loss_count = tf.Variable(0.0, trainable=False, dtype=tf.float32)\n", + " self.layer_list = [keras.Input(shape=(dims[0],))]\n", + " for d in range(len(dims) - 1):\n", + " self.layer_list += [\n", + " FFDense(\n", + " dims[d + 1],\n", + " optimizer=self.layer_optimizer,\n", + " loss_metric=keras.metrics.Mean(),\n", + " )\n", + " ]\n", + "\n", + " # This function makes a dynamic change to the image wherein the labels are\n", + " # put on top of the original image (for this example, as MNIST has 10\n", + " # unique labels, we take the top-left corner's first 10 pixels). This\n", + " # function returns the original data tensor with the first 10 pixels being\n", + " # a pixel-based one-hot representation of the labels.\n", + "\n", + " @tf.function(reduce_retracing=True)\n", + " def overlay_y_on_x(self, data):\n", + " X_sample, y_sample = data\n", + " max_sample = tf.reduce_max(X_sample, axis=0, keepdims=True)\n", + " max_sample = tf.cast(max_sample, dtype=tf.float64)\n", + " X_zeros = tf.zeros([10], dtype=tf.float64)\n", + " X_update = xla.dynamic_update_slice(X_zeros, max_sample, [y_sample])\n", + " X_sample = xla.dynamic_update_slice(X_sample, X_update, [0])\n", + " return X_sample, y_sample\n", + "\n", + " # A custom `predict_one_sample` performs predictions by passing the images\n", + " # through the network, measures the results produced by each layer (i.e.\n", + " # how high/low the output values are with respect to the set threshold for\n", + " # each label) and then simply finding the label with the highest values.\n", + " # In such a case, the images are tested for their 'goodness' with all\n", + " # labels.\n", + "\n", + " @tf.function(reduce_retracing=True)\n", + " def predict_one_sample(self, x):\n", + " goodness_per_label = []\n", + " x = tf.reshape(x, [tf.shape(x)[0] * tf.shape(x)[1]])\n", + " for label in range(10):\n", + " h, label = self.overlay_y_on_x(data=(x, label))\n", + " h = tf.reshape(h, [-1, tf.shape(h)[0]])\n", + " goodness = []\n", + " for layer_idx in range(1, len(self.layer_list)):\n", + " layer = self.layer_list[layer_idx]\n", + " h = layer(h)\n", + " goodness += [tf.math.reduce_mean(tf.math.pow(h, 2), 1)]\n", + " goodness_per_label += [\n", + " tf.expand_dims(tf.reduce_sum(goodness, keepdims=True), 1)\n", + " ]\n", + " goodness_per_label = tf.concat(goodness_per_label, 1)\n", + " return tf.cast(tf.argmax(goodness_per_label, 1), tf.float64)\n", + "\n", + " def predict(self, data):\n", + " x = data\n", + " preds = list()\n", + " preds = tf.map_fn(fn=self.predict_one_sample, elems=x)\n", + " return np.asarray(preds, dtype=int)\n", + "\n", + " # This custom `train_step` function overrides the internal `train_step`\n", + " # implementation. We take all the input image tensors, flatten them and\n", + " # subsequently produce positive and negative samples on the images.\n", + " # A positive sample is an image that has the right label encoded on it with\n", + " # the `overlay_y_on_x` function. A negative sample is an image that has an\n", + " # erroneous label present on it.\n", + " # With the samples ready, we pass them through each `FFLayer` and perform\n", + " # the Forward-Forward computation on it. The returned loss is the final\n", + " # loss value over all the layers.\n", + "\n", + " @tf.function(jit_compile=True)\n", + " def train_step(self, data):\n", + " x, y = data\n", + "\n", + " # Flatten op\n", + " x = tf.reshape(x, [-1, tf.shape(x)[1] * tf.shape(x)[2]])\n", + "\n", + " x_pos, y = tf.map_fn(fn=self.overlay_y_on_x, elems=(x, y))\n", + "\n", + " random_y = tf.random.shuffle(y)\n", + " x_neg, y = tf.map_fn(fn=self.overlay_y_on_x, elems=(x, random_y))\n", + "\n", + " h_pos, h_neg = x_pos, x_neg\n", + "\n", + " for idx, layer in enumerate(self.layers):\n", + " if isinstance(layer, FFDense):\n", + " print(f\"Training layer {idx+1} now : \")\n", + " h_pos, h_neg, loss = layer.forward_forward(h_pos, h_neg)\n", + " self.loss_var.assign_add(loss)\n", + " self.loss_count.assign_add(1.0)\n", + " else:\n", + " print(f\"Passing layer {idx+1} now : \")\n", + " x = layer(x)\n", + " mean_res = tf.math.divide(self.loss_var, self.loss_count)\n", + " return {\"FinalLoss\": mean_res}\n", + "" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Convert MNIST `NumPy` arrays to `tf.data.Dataset`\n", + "\n", + "We now perform some preliminary processing on the `NumPy` arrays and then convert them\n", + "into the `tf.data.Dataset` format which allows for optimized loading." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "x_train = x_train.astype(float) / 255\n", + "x_test = x_test.astype(float) / 255\n", + "y_train = y_train.astype(int)\n", + "y_test = y_test.astype(int)\n", + "\n", + "train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))\n", + "test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))\n", + "\n", + "train_dataset = train_dataset.batch(60000)\n", + "test_dataset = test_dataset.batch(10000)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Fit the network and visualize results\n", + "\n", + "Having performed all previous set-up, we are now going to run `model.fit()` and run 250\n", + "model epochs, which will perform 50*250 epochs on each layer. We get to see the plotted loss\n", + "curve as each layer is trained." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "model = FFNetwork(dims=[784, 500, 500])\n", + "\n", + "model.compile(\n", + " optimizer=keras.optimizers.Adam(learning_rate=0.03),\n", + " loss=\"mse\",\n", + " jit_compile=True,\n", + " metrics=[keras.metrics.Mean()],\n", + ")\n", + "\n", + "epochs = 250\n", + "history = model.fit(train_dataset, epochs=epochs)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Perform inference and testing\n", + "\n", + "Having trained the model to a large extent, we now see how it performs on the\n", + "test set. We calculate the Accuracy Score to understand the results closely." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab_type": "code" + }, + "outputs": [], + "source": [ + "preds = model.predict(tf.convert_to_tensor(x_test))\n", + "\n", + "preds = preds.reshape((preds.shape[0], preds.shape[1]))\n", + "\n", + "results = accuracy_score(preds, y_test)\n", + "\n", + "print(f\"Test Accuracy score : {results*100}%\")\n", + "\n", + "plt.plot(range(len(history.history[\"FinalLoss\"])), history.history[\"FinalLoss\"])\n", + "plt.title(\"Loss over training\")\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text" + }, + "source": [ + "## Conclusion\n", + "\n", + "This example has hereby demonstrated how the Forward-Forward algorithm works using\n", + "the TensorFlow and Keras packages. While the investigation results presented by Prof. Hinton\n", + "in their paper are currently still limited to smaller models and datasets like MNIST and\n", + "Fashion-MNIST, subsequent results on larger models like LLMs are expected in future\n", + "papers.\n", + "\n", + "Through the paper, Prof. Hinton has reported results of 1.36% test accuracy error with a\n", + "2000-units, 4 hidden-layer, fully-connected network run over 60 epochs (while mentioning\n", + "that backpropagation takes only 20 epochs to achieve similar performance). Another run of\n", + "doubling the learning rate and training for 40 epochs yields a slightly worse error rate\n", + "of 1.46%\n", + "\n", + "The current example does not yield state-of-the-art results. But with proper tuning of\n", + "the Learning Rate, model architecture (number of units in `Dense` layers, kernel\n", + "activations, initializations, regularization etc.), the results can be improved\n", + "to match the claims of the paper." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "forwardforward", + "private_outputs": false, + "provenance": [], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.0" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} \ No newline at end of file diff --git a/examples/vision/md/forwardforward.md b/examples/vision/md/forwardforward.md new file mode 100644 index 0000000000..75eb37f093 --- /dev/null +++ b/examples/vision/md/forwardforward.md @@ -0,0 +1,973 @@ +# Using Forward-Forward Algorithm for Image Classification + +**Author:** [Suvaditya Mukherjee](https://twitter.com/halcyonrayes)
+**Date created:** 2023/01/08
+**Last modified:** 2023/01/08
+**Description:** Training a Dense-layer model using the Forward-Forward algorithm. + + + [**View in Colab**](https://colab.research.google.com/github/keras-team/keras-io/blob/master/examples/vision/ipynb/forwardforward.ipynb) [**GitHub source**](https://github.com/keras-team/keras-io/blob/master/examples/vision/forwardforward.py) + + + +--- +## Introduction + +The following example explores how to use the Forward-Forward algorithm to perform +training instead of the traditionally-used method of backpropagation, as proposed by +Hinton in +[The Forward-Forward Algorithm: Some Preliminary Investigations](https://www.cs.toronto.edu/~hinton/FFA13.pdf) +(2022). + +The concept was inspired by the understanding behind +[Boltzmann Machines](http://www.cs.toronto.edu/~fritz/absps/dbm.pdf). Backpropagation +involves calculating the difference between actual and predicted output via a cost +function to adjust network weights. On the other hand, the FF Algorithm suggests the +analogy of neurons which get "excited" based on looking at a certain recognized +combination of an image and its correct corresponding label. + +This method takes certain inspiration from the biological learning process that occurs in +the cortex. A significant advantage that this method brings is the fact that +backpropagation through the network does not need to be performed anymore, and that +weight updates are local to the layer itself. + +As this is yet still an experimental method, it does not yield state-of-the-art results. +But with proper tuning, it is supposed to come close to the same. +Through this example, we will examine a process that allows us to implement the +Forward-Forward algorithm within the layers themselves, instead of the traditional method +of relying on the global loss functions and optimizers. + +The tutorial is structured as follows: + +- Perform necessary imports +- Load the [MNIST dataset](http://yann.lecun.com/exdb/mnist/) +- Visualize Random samples from the MNIST dataset +- Define a `FFDense` Layer to override `call` and implement a custom `forwardforward` +method which performs weight updates. +- Define a `FFNetwork` Layer to override `train_step`, `predict` and implement 2 custom +functions for per-sample prediction and overlaying labels +- Convert MNIST from `NumPy` arrays to `tf.data.Dataset` +- Fit the network +- Visualize results +- Perform inference on test samples + +As this example requires the customization of certain core functions with +`keras.layers.Layer` and `keras.models.Model`, refer to the following resources for +a primer on how to do so: + +- [Customizing what happens in `model.fit()`](https://www.tensorflow.org/guide/keras/customizing_what_happens_in_fit) +- [Making new Layers and Models via subclassing](https://www.tensorflow.org/guide/keras/custom_layers_and_models) + +--- +## Setup imports + + +```python +import tensorflow as tf +from tensorflow import keras +import numpy as np +import matplotlib.pyplot as plt +from sklearn.metrics import accuracy_score +import random +from tensorflow.compiler.tf2xla.python import xla +``` + +--- +## Load the dataset and visualize the data + +We use the `keras.datasets.mnist.load_data()` utility to directly pull the MNIST dataset +in the form of `NumPy` arrays. We then arrange it in the form of the train and test +splits. + +Following loading the dataset, we select 4 random samples from within the training set +and visualize them using `matplotlib.pyplot`. + + +```python +(x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() + +print("4 Random Training samples and labels") +idx1, idx2, idx3, idx4 = random.sample(range(0, x_train.shape[0]), 4) + +img1 = (x_train[idx1], y_train[idx1]) +img2 = (x_train[idx2], y_train[idx2]) +img3 = (x_train[idx3], y_train[idx3]) +img4 = (x_train[idx4], y_train[idx4]) + +imgs = [img1, img2, img3, img4] + +plt.figure(figsize=(10, 10)) + +for idx, item in enumerate(imgs): + image, label = item[0], item[1] + plt.subplot(2, 2, idx + 1) + plt.imshow(image, cmap="gray") + plt.title(f"Label : {label}") +plt.show() +``` + +
+``` +Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz +11490434/11490434 [==============================] - 0s 0us/step +4 Random Training samples and labels + +``` +
+![png](/img/examples/vision/forwardforward/forwardforward_5_1.png) + + +--- +## Define `FFDense` custom layer + +In this custom layer, we have a base `keras.layers.Dense` object which acts as the +base `Dense` layer within. Since weight updates will happen within the layer itself, we +add an `keras.optimizers.Optimizer` object that is accepted from the user. Here, we +use `Adam` as our optimizer with a rather higher learning rate of `0.03`. + +Following the algorithm's specifics, we must set a `threshold` parameter that will be +used to make the positive-negative decision in each prediction. This is set to a default +of 2.0. +As the epochs are localized to the layer itself, we also set a `num_epochs` parameter +(defaults to 50). + +We override the `call` method in order to perform a normalization over the complete +input space followed by running it through the base `Dense` layer as would happen in a +normal `Dense` layer call. + +We implement the Forward-Forward algorithm which accepts 2 kinds of input tensors, each +representing the positive and negative samples respectively. We write a custom training +loop here with the use of `tf.GradientTape()`, within which we calculate a loss per +sample by taking the distance of the prediction from the threshold to understand the +error and taking its mean to get a `mean_loss` metric. + +With the help of `tf.GradientTape()` we calculate the gradient updates for the trainable +base `Dense` layer and apply them using the layer's local optimizer. + +Finally, we return the `call` result as the `Dense` results of the positive and negative +samples while also returning the last `mean_loss` metric and all the loss values over a +certain all-epoch run. + + +```python + +class FFDense(keras.layers.Layer): + """ + A custom ForwardForward-enabled Dense layer. It has an implementation of the + Forward-Forward network internally for use. + This layer must be used in conjunction with the `FFNetwork` model. + """ + + def __init__( + self, + units, + optimizer, + loss_metric, + num_epochs=50, + use_bias=True, + kernel_initializer="glorot_uniform", + bias_initializer="zeros", + kernel_regularizer=None, + bias_regularizer=None, + **kwargs, + ): + super().__init__(**kwargs) + self.dense = keras.layers.Dense( + units=units, + use_bias=use_bias, + kernel_initializer=kernel_initializer, + bias_initializer=bias_initializer, + kernel_regularizer=kernel_regularizer, + bias_regularizer=bias_regularizer, + ) + self.relu = keras.layers.ReLU() + self.optimizer = optimizer + self.loss_metric = loss_metric + self.threshold = 1.5 + self.num_epochs = num_epochs + + # We perform a normalization step before we run the input through the Dense + # layer. + + def call(self, x): + x_norm = tf.norm(x, ord=2, axis=1, keepdims=True) + x_norm = x_norm + 1e-4 + x_dir = x / x_norm + res = self.dense(x_dir) + return self.relu(res) + + # The Forward-Forward algorithm is below. We first perform the Dense-layer + # operation and then get a Mean Square value for all positive and negative + # samples respectively. + # The custom loss function finds the distance between the Mean-squared + # result and the threshold value we set (a hyperparameter) that will define + # whether the prediction is positive or negative in nature. Once the loss is + # calculated, we get a mean across the entire batch combined and perform a + # gradient calculation and optimization step. This does not technically + # qualify as backpropagation since there is no gradient being + # sent to any previous layer and is completely local in nature. + + def forward_forward(self, x_pos, x_neg): + for i in range(self.num_epochs): + with tf.GradientTape() as tape: + g_pos = tf.math.reduce_mean(tf.math.pow(self.call(x_pos), 2), 1) + g_neg = tf.math.reduce_mean(tf.math.pow(self.call(x_neg), 2), 1) + + loss = tf.math.log( + 1 + + tf.math.exp( + tf.concat([-g_pos + self.threshold, g_neg - self.threshold], 0) + ) + ) + mean_loss = tf.cast(tf.math.reduce_mean(loss), tf.float32) + self.loss_metric.update_state([mean_loss]) + gradients = tape.gradient(mean_loss, self.dense.trainable_weights) + self.optimizer.apply_gradients(zip(gradients, self.dense.trainable_weights)) + return ( + tf.stop_gradient(self.call(x_pos)), + tf.stop_gradient(self.call(x_neg)), + self.loss_metric.result(), + ) + +``` + +--- +## Define the `FFNetwork` Custom Model + +With our custom layer defined, we also need to override the `train_step` method and +define a custom `keras.models.Model` that works with our `FFDense` layer. + +For this algorithm, we must 'embed' the labels onto the original image. To do so, we +exploit the structure of MNIST images where the top-left 10 pixels are always zeros. We +use that as a label space in order to visually one-hot-encode the labels within the image +itself. This action is performed by the `overlay_y_on_x` function. + +We break down the prediction function with a per-sample prediction function which is then +called over the entire test set by the overriden `predict()` function. The prediction is +performed here with the help of measuring the `excitation` of the neurons per layer for +each image. This is then summed over all layers to calculate a network-wide 'goodness +score'. The label with the highest 'goodness score' is then chosen as the sample +prediction. + +The `train_step` function is overriden to act as the main controlling loop for running +training on each layer as per the number of epochs per layer. + + +```python + +class FFNetwork(keras.Model): + """ + A `keras.Model` that supports a `FFDense` network creation. This model + can work for any kind of classification task. It has an internal + implementation with some details specific to the MNIST dataset which can be + changed as per the use-case. + """ + + # Since each layer runs gradient-calculation and optimization locally, each + # layer has its own optimizer that we pass. As a standard choice, we pass + # the `Adam` optimizer with a default learning rate of 0.03 as that was + # found to be the best rate after experimentation. + # Loss is tracked using `loss_var` and `loss_count` variables. + + def __init__( + self, dims, layer_optimizer=keras.optimizers.Adam(learning_rate=0.03), **kwargs + ): + super().__init__(**kwargs) + self.layer_optimizer = layer_optimizer + self.loss_var = tf.Variable(0.0, trainable=False, dtype=tf.float32) + self.loss_count = tf.Variable(0.0, trainable=False, dtype=tf.float32) + self.layer_list = [keras.Input(shape=(dims[0],))] + for d in range(len(dims) - 1): + self.layer_list += [ + FFDense( + dims[d + 1], + optimizer=self.layer_optimizer, + loss_metric=keras.metrics.Mean(), + ) + ] + + # This function makes a dynamic change to the image wherein the labels are + # put on top of the original image (for this example, as MNIST has 10 + # unique labels, we take the top-left corner's first 10 pixels). This + # function returns the original data tensor with the first 10 pixels being + # a pixel-based one-hot representation of the labels. + + @tf.function(reduce_retracing=True) + def overlay_y_on_x(self, data): + X_sample, y_sample = data + max_sample = tf.reduce_max(X_sample, axis=0, keepdims=True) + max_sample = tf.cast(max_sample, dtype=tf.float64) + X_zeros = tf.zeros([10], dtype=tf.float64) + X_update = xla.dynamic_update_slice(X_zeros, max_sample, [y_sample]) + X_sample = xla.dynamic_update_slice(X_sample, X_update, [0]) + return X_sample, y_sample + + # A custom `predict_one_sample` performs predictions by passing the images + # through the network, measures the results produced by each layer (i.e. + # how high/low the output values are with respect to the set threshold for + # each label) and then simply finding the label with the highest values. + # In such a case, the images are tested for their 'goodness' with all + # labels. + + @tf.function(reduce_retracing=True) + def predict_one_sample(self, x): + goodness_per_label = [] + x = tf.reshape(x, [tf.shape(x)[0] * tf.shape(x)[1]]) + for label in range(10): + h, label = self.overlay_y_on_x(data=(x, label)) + h = tf.reshape(h, [-1, tf.shape(h)[0]]) + goodness = [] + for layer_idx in range(1, len(self.layer_list)): + layer = self.layer_list[layer_idx] + h = layer(h) + goodness += [tf.math.reduce_mean(tf.math.pow(h, 2), 1)] + goodness_per_label += [ + tf.expand_dims(tf.reduce_sum(goodness, keepdims=True), 1) + ] + goodness_per_label = tf.concat(goodness_per_label, 1) + return tf.cast(tf.argmax(goodness_per_label, 1), tf.float64) + + def predict(self, data): + x = data + preds = list() + preds = tf.map_fn(fn=self.predict_one_sample, elems=x) + return np.asarray(preds, dtype=int) + + # This custom `train_step` function overrides the internal `train_step` + # implementation. We take all the input image tensors, flatten them and + # subsequently produce positive and negative samples on the images. + # A positive sample is an image that has the right label encoded on it with + # the `overlay_y_on_x` function. A negative sample is an image that has an + # erroneous label present on it. + # With the samples ready, we pass them through each `FFLayer` and perform + # the Forward-Forward computation on it. The returned loss is the final + # loss value over all the layers. + + @tf.function(jit_compile=True) + def train_step(self, data): + x, y = data + + # Flatten op + x = tf.reshape(x, [-1, tf.shape(x)[1] * tf.shape(x)[2]]) + + x_pos, y = tf.map_fn(fn=self.overlay_y_on_x, elems=(x, y)) + + random_y = tf.random.shuffle(y) + x_neg, y = tf.map_fn(fn=self.overlay_y_on_x, elems=(x, random_y)) + + h_pos, h_neg = x_pos, x_neg + + for idx, layer in enumerate(self.layers): + if isinstance(layer, FFDense): + print(f"Training layer {idx+1} now : ") + h_pos, h_neg, loss = layer.forward_forward(h_pos, h_neg) + self.loss_var.assign_add(loss) + self.loss_count.assign_add(1.0) + else: + print(f"Passing layer {idx+1} now : ") + x = layer(x) + mean_res = tf.math.divide(self.loss_var, self.loss_count) + return {"FinalLoss": mean_res} + +``` + +--- +## Convert MNIST `NumPy` arrays to `tf.data.Dataset` + +We now perform some preliminary processing on the `NumPy` arrays and then convert them +into the `tf.data.Dataset` format which allows for optimized loading. + + +```python +x_train = x_train.astype(float) / 255 +x_test = x_test.astype(float) / 255 +y_train = y_train.astype(int) +y_test = y_test.astype(int) + +train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train)) +test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test)) + +train_dataset = train_dataset.batch(60000) +test_dataset = test_dataset.batch(10000) +``` + +--- +## Fit the network and visualize results + +Having performed all previous set-up, we are now going to run `model.fit()` and run 250 +model epochs, which will perform 50*250 epochs on each layer. We get to see the plotted loss +curve as each layer is trained. + + +```python +model = FFNetwork(dims=[784, 500, 500]) + +model.compile( + optimizer=keras.optimizers.Adam(learning_rate=0.03), + loss="mse", + jit_compile=True, + metrics=[keras.metrics.Mean()], +) + +epochs = 250 +history = model.fit(train_dataset, epochs=epochs) +``` + +
+``` +Epoch 1/250 +Training layer 1 now : +Training layer 2 now : +Training layer 1 now : +Training layer 2 now : +1/1 [==============================] - 72s 72s/step - FinalLoss: 0.7279 +Epoch 2/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.7082 +Epoch 3/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.7031 +Epoch 4/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.6806 +Epoch 5/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.6564 +Epoch 6/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.6333 +Epoch 7/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.6126 +Epoch 8/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.5946 +Epoch 9/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.5786 +Epoch 10/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.5644 +Epoch 11/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.5518 +Epoch 12/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.5405 +Epoch 13/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.5301 +Epoch 14/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.5207 +Epoch 15/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.5122 +Epoch 16/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.5044 +Epoch 17/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4972 +Epoch 18/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4906 +Epoch 19/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4845 +Epoch 20/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4787 +Epoch 21/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4734 +Epoch 22/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4685 +Epoch 23/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4639 +Epoch 24/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4596 +Epoch 25/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4555 +Epoch 26/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4516 +Epoch 27/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4479 +Epoch 28/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4445 +Epoch 29/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4411 +Epoch 30/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4380 +Epoch 31/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4350 +Epoch 32/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4322 +Epoch 33/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4295 +Epoch 34/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4269 +Epoch 35/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4245 +Epoch 36/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4222 +Epoch 37/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4199 +Epoch 38/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4178 +Epoch 39/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4157 +Epoch 40/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4136 +Epoch 41/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4117 +Epoch 42/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4098 +Epoch 43/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4079 +Epoch 44/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4062 +Epoch 45/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4045 +Epoch 46/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4028 +Epoch 47/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.4012 +Epoch 48/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3996 +Epoch 49/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3982 +Epoch 50/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3967 +Epoch 51/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3952 +Epoch 52/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3938 +Epoch 53/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3925 +Epoch 54/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3912 +Epoch 55/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3899 +Epoch 56/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3886 +Epoch 57/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3874 +Epoch 58/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3862 +Epoch 59/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3851 +Epoch 60/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3840 +Epoch 61/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3829 +Epoch 62/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3818 +Epoch 63/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3807 +Epoch 64/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3797 +Epoch 65/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3787 +Epoch 66/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3777 +Epoch 67/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3767 +Epoch 68/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3758 +Epoch 69/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3748 +Epoch 70/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3739 +Epoch 71/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3730 +Epoch 72/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3721 +Epoch 73/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3712 +Epoch 74/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3704 +Epoch 75/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3695 +Epoch 76/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3688 +Epoch 77/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3680 +Epoch 78/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3671 +Epoch 79/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3664 +Epoch 80/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3656 +Epoch 81/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3648 +Epoch 82/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3641 +Epoch 83/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3634 +Epoch 84/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3627 +Epoch 85/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3620 +Epoch 86/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3613 +Epoch 87/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3606 +Epoch 88/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3599 +Epoch 89/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3593 +Epoch 90/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3586 +Epoch 91/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3580 +Epoch 92/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3574 +Epoch 93/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3568 +Epoch 94/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3561 +Epoch 95/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3555 +Epoch 96/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3549 +Epoch 97/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3544 +Epoch 98/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3538 +Epoch 99/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3532 +Epoch 100/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3526 +Epoch 101/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3521 +Epoch 102/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3515 +Epoch 103/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3510 +Epoch 104/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3505 +Epoch 105/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3499 +Epoch 106/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3494 +Epoch 107/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3489 +Epoch 108/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3484 +Epoch 109/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3478 +Epoch 110/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3474 +Epoch 111/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3468 +Epoch 112/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3464 +Epoch 113/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3459 +Epoch 114/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3454 +Epoch 115/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3450 +Epoch 116/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3445 +Epoch 117/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3440 +Epoch 118/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3436 +Epoch 119/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3432 +Epoch 120/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3427 +Epoch 121/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3423 +Epoch 122/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3419 +Epoch 123/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3414 +Epoch 124/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3410 +Epoch 125/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3406 +Epoch 126/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3402 +Epoch 127/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3398 +Epoch 128/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3394 +Epoch 129/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3390 +Epoch 130/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3386 +Epoch 131/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3382 +Epoch 132/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3378 +Epoch 133/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3375 +Epoch 134/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3371 +Epoch 135/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3368 +Epoch 136/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3364 +Epoch 137/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3360 +Epoch 138/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3357 +Epoch 139/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3353 +Epoch 140/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3350 +Epoch 141/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3346 +Epoch 142/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3343 +Epoch 143/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3339 +Epoch 144/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3336 +Epoch 145/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3333 +Epoch 146/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3329 +Epoch 147/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3326 +Epoch 148/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3323 +Epoch 149/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3320 +Epoch 150/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3317 +Epoch 151/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3313 +Epoch 152/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3310 +Epoch 153/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3307 +Epoch 154/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3304 +Epoch 155/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3302 +Epoch 156/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3299 +Epoch 157/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3296 +Epoch 158/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3293 +Epoch 159/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3290 +Epoch 160/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3287 +Epoch 161/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3284 +Epoch 162/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3281 +Epoch 163/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3279 +Epoch 164/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3276 +Epoch 165/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3273 +Epoch 166/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3270 +Epoch 167/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3268 +Epoch 168/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3265 +Epoch 169/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3262 +Epoch 170/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3260 +Epoch 171/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3257 +Epoch 172/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3255 +Epoch 173/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3252 +Epoch 174/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3250 +Epoch 175/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3247 +Epoch 176/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3244 +Epoch 177/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3242 +Epoch 178/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3240 +Epoch 179/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3237 +Epoch 180/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3235 +Epoch 181/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3232 +Epoch 182/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3230 +Epoch 183/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3228 +Epoch 184/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3225 +Epoch 185/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3223 +Epoch 186/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3221 +Epoch 187/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3219 +Epoch 188/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3216 +Epoch 189/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3214 +Epoch 190/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3212 +Epoch 191/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3210 +Epoch 192/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3208 +Epoch 193/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3205 +Epoch 194/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3203 +Epoch 195/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3201 +Epoch 196/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3199 +Epoch 197/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3197 +Epoch 198/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3195 +Epoch 199/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3193 +Epoch 200/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3191 +Epoch 201/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3189 +Epoch 202/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3187 +Epoch 203/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3185 +Epoch 204/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3183 +Epoch 205/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3181 +Epoch 206/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3179 +Epoch 207/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3177 +Epoch 208/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3175 +Epoch 209/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3174 +Epoch 210/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3172 +Epoch 211/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3170 +Epoch 212/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3168 +Epoch 213/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3166 +Epoch 214/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3165 +Epoch 215/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3163 +Epoch 216/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3161 +Epoch 217/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3159 +Epoch 218/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3157 +Epoch 219/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3155 +Epoch 220/250 +1/1 [==============================] - 5s 5s/step - FinalLoss: 0.3154 +Epoch 221/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3152 +Epoch 222/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3150 +Epoch 223/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3148 +Epoch 224/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3147 +Epoch 225/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3145 +Epoch 226/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3143 +Epoch 227/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3142 +Epoch 228/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3140 +Epoch 229/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3139 +Epoch 230/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3137 +Epoch 231/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3135 +Epoch 232/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3134 +Epoch 233/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3132 +Epoch 234/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3131 +Epoch 235/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3129 +Epoch 236/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3127 +Epoch 237/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3126 +Epoch 238/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3124 +Epoch 239/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3123 +Epoch 240/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3121 +Epoch 241/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3120 +Epoch 242/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3118 +Epoch 243/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3117 +Epoch 244/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3116 +Epoch 245/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3114 +Epoch 246/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3113 +Epoch 247/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3111 +Epoch 248/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3110 +Epoch 249/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3108 +Epoch 250/250 +1/1 [==============================] - 6s 6s/step - FinalLoss: 0.3107 + +``` +
+--- +## Perform inference and testing + +Having trained the model to a large extent, we now see how it performs on the +test set. We calculate the Accuracy Score to understand the results closely. + + +```python +preds = model.predict(tf.convert_to_tensor(x_test)) + +preds = preds.reshape((preds.shape[0], preds.shape[1])) + +results = accuracy_score(preds, y_test) + +print(f"Test Accuracy score : {results*100}%") + +plt.plot(range(len(history.history["FinalLoss"])), history.history["FinalLoss"]) +plt.title("Loss over training") +plt.show() +``` + +
+``` +Test Accuracy score : 97.64% + +``` +
+![png](/img/examples/vision/forwardforward/forwardforward_15_1.png) + + +--- +## Conclusion + +This example has hereby demonstrated how the Forward-Forward algorithm works using +the TensorFlow and Keras packages. While the investigation results presented by Prof. Hinton +in their paper are currently still limited to smaller models and datasets like MNIST and +Fashion-MNIST, subsequent results on larger models like LLMs are expected in future +papers. + +Through the paper, Prof. Hinton has reported results of 1.36% test accuracy error with a +2000-units, 4 hidden-layer, fully-connected network run over 60 epochs (while mentioning +that backpropagation takes only 20 epochs to achieve similar performance). Another run of +doubling the learning rate and training for 40 epochs yields a slightly worse error rate +of 1.46% + +The current example does not yield state-of-the-art results. But with proper tuning of +the Learning Rate, model architecture (number of units in `Dense` layers, kernel +activations, initializations, regularization etc.), the results can be improved +to match the claims of the paper.