forked from frsong/tf-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmnist_cnn_train.py
More file actions
90 lines (71 loc) · 2.8 KB
/
mnist_cnn_train.py
File metadata and controls
90 lines (71 loc) · 2.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
"""
MNIST handwritten digit classification with a convolutional neural network.
"""
import os
import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
from mnist_cnn_model import Model
def train(save_dir='save/mnist', log_dir='logs/mnist'):
# Load data
data = input_data.read_data_sets('datasets/mnist', one_hot=True)
# Setup directories
if not os.path.isdir(save_dir):
os.makedirs(save_dir)
# Hyperparameters
learning_rate = 1e-4
num_epochs = 20
batch_size = 50
# Model
model = Model(learning_rate)
# Saver
saver = tf.train.Saver(tf.global_variables())
# Summary
summary_op = tf.summary.merge_all()
# Seed the random number generator for reproducible batches
np.random.seed(0)
# Print list of variables
print("")
print("Variables")
print("---------")
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
num_params = 0
for v in variables:
num_params += np.prod(v.get_shape().as_list())
print(v.name, v.get_shape())
print("=> Total number of parameters =", num_params)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# For TensorBoard
writer = tf.summary.FileWriter(log_dir, sess.graph)
# Minimize the loss function
num_batches_per_epoch = data.train.num_examples // batch_size
for epoch in range(num_epochs):
# Present one mini-batch at a time
for b in range(num_batches_per_epoch):
batch = data.train.next_batch(batch_size)
feed_dict = {model.x: batch[0],
model.y: batch[1],
model.keep_prob: 0.5}
_, summary = sess.run([model.train_op, summary_op], feed_dict)
# Add to summary
writer.add_summary(summary, epoch*num_batches_per_epoch + b)
# Progress report
feed_dict = {model.x: data.validation.images,
model.y: data.validation.labels,
model.keep_prob: 1.0}
accuracy = sess.run(model.accuracy_op, feed_dict)
print("After {} epochs, validation accuracy = {}"
.format(epoch+1, accuracy))
# Save
ckpt_path = os.path.join(save_dir, 'model.ckpt')
saver.save(sess, ckpt_path, global_step=epoch+1)
# Test accuracy
feed_dict = {model.x: data.test.images,
model.y: data.test.labels,
model.keep_prob: 1.0}
accuracy = sess.run(model.accuracy_op, feed_dict)
print("Test accuracy =", accuracy)
#///////////////////////////////////////////////////////////////////////////////
if __name__ == '__main__':
train()