|
| 1 | +import keras |
| 2 | +from keras.models import Sequential |
| 3 | +from keras.layers import Dense, Activation, Dropout |
| 4 | +from keras.layers import LSTM |
| 5 | +from keras.optimizers import RMSprop |
| 6 | +import numpy as np |
| 7 | +import os |
| 8 | + |
| 9 | +import datetime |
| 10 | +import pickle |
| 11 | + |
| 12 | +import argparse |
| 13 | +import sys |
| 14 | + |
| 15 | +import p3b2 as bmk |
| 16 | +import candle_keras as candle |
| 17 | + |
| 18 | +def initialize_parameters(): |
| 19 | + |
| 20 | + # Build benchmark object |
| 21 | + p3b2Bmk = bmk.BenchmarkP3B2(bmk.file_path, 'p3b2_default_model.txt', 'keras', |
| 22 | + prog='p3b2_baseline', desc='Multi-task (DNN) for data extraction from clinical reports - Pilot 3 Benchmark 1') |
| 23 | + |
| 24 | + # Initialize parameters |
| 25 | + gParameters = candle.initialize_parameters(p3b2Bmk) |
| 26 | + #bmk.logger.info('Params: {}'.format(gParameters)) |
| 27 | + |
| 28 | + return gParameters |
| 29 | + |
| 30 | +class LossHistory( keras.callbacks.Callback ): |
| 31 | + def on_train_begin( self, logs= {} ): |
| 32 | + self.losses = [] |
| 33 | + |
| 34 | + def on_batch_end( self, batch, logs= {} ): |
| 35 | + self.losses.append( logs.get( 'loss' ) ) |
| 36 | + |
| 37 | + |
| 38 | + |
| 39 | +def sample( preds, temperature= 1.0 ): |
| 40 | + # helper function to sample an index from a probability array |
| 41 | + preds = np.asarray( preds ).astype( 'float64' ) |
| 42 | + preds = np.log( preds ) / temperature |
| 43 | + exp_preds = np.exp( preds ) |
| 44 | + preds = exp_preds / np.sum( exp_preds ) |
| 45 | + probas = np.random.multinomial( 1, preds, 1 ) |
| 46 | + return np.argmax( probas ) |
| 47 | + |
| 48 | + |
| 49 | + |
| 50 | +def run(gParameters, data_path): |
| 51 | + |
| 52 | + kerasDefaults = candle.keras_default_config() |
| 53 | + |
| 54 | + rnn_size = gParameters['rnn_size'] |
| 55 | + n_layers = gParameters['n_layers'] |
| 56 | + learning_rate = gParameters['learning_rate'] |
| 57 | + dropout = gParameters['drop'] |
| 58 | + recurrent_dropout = gParameters['recurrent_dropout'] |
| 59 | + n_epochs = gParameters['epochs'] |
| 60 | + data_train = data_path+'/data.pkl' |
| 61 | + verbose = gParameters['verbose'] |
| 62 | + savedir = gParameters['output_dir'] |
| 63 | + do_sample = gParameters['do_sample'] |
| 64 | + temperature = gParameters['temperature'] |
| 65 | + primetext = gParameters['primetext'] |
| 66 | + length = gParameters['length'] |
| 67 | + |
| 68 | + |
| 69 | + # load data from pickle |
| 70 | + f = open( data_train, 'rb' ) |
| 71 | + |
| 72 | + if ( sys.version_info > ( 3, 0 ) ): |
| 73 | + classes = pickle.load( f, encoding= 'latin1' ) |
| 74 | + chars = pickle.load( f, encoding= 'latin1' ) |
| 75 | + char_indices = pickle.load( f, encoding= 'latin1' ) |
| 76 | + indices_char = pickle.load( f, encoding= 'latin1' ) |
| 77 | + |
| 78 | + maxlen = pickle.load( f, encoding= 'latin1' ) |
| 79 | + step = pickle.load( f, encoding= 'latin1' ) |
| 80 | + |
| 81 | + X_ind = pickle.load( f, encoding= 'latin1' ) |
| 82 | + y_ind = pickle.load( f, encoding= 'latin1' ) |
| 83 | + else: |
| 84 | + classes = pickle.load( f ) |
| 85 | + chars = pickle.load( f ) |
| 86 | + char_indices = pickle.load( f ) |
| 87 | + indices_char = pickle.load( f ) |
| 88 | + |
| 89 | + maxlen = pickle.load( f ) |
| 90 | + step = pickle.load( f ) |
| 91 | + |
| 92 | + X_ind = pickle.load( f ) |
| 93 | + y_ind = pickle.load( f ) |
| 94 | + |
| 95 | + f.close() |
| 96 | + |
| 97 | + [ s1, s2 ] = X_ind.shape |
| 98 | + print( X_ind.shape ) |
| 99 | + print( y_ind.shape ) |
| 100 | + print( maxlen ) |
| 101 | + print( len( chars ) ) |
| 102 | + |
| 103 | + X = np.zeros( ( s1, s2, len( chars ) ), dtype=np.bool ) |
| 104 | + y = np.zeros( ( s1, len( chars ) ), dtype=np.bool ) |
| 105 | + |
| 106 | + for i in range( s1 ): |
| 107 | + for t in range( s2 ): |
| 108 | + X[ i, t, X_ind[ i, t ] ] = 1 |
| 109 | + y[ i, y_ind[ i ] ] = 1 |
| 110 | + |
| 111 | + # build the model: a single LSTM |
| 112 | + if verbose: |
| 113 | + print( 'Build model...' ) |
| 114 | + |
| 115 | + model = Sequential() |
| 116 | + |
| 117 | + # for rnn_size in rnn_sizes: |
| 118 | + for k in range( n_layers ): |
| 119 | + if k < n_layers - 1: |
| 120 | + ret_seq = True |
| 121 | + else: |
| 122 | + ret_seq = False |
| 123 | + |
| 124 | + if k == 0: |
| 125 | + model.add( LSTM( rnn_size, input_shape= ( maxlen, len( chars ) ), return_sequences= ret_seq, |
| 126 | + dropout= dropout, recurrent_dropout= recurrent_dropout ) ) |
| 127 | + else: |
| 128 | + model.add( LSTM( rnn_size, dropout= dropout, recurrent_dropout= recurrent_dropout, return_sequences= ret_seq ) ) |
| 129 | + |
| 130 | + model.add( Dense( len( chars ) ) ) |
| 131 | + model.add( Activation( gParameters['activation'] ) ) |
| 132 | + |
| 133 | + optimizer = candle.build_optimizer(gParameters['optimizer'], |
| 134 | + gParameters['learning_rate'], |
| 135 | + kerasDefaults) |
| 136 | + |
| 137 | + model.compile( loss= gParameters['loss'], optimizer= optimizer ) |
| 138 | + |
| 139 | + if verbose: |
| 140 | + model.summary() |
| 141 | + |
| 142 | + |
| 143 | + for iteration in range( 1, n_epochs + 1 ): |
| 144 | + if verbose: |
| 145 | + print() |
| 146 | + print('-' * 50) |
| 147 | + print('Iteration', iteration) |
| 148 | + |
| 149 | + history = LossHistory() |
| 150 | + model.fit( X, y, batch_size= 100, epochs= 1, callbacks= [ history ] ) |
| 151 | + |
| 152 | + loss = history.losses[ -1 ] |
| 153 | + if verbose: |
| 154 | + print( loss ) |
| 155 | + |
| 156 | + |
| 157 | + dirname = savedir |
| 158 | + if len( dirname ) > 0 and not dirname.endswith( '/' ): |
| 159 | + dirname = dirname + '/' |
| 160 | + |
| 161 | + if not os.path.exists( dirname ): |
| 162 | + os.makedirs( dirname ) |
| 163 | + |
| 164 | + # serialize model to JSON |
| 165 | + model_json = model.to_json() |
| 166 | + with open( dirname + "/model_" + str( iteration ) + "_" + "{:f}".format( loss ) + ".json", "w" ) as json_file: |
| 167 | + json_file.write( model_json ) |
| 168 | + |
| 169 | + # serialize weights to HDF5 |
| 170 | + model.save_weights( dirname + "/model_" + str( iteration ) + "_" + "{:f}".format( loss ) + ".h5" ) |
| 171 | + |
| 172 | + if verbose: |
| 173 | + print( "Checkpoint saved." ) |
| 174 | + |
| 175 | + if do_sample: |
| 176 | + outtext = open( dirname + "/example_" + str( iteration ) + "_" + "{:f}".format( loss ) + ".txt", "w" , encoding= 'utf-8' ) |
| 177 | + |
| 178 | + diversity = temperature |
| 179 | + |
| 180 | + outtext.write('----- diversity:' + str( diversity ) + "\n" ) |
| 181 | + |
| 182 | + generated = '' |
| 183 | + seedstr = primetext |
| 184 | + |
| 185 | + outtext.write('----- Generating with seed: "' + seedstr + '"' + "\n" ) |
| 186 | + |
| 187 | + sentence = " " * maxlen |
| 188 | + |
| 189 | + # class_index = 0 |
| 190 | + generated += sentence |
| 191 | + outtext.write( generated ) |
| 192 | + |
| 193 | + for c in seedstr: |
| 194 | + sentence = sentence[1:] + c |
| 195 | + x = np.zeros( ( 1, maxlen, len( chars ) ) ) |
| 196 | + for t, char in enumerate(sentence): |
| 197 | + x[ 0, t, char_indices[ char ] ] = 1. |
| 198 | + |
| 199 | + preds = model.predict( x, verbose= verbose )[ 0 ] |
| 200 | + next_index = sample( preds, diversity ) |
| 201 | + next_char = indices_char[ next_index ] |
| 202 | + |
| 203 | + generated += c |
| 204 | + |
| 205 | + outtext.write( c ) |
| 206 | + |
| 207 | + |
| 208 | + for i in range( length ): |
| 209 | + x = np.zeros( ( 1, maxlen, len( chars ) ) ) |
| 210 | + for t, char in enumerate( sentence ): |
| 211 | + x[ 0, t, char_indices[ char ] ] = 1. |
| 212 | + |
| 213 | + preds = model.predict( x, verbose= verbose )[ 0 ] |
| 214 | + next_index = sample( preds, diversity ) |
| 215 | + next_char = indices_char[ next_index ] |
| 216 | + |
| 217 | + generated += next_char |
| 218 | + sentence = sentence[ 1 : ] + next_char |
| 219 | + |
| 220 | + if (sys.version_info > (3, 0)): |
| 221 | + outtext.write( generated + '\n' ) |
| 222 | + else: |
| 223 | + outtext.write( generated.decode('utf-8').encode('utf-8') + '\n' ) |
| 224 | + |
| 225 | + outtext.close() |
| 226 | + |
| 227 | + |
| 228 | +if __name__ == "__main__": |
| 229 | + |
| 230 | + gParameters = initialize_parameters() |
| 231 | + |
| 232 | + origin = gParameters['data_url'] |
| 233 | + train_data = gParameters['train_data'] |
| 234 | + data_loc = candle.fetch_file(origin+train_data, untar=True, md5_hash=None, subdir='Pilot3') |
| 235 | + |
| 236 | + print( 'Data downloaded and stored at: ' + data_loc ) |
| 237 | + data_path = os.path.dirname(data_loc) |
| 238 | + print( data_path ) |
| 239 | + |
| 240 | + run(gParameters, data_path) |
0 commit comments