|
| 1 | +import argparse |
| 2 | + |
| 3 | +import keras |
| 4 | +from keras.datasets import mnist |
| 5 | +from keras.models import Sequential |
| 6 | +from keras.layers import Dense, Dropout, Flatten |
| 7 | +from keras.layers import Conv2D, MaxPooling2D |
| 8 | +from keras.preprocessing.image import ImageDataGenerator |
| 9 | +from keras import backend as K |
| 10 | +import tensorflow as tf |
| 11 | +import horovod.keras as hvd |
| 12 | + |
| 13 | +# Training settings |
| 14 | +parser = argparse.ArgumentParser(description='PyTorch MNIST Example') |
| 15 | +parser.add_argument('--batch-size', type=int, default=128, metavar='N', |
| 16 | + help='input batch size for training (default: 128)') |
| 17 | +parser.add_argument('--epochs', type=int, default=24, metavar='N', |
| 18 | + help='number of epochs to train (default: 24)') |
| 19 | +parser.add_argument('--lr', type=float, default=1.0, metavar='LR', |
| 20 | + help='learning rate (default: 1.0)') |
| 21 | + |
| 22 | +args = parser.parse_args() |
| 23 | + |
| 24 | +# Horovod: initialize Horovod. |
| 25 | +hvd.init() |
| 26 | + |
| 27 | +# Horovod: pin GPU to be used to process local rank (one GPU per process) |
| 28 | +config = tf.ConfigProto() |
| 29 | +config.gpu_options.allow_growth = True |
| 30 | +config.gpu_options.visible_device_list = str(hvd.local_rank()) |
| 31 | +K.set_session(tf.Session(config=config)) |
| 32 | + |
| 33 | +batch_size = args.batch_size |
| 34 | +num_classes = 10 |
| 35 | + |
| 36 | +# Enough epochs to demonstrate learning rate warmup and the reduction of |
| 37 | +# learning rate when training plateaues. |
| 38 | +epochs = args.epochs |
| 39 | + |
| 40 | +# Input image dimensions |
| 41 | +img_rows, img_cols = 28, 28 |
| 42 | + |
| 43 | +# The data, shuffled and split between train and test sets |
| 44 | +(x_train, y_train), (x_test, y_test) = mnist.load_data() |
| 45 | + |
| 46 | +# Determine how many batches are there in train and test sets |
| 47 | +train_batches = len(x_train) // batch_size |
| 48 | +test_batches = len(x_test) // batch_size |
| 49 | + |
| 50 | +if K.image_data_format() == 'channels_first': |
| 51 | + x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols) |
| 52 | + x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols) |
| 53 | + input_shape = (1, img_rows, img_cols) |
| 54 | +else: |
| 55 | + x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1) |
| 56 | + x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1) |
| 57 | + input_shape = (img_rows, img_cols, 1) |
| 58 | + |
| 59 | +x_train = x_train.astype('float32') |
| 60 | +x_test = x_test.astype('float32') |
| 61 | +x_train /= 255 |
| 62 | +x_test /= 255 |
| 63 | +print('x_train shape:', x_train.shape) |
| 64 | +print(x_train.shape[0], 'train samples') |
| 65 | +print(x_test.shape[0], 'test samples') |
| 66 | + |
| 67 | +# Convert class vectors to binary class matrices |
| 68 | +y_train = keras.utils.to_categorical(y_train, num_classes) |
| 69 | +y_test = keras.utils.to_categorical(y_test, num_classes) |
| 70 | + |
| 71 | +model = Sequential() |
| 72 | +model.add(Conv2D(32, kernel_size=(3, 3), |
| 73 | + activation='relu', |
| 74 | + input_shape=input_shape)) |
| 75 | +model.add(Conv2D(64, (3, 3), activation='relu')) |
| 76 | +model.add(MaxPooling2D(pool_size=(2, 2))) |
| 77 | +model.add(Dropout(0.25)) |
| 78 | +model.add(Flatten()) |
| 79 | +model.add(Dense(128, activation='relu')) |
| 80 | +model.add(Dropout(0.5)) |
| 81 | +model.add(Dense(num_classes, activation='softmax')) |
| 82 | + |
| 83 | +# Horovod: adjust learning rate based on number of GPUs. |
| 84 | +scaled_lr = args.lr * hvd.size() |
| 85 | +opt = keras.optimizers.Adadelta(lr=scaled_lr) |
| 86 | + |
| 87 | +# Horovod: add Horovod Distributed Optimizer. |
| 88 | +opt = hvd.DistributedOptimizer(opt) |
| 89 | + |
| 90 | +model.compile(loss=keras.losses.categorical_crossentropy, |
| 91 | + optimizer=opt, |
| 92 | + metrics=['accuracy']) |
| 93 | + |
| 94 | +callbacks = [ |
| 95 | + # Horovod: broadcast initial variable states from rank 0 to all other processes. |
| 96 | + # This is necessary to ensure consistent initialization of all workers when |
| 97 | + # training is started with random weights or restored from a checkpoint. |
| 98 | + hvd.callbacks.BroadcastGlobalVariablesCallback(0), |
| 99 | + |
| 100 | + # Horovod: average metrics among workers at the end of every epoch. |
| 101 | + # |
| 102 | + # Note: This callback must be in the list before the ReduceLROnPlateau, |
| 103 | + # TensorBoard or other metrics-based callbacks. |
| 104 | + hvd.callbacks.MetricAverageCallback(), |
| 105 | + |
| 106 | + hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5, initial_lr=scaled_lr, verbose=1), |
| 107 | + |
| 108 | + # Reduce the learning rate if training plateaues. |
| 109 | + keras.callbacks.ReduceLROnPlateau(patience=10, verbose=1), |
| 110 | +] |
| 111 | + |
| 112 | +# Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them. |
| 113 | +if hvd.rank() == 0: |
| 114 | + callbacks.append(keras.callbacks.ModelCheckpoint('./checkpoint-{epoch}.h5')) |
| 115 | + |
| 116 | +# Set up ImageDataGenerators to do data augmentation for the training images. |
| 117 | +train_gen = ImageDataGenerator(rotation_range=8, width_shift_range=0.08, shear_range=0.3, |
| 118 | + height_shift_range=0.08, zoom_range=0.08) |
| 119 | +test_gen = ImageDataGenerator() |
| 120 | + |
| 121 | +# Train the model. |
| 122 | +# Horovod: the training will randomly sample 1 / N batches of training data and |
| 123 | +# 3 / N batches of validation data on every worker, where N is the number of workers. |
| 124 | +# Over-sampling of validation data helps to increase probability that every validation |
| 125 | +# example will be evaluated. |
| 126 | +model.fit_generator(train_gen.flow(x_train, y_train, batch_size=batch_size), |
| 127 | + steps_per_epoch=train_batches // hvd.size(), |
| 128 | + callbacks=callbacks, |
| 129 | + epochs=epochs, |
| 130 | + verbose=1, |
| 131 | + validation_data=test_gen.flow(x_test, y_test, batch_size=batch_size), |
| 132 | + validation_steps=3 * test_batches // hvd.size()) |
| 133 | + |
| 134 | +# Evaluate the model on the full data set. |
| 135 | +score = model.evaluate(x_test, y_test, verbose=0) |
| 136 | +print('Test loss:', score[0]) |
| 137 | +print('Test accuracy:', score[1]) |
0 commit comments