Skip to content

Commit c4577ac

Browse files
committed
add test-horovod.py
1 parent 61b7f92 commit c4577ac

File tree

1 file changed

+137
-0
lines changed

1 file changed

+137
-0
lines changed

tests/test-horovod.py

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import argparse
2+
3+
import keras
4+
from keras.datasets import mnist
5+
from keras.models import Sequential
6+
from keras.layers import Dense, Dropout, Flatten
7+
from keras.layers import Conv2D, MaxPooling2D
8+
from keras.preprocessing.image import ImageDataGenerator
9+
from keras import backend as K
10+
import tensorflow as tf
11+
import horovod.keras as hvd
12+
13+
# Training settings
14+
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
15+
parser.add_argument('--batch-size', type=int, default=128, metavar='N',
16+
help='input batch size for training (default: 128)')
17+
parser.add_argument('--epochs', type=int, default=24, metavar='N',
18+
help='number of epochs to train (default: 24)')
19+
parser.add_argument('--lr', type=float, default=1.0, metavar='LR',
20+
help='learning rate (default: 1.0)')
21+
22+
args = parser.parse_args()
23+
24+
# Horovod: initialize Horovod.
25+
hvd.init()
26+
27+
# Horovod: pin GPU to be used to process local rank (one GPU per process)
28+
config = tf.ConfigProto()
29+
config.gpu_options.allow_growth = True
30+
config.gpu_options.visible_device_list = str(hvd.local_rank())
31+
K.set_session(tf.Session(config=config))
32+
33+
batch_size = args.batch_size
34+
num_classes = 10
35+
36+
# Enough epochs to demonstrate learning rate warmup and the reduction of
37+
# learning rate when training plateaues.
38+
epochs = args.epochs
39+
40+
# Input image dimensions
41+
img_rows, img_cols = 28, 28
42+
43+
# The data, shuffled and split between train and test sets
44+
(x_train, y_train), (x_test, y_test) = mnist.load_data()
45+
46+
# Determine how many batches are there in train and test sets
47+
train_batches = len(x_train) // batch_size
48+
test_batches = len(x_test) // batch_size
49+
50+
if K.image_data_format() == 'channels_first':
51+
x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
52+
x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
53+
input_shape = (1, img_rows, img_cols)
54+
else:
55+
x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
56+
x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
57+
input_shape = (img_rows, img_cols, 1)
58+
59+
x_train = x_train.astype('float32')
60+
x_test = x_test.astype('float32')
61+
x_train /= 255
62+
x_test /= 255
63+
print('x_train shape:', x_train.shape)
64+
print(x_train.shape[0], 'train samples')
65+
print(x_test.shape[0], 'test samples')
66+
67+
# Convert class vectors to binary class matrices
68+
y_train = keras.utils.to_categorical(y_train, num_classes)
69+
y_test = keras.utils.to_categorical(y_test, num_classes)
70+
71+
model = Sequential()
72+
model.add(Conv2D(32, kernel_size=(3, 3),
73+
activation='relu',
74+
input_shape=input_shape))
75+
model.add(Conv2D(64, (3, 3), activation='relu'))
76+
model.add(MaxPooling2D(pool_size=(2, 2)))
77+
model.add(Dropout(0.25))
78+
model.add(Flatten())
79+
model.add(Dense(128, activation='relu'))
80+
model.add(Dropout(0.5))
81+
model.add(Dense(num_classes, activation='softmax'))
82+
83+
# Horovod: adjust learning rate based on number of GPUs.
84+
scaled_lr = args.lr * hvd.size()
85+
opt = keras.optimizers.Adadelta(lr=scaled_lr)
86+
87+
# Horovod: add Horovod Distributed Optimizer.
88+
opt = hvd.DistributedOptimizer(opt)
89+
90+
model.compile(loss=keras.losses.categorical_crossentropy,
91+
optimizer=opt,
92+
metrics=['accuracy'])
93+
94+
callbacks = [
95+
# Horovod: broadcast initial variable states from rank 0 to all other processes.
96+
# This is necessary to ensure consistent initialization of all workers when
97+
# training is started with random weights or restored from a checkpoint.
98+
hvd.callbacks.BroadcastGlobalVariablesCallback(0),
99+
100+
# Horovod: average metrics among workers at the end of every epoch.
101+
#
102+
# Note: This callback must be in the list before the ReduceLROnPlateau,
103+
# TensorBoard or other metrics-based callbacks.
104+
hvd.callbacks.MetricAverageCallback(),
105+
106+
hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5, initial_lr=scaled_lr, verbose=1),
107+
108+
# Reduce the learning rate if training plateaues.
109+
keras.callbacks.ReduceLROnPlateau(patience=10, verbose=1),
110+
]
111+
112+
# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them.
113+
if hvd.rank() == 0:
114+
callbacks.append(keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5'))
115+
116+
# Set up ImageDataGenerators to do data augmentation for the training images.
117+
train_gen = ImageDataGenerator(rotation_range=8, width_shift_range=0.08, shear_range=0.3,
118+
height_shift_range=0.08, zoom_range=0.08)
119+
test_gen = ImageDataGenerator()
120+
121+
# Train the model.
122+
# Horovod: the training will randomly sample 1 / N batches of training data and
123+
# 3 / N batches of validation data on every worker, where N is the number of workers.
124+
# Over-sampling of validation data helps to increase probability that every validation
125+
# example will be evaluated.
126+
model.fit_generator(train_gen.flow(x_train, y_train, batch_size=batch_size),
127+
steps_per_epoch=train_batches // hvd.size(),
128+
callbacks=callbacks,
129+
epochs=epochs,
130+
verbose=1,
131+
validation_data=test_gen.flow(x_test, y_test, batch_size=batch_size),
132+
validation_steps=3 * test_batches // hvd.size())
133+
134+
# Evaluate the model on the full data set.
135+
score = model.evaluate(x_test, y_test, verbose=0)
136+
print('Test loss:', score[0])
137+
print('Test accuracy:', score[1])

0 commit comments

Comments
 (0)