forked from frsong/tf-examples
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathchar_rnn_train.py
More file actions
100 lines (77 loc) · 3.07 KB
/
char_rnn_train.py
File metadata and controls
100 lines (77 loc) · 3.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
"""
Simple char-rnn based on
https://github.com/sherjilozair/char-rnn-tensorflow
Original article:
http://karpathy.github.io/2015/05/21/rnn-effectiveness/
"""
import os
import pickle
import numpy as np
import tensorflow as tf
from char_rnn_model import Model
from char_rnn_reader import Reader
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_string('data_dir', 'datasets/tinyshakespeare',
"data directory")
tf.app.flags.DEFINE_string('save_dir', 'save/char-rnn', "save directory")
tf.app.flags.DEFINE_string('log_dir', 'logs/char-rnn', "log directory")
tf.app.flags.DEFINE_integer('batch_size', 50, "batch size")
tf.app.flags.DEFINE_integer('num_epochs', 50, "number of epochs to train")
tf.app.flags.DEFINE_integer('seq_length', 50, "sequence length")
def train():
# Load data
data = Reader(FLAGS.data_dir, FLAGS.batch_size, FLAGS.seq_length)
vocab_size = len(data.chars)
# Setup directories
if not os.path.isdir(FLAGS.save_dir):
os.makedirs(FLAGS.save_dir)
filename = os.path.join(FLAGS.save_dir, 'chars_vocab.pkl')
with open(filename, 'wb') as f:
pickle.dump([data.chars, data.vocab], f)
# Model
model = Model(vocab_size, training=True)
# Saver
saver = tf.train.Saver(tf.global_variables())
# Summary
summary_op = tf.summary.merge_all()
# Print list of variables
print("")
print("Variables")
print("---------")
variables = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
num_params = 0
for v in variables:
num_params += np.prod(v.get_shape().as_list())
print(v.name, v.get_shape())
print("=> Total number of parameters =", num_params)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# For TensorBoard
writer = tf.summary.FileWriter(FLAGS.log_dir, sess.graph)
# Minimize the loss function
for epoch in range(FLAGS.num_epochs):
state = sess.run(model.initial_state)
current_loss = 0
for b in range(data.num_batches):
batch = data.next_batch()
feed_dict = {model.x: batch[0], model.y: batch[1]}
for layer, (c, h) in enumerate(model.initial_state):
feed_dict[c] = state[layer].c
feed_dict[h] = state[layer].h
fetches = [model.train_op, model.loss, model.final_state,
summary_op]
_, loss, state, summary = sess.run(fetches, feed_dict)
current_loss += loss
# Add to summary
writer.add_summary(summary, epoch*data.num_batches + b)
# Progress report
print("After {} epochs, loss = {}"
.format(epoch+1, current_loss/data.num_batches))
# Save
ckpt_path = os.path.join(FLAGS.save_dir, 'model.ckpt')
saver.save(sess, ckpt_path, global_step=epoch+1)
#///////////////////////////////////////////////////////////////////////////////
def main(_):
train()
if __name__ == '__main__':
tf.app.run()