|
| 1 | +import keras |
| 2 | +import numpy as np |
| 3 | +from keras import backend as K |
| 4 | +from keras.datasets import mnist |
| 5 | +from keras.models import Sequential |
| 6 | +from keras.layers import Dense, Dropout, Activation, Flatten, Conv2D, MaxPooling2D |
| 7 | +from keras.regularizers import l2 |
| 8 | +from keras.wrappers.scikit_learn import KerasClassifier |
| 9 | +from modAL.models import ActiveLearner |
| 10 | + |
| 11 | +def create_keras_model(): |
| 12 | + model = Sequential() |
| 13 | + model.add(Conv2D(32, (4, 4), activation='relu')) |
| 14 | + model.add(Conv2D(32, (4, 4), activation='relu')) |
| 15 | + model.add(MaxPooling2D(pool_size=(2, 2))) |
| 16 | + model.add(Dropout(0.25)) |
| 17 | + model.add(Flatten()) |
| 18 | + model.add(Dense(128, activation='relu')) |
| 19 | + model.add(Dropout(0.5)) |
| 20 | + model.add(Dense(10, activation='softmax')) |
| 21 | + model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=["accuracy"]) |
| 22 | + return model |
| 23 | + |
| 24 | + |
| 25 | +# create the classifier |
| 26 | +classifier = KerasClassifier(create_keras_model) |
| 27 | + |
| 28 | +# read training data |
| 29 | +(X_train, y_train), (X_test, y_test) = mnist.load_data() |
| 30 | + |
| 31 | + |
| 32 | +# assemble initial data |
| 33 | +initial_idx = np.array([],dtype=np.int) |
| 34 | +for i in range(10): |
| 35 | + idx = np.random.choice(np.where(y_train==i)[0], size=2, replace=False) |
| 36 | + initial_idx = np.concatenate((initial_idx, idx)) |
| 37 | + |
| 38 | +# Preprocessing |
| 39 | +X_train = X_train.reshape(60000, 28, 28, 1).astype('float32') / 255. |
| 40 | +X_test = X_test.reshape(10000, 28, 28, 1).astype('float32') / 255. |
| 41 | +y_train = keras.utils.to_categorical(y_train, 10) |
| 42 | +y_test = keras.utils.to_categorical(y_test, 10) |
| 43 | + |
| 44 | +X_initial = X_train[initial_idx] |
| 45 | +y_initial = y_train[initial_idx] |
| 46 | + |
| 47 | +# remove the initial data from the pool of unlabelled examples |
| 48 | +X_pool = np.delete(X_train, initial_idx, axis=0) |
| 49 | +y_pool = np.delete(y_train, initial_idx, axis=0) |
| 50 | + |
| 51 | +""" |
| 52 | +Query Strategy |
| 53 | +""" |
| 54 | + |
| 55 | +def max_entropy(learner, X, n_instances=1, T=100): |
| 56 | + random_subset = np.random.choice(X.shape[0], 2000, replace=False) |
| 57 | + MC_output = K.function([learner.estimator.model.layers[0].input, K.learning_phase()], |
| 58 | + [learner.estimator.model.layers[-1].output]) |
| 59 | + learning_phase = True |
| 60 | + MC_samples = [MC_output([X[random_subset], learning_phase])[0] for _ in range(T)] |
| 61 | + MC_samples = np.array(MC_samples) # [#samples x batch size x #classes] |
| 62 | + expected_p = np.mean(MC_samples, axis=0) |
| 63 | + acquisition = - np.sum(expected_p * np.log(expected_p + 1e-10), axis=-1) # [batch size] |
| 64 | + idx = (-acquisition).argsort()[:n_instances] |
| 65 | + query_idx = random_subset[idx] |
| 66 | + return query_idx, X[query_idx] |
| 67 | + |
| 68 | +def uniform(learner, X, n_instances=1): |
| 69 | + query_idx = np.random.choice(range(len(X)), size=n_instances, replace=False) |
| 70 | + return query_idx, X[query_idx] |
| 71 | + |
| 72 | +""" |
| 73 | +Training the ActiveLearner |
| 74 | +""" |
| 75 | + |
| 76 | +# initialize ActiveLearner |
| 77 | +learner = ActiveLearner( |
| 78 | + estimator=classifier, |
| 79 | + X_training=X_initial, |
| 80 | + y_training=y_initial, |
| 81 | + query_strategy=max_entropy, |
| 82 | + verbose=0 |
| 83 | +) |
| 84 | + |
| 85 | +# the active learning loop |
| 86 | +n_queries = 100 |
| 87 | +perf_hist = [learner.score(X_test, y_test, verbose=0)] |
| 88 | +for index in range(n_queries): |
| 89 | + query_idx, query_instance = learner.query(X_pool, n_instances=10) |
| 90 | + learner.teach(X_pool[query_idx], y_pool[query_idx], epochs=50, batch_size=128, verbose=0) |
| 91 | + # remove queried instance from pool |
| 92 | + X_pool = np.delete(X_pool, query_idx, axis=0) |
| 93 | + y_pool = np.delete(y_pool, query_idx, axis=0) |
| 94 | + model_accuracy = learner.score(X_test, y_test, verbose=0) |
| 95 | + print('Accuracy after query {n}: {acc:0.4f}'.format(n=index + 1, acc=model_accuracy)) |
| 96 | + perf_hist.append(model_accuracy) |
0 commit comments