|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "metadata": {}, |
| 6 | + "source": [ |
| 7 | + "# Recurrent Neural Network in TensorFlow\n", |
| 8 | + "\n", |
| 9 | + "Credits: Forked from [TensorFlow-Examples](https://github.com/aymericdamien/TensorFlow-Examples) by Aymeric Damien\n", |
| 10 | + "\n", |
| 11 | + "## Setup\n", |
| 12 | + "\n", |
| 13 | + "Refer to the [setup instructions](http://nbviewer.ipython.org/github/donnemartin/data-science-ipython-notebooks/blob/master/deep-learning/tensor-flow-examples/Setup_TensorFlow.md)" |
| 14 | + ] |
| 15 | + }, |
| 16 | + { |
| 17 | + "cell_type": "code", |
| 18 | + "execution_count": 2, |
| 19 | + "metadata": { |
| 20 | + "collapsed": false |
| 21 | + }, |
| 22 | + "outputs": [ |
| 23 | + { |
| 24 | + "name": "stdout", |
| 25 | + "output_type": "stream", |
| 26 | + "text": [ |
| 27 | + "Extracting /tmp/data/train-images-idx3-ubyte.gz\n", |
| 28 | + "Extracting /tmp/data/train-labels-idx1-ubyte.gz\n", |
| 29 | + "Extracting /tmp/data/t10k-images-idx3-ubyte.gz\n", |
| 30 | + "Extracting /tmp/data/t10k-labels-idx1-ubyte.gz\n" |
| 31 | + ] |
| 32 | + } |
| 33 | + ], |
| 34 | + "source": [ |
| 35 | + "# Import MINST data\n", |
| 36 | + "import input_data\n", |
| 37 | + "mnist = input_data.read_data_sets(\"/tmp/data/\", one_hot=True)\n", |
| 38 | + "\n", |
| 39 | + "import tensorflow as tf\n", |
| 40 | + "from tensorflow.models.rnn import rnn, rnn_cell\n", |
| 41 | + "import numpy as np" |
| 42 | + ] |
| 43 | + }, |
| 44 | + { |
| 45 | + "cell_type": "code", |
| 46 | + "execution_count": 3, |
| 47 | + "metadata": { |
| 48 | + "collapsed": true |
| 49 | + }, |
| 50 | + "outputs": [], |
| 51 | + "source": [ |
| 52 | + "'''\n", |
| 53 | + "To classify images using a reccurent neural network, we consider every image row as a sequence of pixels.\n", |
| 54 | + "Because MNIST image shape is 28*28px, we will then handle 28 sequences of 28 steps for every sample.\n", |
| 55 | + "'''\n", |
| 56 | + "\n", |
| 57 | + "# Parameters\n", |
| 58 | + "learning_rate = 0.001\n", |
| 59 | + "training_iters = 100000\n", |
| 60 | + "batch_size = 128\n", |
| 61 | + "display_step = 10\n", |
| 62 | + "\n", |
| 63 | + "# Network Parameters\n", |
| 64 | + "n_input = 28 # MNIST data input (img shape: 28*28)\n", |
| 65 | + "n_steps = 28 # timesteps\n", |
| 66 | + "n_hidden = 128 # hidden layer num of features\n", |
| 67 | + "n_classes = 10 # MNIST total classes (0-9 digits)" |
| 68 | + ] |
| 69 | + }, |
| 70 | + { |
| 71 | + "cell_type": "code", |
| 72 | + "execution_count": 4, |
| 73 | + "metadata": { |
| 74 | + "collapsed": true |
| 75 | + }, |
| 76 | + "outputs": [], |
| 77 | + "source": [ |
| 78 | + "# tf Graph input\n", |
| 79 | + "x = tf.placeholder(\"float\", [None, n_steps, n_input])\n", |
| 80 | + "istate = tf.placeholder(\"float\", [None, 2*n_hidden]) #state & cell => 2x n_hidden\n", |
| 81 | + "y = tf.placeholder(\"float\", [None, n_classes])\n", |
| 82 | + "\n", |
| 83 | + "# Define weights\n", |
| 84 | + "weights = {\n", |
| 85 | + " 'hidden': tf.Variable(tf.random_normal([n_input, n_hidden])), # Hidden layer weights\n", |
| 86 | + " 'out': tf.Variable(tf.random_normal([n_hidden, n_classes]))\n", |
| 87 | + "}\n", |
| 88 | + "biases = {\n", |
| 89 | + " 'hidden': tf.Variable(tf.random_normal([n_hidden])),\n", |
| 90 | + " 'out': tf.Variable(tf.random_normal([n_classes]))\n", |
| 91 | + "}" |
| 92 | + ] |
| 93 | + }, |
| 94 | + { |
| 95 | + "cell_type": "code", |
| 96 | + "execution_count": 5, |
| 97 | + "metadata": { |
| 98 | + "collapsed": true |
| 99 | + }, |
| 100 | + "outputs": [], |
| 101 | + "source": [ |
| 102 | + "def RNN(_X, _istate, _weights, _biases):\n", |
| 103 | + "\n", |
| 104 | + " # input shape: (batch_size, n_steps, n_input)\n", |
| 105 | + " _X = tf.transpose(_X, [1, 0, 2]) # permute n_steps and batch_size\n", |
| 106 | + " # Reshape to prepare input to hidden activation\n", |
| 107 | + " _X = tf.reshape(_X, [-1, n_input]) # (n_steps*batch_size, n_input)\n", |
| 108 | + " # Linear activation\n", |
| 109 | + " _X = tf.matmul(_X, _weights['hidden']) + _biases['hidden']\n", |
| 110 | + "\n", |
| 111 | + " # Define a lstm cell with tensorflow\n", |
| 112 | + " lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)\n", |
| 113 | + " # Split data because rnn cell needs a list of inputs for the RNN inner loop\n", |
| 114 | + " _X = tf.split(0, n_steps, _X) # n_steps * (batch_size, n_hidden)\n", |
| 115 | + "\n", |
| 116 | + " # Get lstm cell output\n", |
| 117 | + " outputs, states = rnn.rnn(lstm_cell, _X, initial_state=_istate)\n", |
| 118 | + "\n", |
| 119 | + " # Linear activation\n", |
| 120 | + " # Get inner loop last output\n", |
| 121 | + " return tf.matmul(outputs[-1], _weights['out']) + _biases['out']" |
| 122 | + ] |
| 123 | + }, |
| 124 | + { |
| 125 | + "cell_type": "code", |
| 126 | + "execution_count": 6, |
| 127 | + "metadata": { |
| 128 | + "collapsed": false |
| 129 | + }, |
| 130 | + "outputs": [], |
| 131 | + "source": [ |
| 132 | + "pred = RNN(x, istate, weights, biases)\n", |
| 133 | + "\n", |
| 134 | + "# Define loss and optimizer\n", |
| 135 | + "cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y)) # Softmax loss\n", |
| 136 | + "optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) # Adam Optimizer\n", |
| 137 | + "\n", |
| 138 | + "# Evaluate model\n", |
| 139 | + "correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))\n", |
| 140 | + "accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.types.float32))" |
| 141 | + ] |
| 142 | + }, |
| 143 | + { |
| 144 | + "cell_type": "code", |
| 145 | + "execution_count": 7, |
| 146 | + "metadata": { |
| 147 | + "collapsed": false |
| 148 | + }, |
| 149 | + "outputs": [ |
| 150 | + { |
| 151 | + "name": "stdout", |
| 152 | + "output_type": "stream", |
| 153 | + "text": [ |
| 154 | + "Iter 1280, Minibatch Loss= 1.888242, Training Accuracy= 0.39844\n", |
| 155 | + "Iter 2560, Minibatch Loss= 1.519879, Training Accuracy= 0.47656\n", |
| 156 | + "Iter 3840, Minibatch Loss= 1.238005, Training Accuracy= 0.63281\n", |
| 157 | + "Iter 5120, Minibatch Loss= 0.933760, Training Accuracy= 0.71875\n", |
| 158 | + "Iter 6400, Minibatch Loss= 0.832130, Training Accuracy= 0.73438\n", |
| 159 | + "Iter 7680, Minibatch Loss= 0.979760, Training Accuracy= 0.70312\n", |
| 160 | + "Iter 8960, Minibatch Loss= 0.821921, Training Accuracy= 0.71875\n", |
| 161 | + "Iter 10240, Minibatch Loss= 0.710566, Training Accuracy= 0.79688\n", |
| 162 | + "Iter 11520, Minibatch Loss= 0.578501, Training Accuracy= 0.82812\n", |
| 163 | + "Iter 12800, Minibatch Loss= 0.765049, Training Accuracy= 0.75000\n", |
| 164 | + "Iter 14080, Minibatch Loss= 0.582995, Training Accuracy= 0.78125\n", |
| 165 | + "Iter 15360, Minibatch Loss= 0.575092, Training Accuracy= 0.79688\n", |
| 166 | + "Iter 16640, Minibatch Loss= 0.701214, Training Accuracy= 0.75781\n", |
| 167 | + "Iter 17920, Minibatch Loss= 0.561972, Training Accuracy= 0.78125\n", |
| 168 | + "Iter 19200, Minibatch Loss= 0.394480, Training Accuracy= 0.85938\n", |
| 169 | + "Iter 20480, Minibatch Loss= 0.356244, Training Accuracy= 0.91406\n", |
| 170 | + "Iter 21760, Minibatch Loss= 0.632163, Training Accuracy= 0.78125\n", |
| 171 | + "Iter 23040, Minibatch Loss= 0.269334, Training Accuracy= 0.90625\n", |
| 172 | + "Iter 24320, Minibatch Loss= 0.485007, Training Accuracy= 0.86719\n", |
| 173 | + "Iter 25600, Minibatch Loss= 0.569704, Training Accuracy= 0.78906\n", |
| 174 | + "Iter 26880, Minibatch Loss= 0.267697, Training Accuracy= 0.92188\n", |
| 175 | + "Iter 28160, Minibatch Loss= 0.381177, Training Accuracy= 0.90625\n", |
| 176 | + "Iter 29440, Minibatch Loss= 0.350800, Training Accuracy= 0.87500\n", |
| 177 | + "Iter 30720, Minibatch Loss= 0.356782, Training Accuracy= 0.90625\n", |
| 178 | + "Iter 32000, Minibatch Loss= 0.322511, Training Accuracy= 0.89062\n", |
| 179 | + "Iter 33280, Minibatch Loss= 0.309195, Training Accuracy= 0.90625\n", |
| 180 | + "Iter 34560, Minibatch Loss= 0.535408, Training Accuracy= 0.83594\n", |
| 181 | + "Iter 35840, Minibatch Loss= 0.281643, Training Accuracy= 0.92969\n", |
| 182 | + "Iter 37120, Minibatch Loss= 0.290962, Training Accuracy= 0.89844\n", |
| 183 | + "Iter 38400, Minibatch Loss= 0.204718, Training Accuracy= 0.93750\n", |
| 184 | + "Iter 39680, Minibatch Loss= 0.205882, Training Accuracy= 0.92969\n", |
| 185 | + "Iter 40960, Minibatch Loss= 0.481441, Training Accuracy= 0.84375\n", |
| 186 | + "Iter 42240, Minibatch Loss= 0.348245, Training Accuracy= 0.89844\n", |
| 187 | + "Iter 43520, Minibatch Loss= 0.274692, Training Accuracy= 0.90625\n", |
| 188 | + "Iter 44800, Minibatch Loss= 0.171815, Training Accuracy= 0.94531\n", |
| 189 | + "Iter 46080, Minibatch Loss= 0.171035, Training Accuracy= 0.93750\n", |
| 190 | + "Iter 47360, Minibatch Loss= 0.235800, Training Accuracy= 0.89844\n", |
| 191 | + "Iter 48640, Minibatch Loss= 0.235974, Training Accuracy= 0.93750\n", |
| 192 | + "Iter 49920, Minibatch Loss= 0.207323, Training Accuracy= 0.92188\n", |
| 193 | + "Iter 51200, Minibatch Loss= 0.212989, Training Accuracy= 0.91406\n", |
| 194 | + "Iter 52480, Minibatch Loss= 0.151774, Training Accuracy= 0.95312\n", |
| 195 | + "Iter 53760, Minibatch Loss= 0.090070, Training Accuracy= 0.96875\n", |
| 196 | + "Iter 55040, Minibatch Loss= 0.264714, Training Accuracy= 0.92969\n", |
| 197 | + "Iter 56320, Minibatch Loss= 0.235086, Training Accuracy= 0.92969\n", |
| 198 | + "Iter 57600, Minibatch Loss= 0.160302, Training Accuracy= 0.95312\n", |
| 199 | + "Iter 58880, Minibatch Loss= 0.106515, Training Accuracy= 0.96875\n", |
| 200 | + "Iter 60160, Minibatch Loss= 0.236039, Training Accuracy= 0.94531\n", |
| 201 | + "Iter 61440, Minibatch Loss= 0.279540, Training Accuracy= 0.90625\n", |
| 202 | + "Iter 62720, Minibatch Loss= 0.173585, Training Accuracy= 0.93750\n", |
| 203 | + "Iter 64000, Minibatch Loss= 0.191009, Training Accuracy= 0.92188\n", |
| 204 | + "Iter 65280, Minibatch Loss= 0.210331, Training Accuracy= 0.89844\n", |
| 205 | + "Iter 66560, Minibatch Loss= 0.223444, Training Accuracy= 0.94531\n", |
| 206 | + "Iter 67840, Minibatch Loss= 0.278210, Training Accuracy= 0.91406\n", |
| 207 | + "Iter 69120, Minibatch Loss= 0.174290, Training Accuracy= 0.95312\n", |
| 208 | + "Iter 70400, Minibatch Loss= 0.188701, Training Accuracy= 0.94531\n", |
| 209 | + "Iter 71680, Minibatch Loss= 0.210277, Training Accuracy= 0.94531\n", |
| 210 | + "Iter 72960, Minibatch Loss= 0.249951, Training Accuracy= 0.95312\n", |
| 211 | + "Iter 74240, Minibatch Loss= 0.209853, Training Accuracy= 0.92188\n", |
| 212 | + "Iter 75520, Minibatch Loss= 0.049742, Training Accuracy= 0.99219\n", |
| 213 | + "Iter 76800, Minibatch Loss= 0.250095, Training Accuracy= 0.92969\n", |
| 214 | + "Iter 78080, Minibatch Loss= 0.133853, Training Accuracy= 0.95312\n", |
| 215 | + "Iter 79360, Minibatch Loss= 0.110206, Training Accuracy= 0.97656\n", |
| 216 | + "Iter 80640, Minibatch Loss= 0.141906, Training Accuracy= 0.93750\n", |
| 217 | + "Iter 81920, Minibatch Loss= 0.126872, Training Accuracy= 0.94531\n", |
| 218 | + "Iter 83200, Minibatch Loss= 0.138925, Training Accuracy= 0.95312\n", |
| 219 | + "Iter 84480, Minibatch Loss= 0.128652, Training Accuracy= 0.96094\n", |
| 220 | + "Iter 85760, Minibatch Loss= 0.099837, Training Accuracy= 0.96094\n", |
| 221 | + "Iter 87040, Minibatch Loss= 0.119000, Training Accuracy= 0.95312\n", |
| 222 | + "Iter 88320, Minibatch Loss= 0.179807, Training Accuracy= 0.95312\n", |
| 223 | + "Iter 89600, Minibatch Loss= 0.141792, Training Accuracy= 0.96094\n", |
| 224 | + "Iter 90880, Minibatch Loss= 0.142424, Training Accuracy= 0.96094\n", |
| 225 | + "Iter 92160, Minibatch Loss= 0.159564, Training Accuracy= 0.96094\n", |
| 226 | + "Iter 93440, Minibatch Loss= 0.111984, Training Accuracy= 0.95312\n", |
| 227 | + "Iter 94720, Minibatch Loss= 0.238978, Training Accuracy= 0.92969\n", |
| 228 | + "Iter 96000, Minibatch Loss= 0.068002, Training Accuracy= 0.97656\n", |
| 229 | + "Iter 97280, Minibatch Loss= 0.191819, Training Accuracy= 0.94531\n", |
| 230 | + "Iter 98560, Minibatch Loss= 0.081197, Training Accuracy= 0.99219\n", |
| 231 | + "Iter 99840, Minibatch Loss= 0.206797, Training Accuracy= 0.95312\n", |
| 232 | + "Optimization Finished!\n", |
| 233 | + "Testing Accuracy: 0.941406\n" |
| 234 | + ] |
| 235 | + } |
| 236 | + ], |
| 237 | + "source": [ |
| 238 | + "# Initializing the variables\n", |
| 239 | + "init = tf.initialize_all_variables()\n", |
| 240 | + "\n", |
| 241 | + "# Launch the graph\n", |
| 242 | + "with tf.Session() as sess:\n", |
| 243 | + " sess.run(init)\n", |
| 244 | + " step = 1\n", |
| 245 | + " # Keep training until reach max iterations\n", |
| 246 | + " while step * batch_size < training_iters:\n", |
| 247 | + " batch_xs, batch_ys = mnist.train.next_batch(batch_size)\n", |
| 248 | + " # Reshape data to get 28 seq of 28 elements\n", |
| 249 | + " batch_xs = batch_xs.reshape((batch_size, n_steps, n_input))\n", |
| 250 | + " # Fit training using batch data\n", |
| 251 | + " sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys,\n", |
| 252 | + " istate: np.zeros((batch_size, 2*n_hidden))})\n", |
| 253 | + " if step % display_step == 0:\n", |
| 254 | + " # Calculate batch accuracy\n", |
| 255 | + " acc = sess.run(accuracy, feed_dict={x: batch_xs, y: batch_ys,\n", |
| 256 | + " istate: np.zeros((batch_size, 2*n_hidden))})\n", |
| 257 | + " # Calculate batch loss\n", |
| 258 | + " loss = sess.run(cost, feed_dict={x: batch_xs, y: batch_ys,\n", |
| 259 | + " istate: np.zeros((batch_size, 2*n_hidden))})\n", |
| 260 | + " print \"Iter \" + str(step*batch_size) + \", Minibatch Loss= \" + \"{:.6f}\".format(loss) + \\\n", |
| 261 | + " \", Training Accuracy= \" + \"{:.5f}\".format(acc)\n", |
| 262 | + " step += 1\n", |
| 263 | + " print \"Optimization Finished!\"\n", |
| 264 | + " # Calculate accuracy for 256 mnist test images\n", |
| 265 | + " test_len = 256\n", |
| 266 | + " test_data = mnist.test.images[:test_len].reshape((-1, n_steps, n_input))\n", |
| 267 | + " test_label = mnist.test.labels[:test_len]\n", |
| 268 | + " print \"Testing Accuracy:\", sess.run(accuracy, feed_dict={x: test_data, y: test_label,\n", |
| 269 | + " istate: np.zeros((test_len, 2*n_hidden))})" |
| 270 | + ] |
| 271 | + } |
| 272 | + ], |
| 273 | + "metadata": { |
| 274 | + "kernelspec": { |
| 275 | + "display_name": "Python 3", |
| 276 | + "language": "python", |
| 277 | + "name": "python3" |
| 278 | + }, |
| 279 | + "language_info": { |
| 280 | + "codemirror_mode": { |
| 281 | + "name": "ipython", |
| 282 | + "version": 3 |
| 283 | + }, |
| 284 | + "file_extension": ".py", |
| 285 | + "mimetype": "text/x-python", |
| 286 | + "name": "python", |
| 287 | + "nbconvert_exporter": "python", |
| 288 | + "pygments_lexer": "ipython3", |
| 289 | + "version": "3.4.3" |
| 290 | + } |
| 291 | + }, |
| 292 | + "nbformat": 4, |
| 293 | + "nbformat_minor": 0 |
| 294 | +} |
0 commit comments