|
| 1 | +import keras |
| 2 | +from keras.models import Sequential |
| 3 | +from keras.layers import Dense, Activation, Dropout |
| 4 | +from keras.layers import LSTM |
| 5 | +from keras.optimizers import RMSprop |
| 6 | +import numpy as np |
| 7 | +import os |
| 8 | + |
| 9 | +import datetime |
| 10 | +import cPickle |
| 11 | + |
| 12 | + |
| 13 | + |
| 14 | +class LossHistory( keras.callbacks.Callback ): |
| 15 | + def on_train_begin( self, logs= {} ): |
| 16 | + self.losses = [] |
| 17 | + |
| 18 | + def on_batch_end( self, batch, logs= {} ): |
| 19 | + self.losses.append( logs.get( 'loss' ) ) |
| 20 | + |
| 21 | + |
| 22 | +rnn_size = 256 |
| 23 | + |
| 24 | + |
| 25 | +# load data from pickle |
| 26 | +f = open( 'data.pkl', 'r' ) |
| 27 | + |
| 28 | +classes = cPickle.load( f ) |
| 29 | +chars = cPickle.load( f ) |
| 30 | +char_indices = cPickle.load( f ) |
| 31 | +indices_char = cPickle.load( f ) |
| 32 | + |
| 33 | +maxlen = cPickle.load( f ) |
| 34 | +step = cPickle.load( f ) |
| 35 | + |
| 36 | +X_ind = cPickle.load( f ) |
| 37 | +y_ind = cPickle.load( f ) |
| 38 | + |
| 39 | +f.close() |
| 40 | + |
| 41 | +[ s1, s2 ] = X_ind.shape |
| 42 | + |
| 43 | +X = np.zeros( ( s1, s2, len( chars ) ), dtype=np.bool ) |
| 44 | +y = np.zeros( ( s1, len( chars ) ), dtype=np.bool ) |
| 45 | + |
| 46 | +for i in range( s1 ): |
| 47 | + for t in range( s2 ): |
| 48 | + X[ i, t, X_ind[ i, t ] ] = 1 |
| 49 | + y[ i, y_ind[ i ] ] = 1 |
| 50 | + |
| 51 | +# build the model: a single LSTM |
| 52 | +print( 'Build model...' ) |
| 53 | +model = Sequential() |
| 54 | +model.add( LSTM( rnn_size, input_shape=( maxlen, len( chars ) ) ) ) |
| 55 | +model.add( Dense( len( chars ) ) ) |
| 56 | +model.add( Activation( 'softmax' ) ) |
| 57 | + |
| 58 | +optimizer = RMSprop( lr= 0.001 ) |
| 59 | +model.compile( loss= 'categorical_crossentropy', optimizer= optimizer ) |
| 60 | + |
| 61 | + |
| 62 | +def sample(preds, temperature=1.0): |
| 63 | + # helper function to sample an index from a probability array |
| 64 | + preds = np.asarray(preds).astype('float64') |
| 65 | + preds = np.log(preds) / temperature |
| 66 | + exp_preds = np.exp(preds) |
| 67 | + preds = exp_preds / np.sum(exp_preds) |
| 68 | + probas = np.random.multinomial(1, preds, 1) |
| 69 | + return np.argmax(probas) |
| 70 | + |
| 71 | +# train the model, output generated text after each iteration |
| 72 | +min_loss = 1e15 |
| 73 | +loss_count = 0 |
| 74 | + |
| 75 | +for iteration in range(1, 100): |
| 76 | + print() |
| 77 | + print('-' * 50) |
| 78 | + print('Iteration', iteration) |
| 79 | + |
| 80 | + history = LossHistory() |
| 81 | + model.fit( X, y, batch_size= 100, nb_epoch= 1, callbacks= [ history ] ) |
| 82 | + |
| 83 | + loss = history.losses[ -1 ] |
| 84 | + print( loss ) |
| 85 | + |
| 86 | + if loss < min_loss: |
| 87 | + min_loss = loss |
| 88 | + loss_count = 0 |
| 89 | + else: |
| 90 | + loss_count = loss_count + 1 |
| 91 | + if loss_count > 4: |
| 92 | + break |
| 93 | + |
| 94 | + dirname = str( rnn_size ) + "/" + str( maxlen ) |
| 95 | + if not os.path.exists( dirname ): |
| 96 | + os.makedirs( dirname ) |
| 97 | + |
| 98 | + # serialize model to JSON |
| 99 | + model_json = model.to_json() |
| 100 | + with open( dirname + "/model_" + str( iteration ) + "." + str( round( loss, 6 ) ) + ".json", "w" ) as json_file: |
| 101 | + json_file.write( model_json ) |
| 102 | + # serialize weights to HDF5 |
| 103 | + model.save_weights( dirname + "/model_" + str( iteration ) + "." + str( round( loss, 6 ) ) + ".h5" ) |
| 104 | + print( "Checkpoint saved." ) |
| 105 | + |
| 106 | + outtext = open( dirname + "/example_" + str( iteration ) + "." + str( round( loss, 6 ) ) + ".txt", "w" ) |
| 107 | + |
| 108 | + for diversity in [0.2, 0.5, 1.0, 1.2]: |
| 109 | + outtext.write('----- diversity:' + str( diversity ) + "\n" ) |
| 110 | + |
| 111 | + generated = '' |
| 112 | + seedstr = "Diagnosis" |
| 113 | + outtext.write('----- Generating with seed: "' + seedstr + '"' + "\n" ) |
| 114 | + |
| 115 | + sentence = " " * maxlen |
| 116 | + |
| 117 | + # class_index = 0 |
| 118 | + generated += sentence |
| 119 | + outtext.write( generated ) |
| 120 | + |
| 121 | + for c in seedstr: |
| 122 | + sentence = sentence[1:] + c |
| 123 | + x = np.zeros( ( 1, maxlen, len( chars ) ) ) |
| 124 | + for t, char in enumerate(sentence): |
| 125 | + x[ 0, t, char_indices[ char ] ] = 1. |
| 126 | + |
| 127 | + preds = model.predict(x, verbose=0)[0] |
| 128 | + next_index = sample(preds, diversity) |
| 129 | + next_char = indices_char[next_index] |
| 130 | + |
| 131 | + generated += c |
| 132 | + |
| 133 | + outtext.write( c ) |
| 134 | + |
| 135 | + |
| 136 | + for i in range( 400 ): |
| 137 | + x = np.zeros( ( 1, maxlen, len( chars ) ) ) |
| 138 | + for t, char in enumerate(sentence): |
| 139 | + x[ 0, t, char_indices[ char ] ] = 1. |
| 140 | + |
| 141 | + preds = model.predict(x, verbose=0)[0] |
| 142 | + next_index = sample(preds, diversity) |
| 143 | + next_char = indices_char[next_index] |
| 144 | + |
| 145 | + generated += next_char |
| 146 | + sentence = sentence[1:] + next_char |
| 147 | + |
| 148 | + outtext.write(next_char) |
| 149 | + |
| 150 | + outtext.write( "\n" ) |
| 151 | + |
| 152 | + outtext.close() |
0 commit comments