From 11001446d1782fd7a0f4e648a133b24a09b9b235 Mon Sep 17 00:00:00 2001 From: Ashka Shah Date: Thu, 30 Sep 2021 14:05:12 -0500 Subject: [PATCH 01/12] add counterfactual code --- Pilot1/NT3/nt3_cf/README.md | 15 + .../nt3_cf/abstention/abstain_functions.py | 203 ++++++ Pilot1/NT3/nt3_cf/abstention/make_csv.py | 41 ++ .../abstention/nt3_abstention_keras2_cf.py | 378 ++++++++++ .../nt3_cf/abstention/nt3_baseline_keras2.py | 318 +++++++++ .../nt3_cf/abstention/run_abstention_sweep.sh | 5 + Pilot1/NT3/nt3_cf/analyze.ipynb | 655 ++++++++++++++++++ Pilot1/NT3/nt3_cf/analyze.py | 33 + Pilot1/NT3/nt3_cf/cf_nb.py | 70 ++ Pilot1/NT3/nt3_cf/cf_script.py | 65 ++ Pilot1/NT3/nt3_cf/environment.yml | 264 +++++++ Pilot1/NT3/nt3_cf/gen_clusters.py | 109 +++ Pilot1/NT3/nt3_cf/inject_noise.py | 81 +++ Pilot1/NT3/nt3_cf/nt3.ipynb | 426 ++++++++++++ Pilot1/NT3/nt3_cf/test_cf_accuracy.py | 66 ++ Pilot1/NT3/nt3_cf/threshold.py | 70 ++ 16 files changed, 2799 insertions(+) create mode 100644 Pilot1/NT3/nt3_cf/README.md create mode 100644 Pilot1/NT3/nt3_cf/abstention/abstain_functions.py create mode 100644 Pilot1/NT3/nt3_cf/abstention/make_csv.py create mode 100644 Pilot1/NT3/nt3_cf/abstention/nt3_abstention_keras2_cf.py create mode 100644 Pilot1/NT3/nt3_cf/abstention/nt3_baseline_keras2.py create mode 100755 Pilot1/NT3/nt3_cf/abstention/run_abstention_sweep.sh create mode 100644 Pilot1/NT3/nt3_cf/analyze.ipynb create mode 100644 Pilot1/NT3/nt3_cf/analyze.py create mode 100644 Pilot1/NT3/nt3_cf/cf_nb.py create mode 100644 Pilot1/NT3/nt3_cf/cf_script.py create mode 100644 Pilot1/NT3/nt3_cf/environment.yml create mode 100644 Pilot1/NT3/nt3_cf/gen_clusters.py create mode 100644 Pilot1/NT3/nt3_cf/inject_noise.py create mode 100644 Pilot1/NT3/nt3_cf/nt3.ipynb create mode 100644 Pilot1/NT3/nt3_cf/test_cf_accuracy.py create mode 100644 Pilot1/NT3/nt3_cf/threshold.py diff --git a/Pilot1/NT3/nt3_cf/README.md b/Pilot1/NT3/nt3_cf/README.md new file mode 100644 index 00000000..3866b586 --- /dev/null +++ b/Pilot1/NT3/nt3_cf/README.md @@ -0,0 +1,15 @@ +NT3 with counterfactuals: +Code to generate counterfactual examples given an input model and dataset in pkl format. \ +Clusters and thresholds counterfactuals, injects noise into dataset \ +Workflow: +1) Generate counterfactuals using cf_nb.py +2) Create threshold pickle files using threshold.py (provide a threshold value between 0 and 1, see --help) +3) Cluster threshold files using gen_clusters.py +4) Inject noise into dataset using inject_noise.py (provide a scale value to modify the amplitude of the noise, see --help) + +Abstention with counterfactuals: +Code located in abstention/ +Workflow: +1) Run abstention model with nt3_abstention_keras2_cf.py, pass in a pickle file with X (with noise), y (this is the output of 4) above) +2) For a sweep use run_abstention_sweep.sh +3) To collect metrics (abstention, cluster abstention) run make_csv.py \ No newline at end of file diff --git a/Pilot1/NT3/nt3_cf/abstention/abstain_functions.py b/Pilot1/NT3/nt3_cf/abstention/abstain_functions.py new file mode 100644 index 00000000..8ee5c84b --- /dev/null +++ b/Pilot1/NT3/nt3_cf/abstention/abstain_functions.py @@ -0,0 +1,203 @@ +from tensorflow.keras import backend as K + +abs_definitions = [ + {'name': 'add_class', + 'nargs': '+', + 'type': int, + 'help': 'flag to add abstention (per task)'}, + {'name': 'alpha', + 'nargs': '+', + 'type': float, + 'help': 'abstention penalty coefficient (per task)'}, + {'name': 'min_acc', + 'nargs': '+', + 'type': float, + 'help': 'minimum accuracy required (per task)'}, + {'name': 'max_abs', + 'nargs': '+', + 'type': float, + 'help': 'maximum abstention fraction allowed (per task)'}, + {'name': 'alpha_scale_factor', + 'nargs': '+', + 'type': float, + 'help': 'scaling factor for modifying alpha (per task)'}, + {'name': 'init_abs_epoch', + 'action': 'store', + 'type': int, + 'help': 'number of epochs to skip before modifying alpha'}, + {'name': 'n_iters', + 'action': 'store', + 'type': int, + 'help': 'number of iterations to iterate alpha'}, + {'name': 'acc_gain', + 'type': float, + 'default': 5.0, + 'help': 'factor to weight accuracy when determining new alpha scale'}, + {'name': 'abs_gain', + 'type': float, + 'default': 1.0, + 'help': 'factor to weight abstention fraction when determining new alpha scale'}, + {'name': 'task_list', + 'nargs': '+', + 'type': int, + 'help': 'list of task indices to use'}, + {'name': 'task_names', + 'nargs': '+', + 'type': int, + 'help': 'list of names corresponding to each task to use'}, + {'name': 'cf_noise', + 'type': str, + 'help': 'input file with cf noise'} +] + + +def adjust_alpha(gParameters, X_test, truths_test, labels_val, model, alpha, add_index): + + task_names = gParameters['task_names'] + task_list = gParameters['task_list'] + # retrieve truth-pred pair + avg_loss = 0.0 + ret = [] + ret_k = [] + + # set abstaining classifier parameters + max_abs = gParameters['max_abs'] + min_acc = gParameters['min_acc'] + alpha_scale_factor = gParameters['alpha_scale_factor'] + + # print('labels_test', labels_test) + # print('Add_index', add_index) + + feature_test = X_test + # label_test = keras.utils.to_categorical(truths_test) + + # loss = model.evaluate(feature_test, [label_test[0], label_test[1],label_test[2], label_test[3]]) + loss = model.evaluate(feature_test, labels_val) + avg_loss = avg_loss + loss[0] + + pred = model.predict(feature_test) + # print('pred',pred.shape, pred) + + abs_gain = gParameters['abs_gain'] + acc_gain = gParameters['acc_gain'] + + accs = [] + abst = [] + + for k in range((alpha.shape[0])): + if k in task_list: + truth_test = truths_test[:, k] + alpha_k = K.eval(alpha[k]) + pred_classes = pred[k].argmax(axis=-1) + # true_classes = labels_test[k].argmax(axis=-1) + true_classes = truth_test + + # print('pred_classes',pred_classes.shape, pred_classes) + # print('true_classes',true_classes.shape, true_classes) + # print('labels',label_test.shape, label_test) + + true = K.eval(K.sum(K.cast(K.equal(pred_classes, true_classes), 'int64'))) + false = K.eval(K.sum(K.cast(K.not_equal(pred_classes, true_classes), 'int64'))) + abstain = K.eval(K.sum(K.cast(K.equal(pred_classes, add_index[k] - 1), 'int64'))) + + print(true, false, abstain) + + total = false + true + tot_pred = total - abstain + abs_acc = 0.0 + abs_frac = abstain / total + + if tot_pred > 0: + abs_acc = true / tot_pred + + scale_k = alpha_scale_factor[k] + min_scale = scale_k + max_scale = 1. / scale_k + + acc_error = abs_acc - min_acc[k] + acc_error = min(acc_error, 0.0) + abs_error = abs_frac - max_abs[k] + abs_error = max(abs_error, 0.0) + new_scale = 1.0 + acc_gain * acc_error + abs_gain * abs_error + + # threshold to avoid huge swings + new_scale = min(new_scale, max_scale) + new_scale = max(new_scale, min_scale) + + print('Scaling factor: ', new_scale) + K.set_value(alpha[k], new_scale * alpha_k) + + print_abs_stats(task_names[k], new_scale * alpha_k, true, false, abstain, max_abs[k]) + + ret_k.append(truth_test) + ret_k.append(pred) + + ret.append(ret_k) + + accs.append(abs_acc) + abst.append(abs_frac) + else: + accs.append(1.0) + accs.append(0.0) + + write_abs_stats(gParameters['output_dir'] + 'abs_stats.csv', alpha, accs, abst) + + return ret, alpha + + +def loss_param(alpha, mask): + def loss(y_true, y_pred): + + cost = 0 + + base_pred = (1 - mask) * y_pred + # base_true = (1 - mask) * y_true + base_true = y_true + + base_cost = K.sparse_categorical_crossentropy(base_true, base_pred) + + abs_pred = K.mean(mask * (y_pred), axis=-1) + # add some small value to prevent NaN when prediction is abstained + abs_pred = K.clip(abs_pred, K.epsilon(), 1. - K.epsilon()) + cost = (1. - abs_pred) * base_cost - (alpha) * K.log(1. - abs_pred) + + return cost + return loss + + +def print_abs_stats( + task_name, + alpha, + num_true, + num_false, + num_abstain, + max_abs): + + # Compute interesting values + total = num_true + num_false + tot_pred = total - num_abstain + abs_frac = num_abstain / total + abs_acc = 1.0 + if tot_pred > 0: + abs_acc = num_true / tot_pred + + print(' task, alpha, true, false, abstain, total, tot_pred, abs_frac, max_abs, abs_acc') + print('{:>12s}, {:10.5e}, {:8d}, {:8d}, {:8d}, {:8d}, {:8d}, {:10.5f}, {:10.5f}, {:10.5f}' + .format(task_name, alpha, + num_true, num_false - num_abstain, num_abstain, total, + tot_pred, abs_frac, max_abs, abs_acc)) + + +def write_abs_stats(stats_file, alphas, accs, abst): + + # Open file for appending + abs_file = open(stats_file, 'a') + + # we write all the results + for k in range((alphas.shape[0])): + abs_file.write("%10.5e," % K.get_value(alphas[k])) + for k in range((alphas.shape[0])): + abs_file.write("%10.5e," % accs[k]) + for k in range((alphas.shape[0])): + abs_file.write("%10.5e," % abst[k]) + abs_file.write("\n") diff --git a/Pilot1/NT3/nt3_cf/abstention/make_csv.py b/Pilot1/NT3/nt3_cf/abstention/make_csv.py new file mode 100644 index 00000000..6ee2d98a --- /dev/null +++ b/Pilot1/NT3/nt3_cf/abstention/make_csv.py @@ -0,0 +1,41 @@ +import pandas as pd +import pickle +import argparse +import glob, os +from pathlib import Path + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-f",type=str, help="Run folder") + parser.add_argument("-c1", type=str, help="cluster 1 name") + parser.add_argument("-c2", type=str, help="cluster 2 name") + args = parser.parse_args() + return args + +def main(): + args = get_args() + l1 = [] + l2 = [] + runs = glob.glob(args.f+"/EXP000/*/") + print(runs) + for r in runs: + global_data = pd.read_csv(r+"training.log") + val_abs = global_data['val_abstention'].iloc[-1] + val_abs_acc = global_data['val_abstention_acc'].iloc[-1] + cluster_data = pickle.load(open(r+"cluster_trace.pkl", "rb")) + polluted_abs = cluster_data['Abs polluted'] + val_abs_cluster = cluster_data['Abs val cluster'] + val_abs_acc_cluster = cluster_data['Abs val acc'] + ratio = float(r[-4:-1]) + if args.c1 in r: + l1.append([ratio, val_abs, val_abs_acc, val_abs_cluster, val_abs_acc_cluster, polluted_abs]) + elif args.c2 in r: + l2.append([ratio, val_abs, val_abs_acc, val_abs_cluster, val_abs_acc_cluster, polluted_abs]) + + df1 = pd.DataFrame(l1, columns=['Noise Fraction', 'Val Abs', 'Val Abs Acc', 'Val Abs Cluster', 'Val Abs Acc Cluster', 'Polluted Abs']) + df2 = pd.DataFrame(l2, columns=['Noise Fraction', 'Val Abs', 'Val Abs Acc', 'Val Abs Cluster', 'Val Abs Acc Cluster', 'Polluted Abs']) + print(df1) + df1.to_csv("cluster_1.csv") + df2.to_csv("cluster_2.csv") +if __name__ == "__main__": + main() diff --git a/Pilot1/NT3/nt3_cf/abstention/nt3_abstention_keras2_cf.py b/Pilot1/NT3/nt3_cf/abstention/nt3_abstention_keras2_cf.py new file mode 100644 index 00000000..8ec8c5d8 --- /dev/null +++ b/Pilot1/NT3/nt3_cf/abstention/nt3_abstention_keras2_cf.py @@ -0,0 +1,378 @@ +from __future__ import print_function +import pandas as pd +import numpy as np +import os +import tensorflow +from tensorflow.keras import backend as K +os.environ["CUDA_VISIBLE_DEVICES"]="2" +from tensorflow.keras.layers import Dense, Dropout, Activation, Conv1D, MaxPooling1D, Flatten, LocallyConnected1D +from tensorflow.keras.models import Sequential, model_from_json, model_from_yaml +from tensorflow.keras.utils import to_categorical +from tensorflow.keras.callbacks import CSVLogger, ReduceLROnPlateau + +from sklearn.preprocessing import MaxAbsScaler +from abstain_functions import abs_definitions + +import nt3 as bmk +import candle +import pickle +additional_definitions = abs_definitions + +required = bmk.required + + +class BenchmarkNT3Abs(candle.Benchmark): + def set_locals(self): + """Functionality to set variables specific for the benchmark + - required: set of required parameters for the benchmark. + - additional_definitions: list of dictionaries describing the additional parameters for the + benchmark. + """ + + if required is not None: + self.required = set(bmk.required) + if additional_definitions is not None: + self.additional_definitions = abs_definitions + bmk.additional_definitions + + +def initialize_parameters(default_model='nt3_noise_model.txt'): + + # Build benchmark object + nt3Bmk = BenchmarkNT3Abs( + bmk.file_path, + default_model, + 'keras', + prog='nt3_abstention', + desc='1D CNN to classify RNA sequence data in normal or tumor classes') + + # Initialize parameters + gParameters = candle.finalize_parameters(nt3Bmk) + + return gParameters + + +def load_data(path, gParameters): + + # Rewrite this function to handle pickle files instead + print("Loading data...") + data = pickle.load(open(path, 'rb')) + X=data[0] + y=data[1] + polluted_inds = data[2] + cluster_inds = data[3] + size = X.shape[0] + X_train = X[0:(int)(0.8*size)] + X_test = X[(int)(0.8*size):] + Y_train = y[0:(int)(0.8*size)] + Y_test = y[(int)(0.8*size):] + #df_train = (pd.read_csv(train_path, header=None).values).astype('float32') + #df_test = (pd.read_csv(test_path, header=None).values).astype('float32') + #X_train,Y_train, X_test, Y_test = data + #polluted_inds = [] + #cluster_inds=[] + print('done') + + + print('df_train shape:', X_train.shape) + print('df_test shape:', X_test.shape) + + return X_train, Y_train, X_test, Y_test, polluted_inds, cluster_inds + + +def run(gParameters): + + print('Params:', gParameters) + + data_file = gParameters['cf_noise'] + # file_test = gParameters['test_data'] + url = gParameters['data_url'] + + #train_file = candle.get_file(file_train, url + file_train, cache_subdir='Pilot1') + #test_file = candle.get_file(file_test, url + file_test, cache_subdir='Pilot1') + + X_train, Y_train, X_test, Y_test, polluted_inds, cluster_inds = load_data(data_file, gParameters) + + # add extra class for abstention + # first reverse the to_categorical + Y_train = np.argmax(Y_train, axis=1) + Y_test = np.argmax(Y_test, axis=1) + Y_train, Y_test = candle.modify_labels(gParameters['classes'] + 1, Y_train, Y_test) + # print(Y_test) + + print('X_train shape:', X_train.shape) + print('X_test shape:', X_test.shape) + + print('Y_train shape:', Y_train.shape) + print('Y_test shape:', Y_test.shape) + + x_train_len = X_train.shape[1] + + # this reshaping is critical for the Conv1D to work + + #X_train = np.expand_dims(X_train, axis=2) + #X_test = np.expand_dims(X_test, axis=2) + + print('X_train shape:', X_train.shape) + print('X_test shape:', X_test.shape) + + model = Sequential() + + layer_list = list(range(0, len(gParameters['conv']), 3)) + for _, i in enumerate(layer_list): + filters = gParameters['conv'][i] + filter_len = gParameters['conv'][i + 1] + stride = gParameters['conv'][i + 2] + print(int(i / 3), filters, filter_len, stride) + if gParameters['pool']: + pool_list = gParameters['pool'] + if type(pool_list) != list: + pool_list = list(pool_list) + + if filters <= 0 or filter_len <= 0 or stride <= 0: + break + if 'locally_connected' in gParameters: + model.add(LocallyConnected1D(filters, filter_len, strides=stride, padding='valid', input_shape=(x_train_len, 1))) + else: + # input layer + if i == 0: + model.add(Conv1D(filters=filters, kernel_size=filter_len, strides=stride, padding='valid', input_shape=(x_train_len, 1))) + else: + model.add(Conv1D(filters=filters, kernel_size=filter_len, strides=stride, padding='valid')) + model.add(Activation(gParameters['activation'])) + if gParameters['pool']: + model.add(MaxPooling1D(pool_size=pool_list[int(i / 3)])) + + model.add(Flatten()) + + for layer in gParameters['dense']: + if layer: + model.add(Dense(layer)) + model.add(Activation(gParameters['activation'])) + if gParameters['dropout']: + model.add(Dropout(gParameters['dropout'])) + model.add(Dense(gParameters['classes'])) + model.add(Activation(gParameters['out_activation'])) + + # modify the model for abstention + model = candle.add_model_output(model, mode='abstain', num_add=1, activation=gParameters['out_activation']) + +# Reference case +# model.add(Conv1D(filters=128, kernel_size=20, strides=1, padding='valid', input_shape=(P, 1))) +# model.add(Activation('relu')) +# model.add(MaxPooling1D(pool_size=1)) +# model.add(Conv1D(filters=128, kernel_size=10, strides=1, padding='valid')) +# model.add(Activation('relu')) +# model.add(MaxPooling1D(pool_size=10)) +# model.add(Flatten()) +# model.add(Dense(200)) +# model.add(Activation('relu')) +# model.add(Dropout(0.1)) +# model.add(Dense(20)) +# model.add(Activation('relu')) +# model.add(Dropout(0.1)) +# model.add(Dense(CLASSES)) +# model.add(Activation('softmax')) + + kerasDefaults = candle.keras_default_config() + + # Define optimizer + optimizer = candle.build_optimizer(gParameters['optimizer'], + gParameters['learning_rate'], + kerasDefaults) + + model.summary() + + # Configure abstention model + nb_classes = gParameters['classes'] + mask = np.zeros(nb_classes + 1) + mask[nb_classes] = 1.0 + print("Mask is ", mask) + alpha0 = gParameters['alpha'] + if isinstance(gParameters['max_abs'], list): + max_abs = gParameters['max_abs'][0] + else: + max_abs = gParameters['max_abs'] + + print("Initializing abstention callback with: \n") + print("alpha0 ", alpha0) + print("alpha_scale_factor ", gParameters['alpha_scale_factor']) + print("min_abs_acc ", gParameters['min_acc']) + print("max_abs_frac ", max_abs) + print("acc_gain ", gParameters['acc_gain']) + print("abs_gain ", gParameters['abs_gain']) + + abstention_cbk = candle.AbstentionAdapt_Callback(acc_monitor='val_abstention_acc', + abs_monitor='val_abstention', + init_abs_epoch=gParameters['init_abs_epoch'], + alpha0=alpha0, + alpha_scale_factor=gParameters['alpha_scale_factor'], + min_abs_acc=gParameters['min_acc'], + max_abs_frac=max_abs, + acc_gain=gParameters['acc_gain'], + abs_gain=gParameters['abs_gain']) + + model.compile(loss=candle.abstention_loss(abstention_cbk.alpha, mask), + optimizer=optimizer, + metrics=[candle.abstention_acc_metric(nb_classes), + # candle.acc_class_i_metric(1), + # candle.abstention_acc_class_i_metric(nb_classes, 1), + candle.abstention_metric(nb_classes)]) + + # model.compile(loss=abs_loss, + # optimizer=optimizer, + # metrics=abs_acc) + + # model.compile(loss=gParameters['loss'], + # optimizer=optimizer, + # metrics=[gParameters['metrics']]) + + output_dir = gParameters['output_dir'] + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # calculate trainable and non-trainable params + gParameters.update(candle.compute_trainable_params(model)) + + # set up a bunch of callbacks to do work during model training.. + model_name = gParameters['model_name'] + # path = '{}/{}.autosave.model.h5'.format(output_dir, model_name) + # checkpointer = ModelCheckpoint(filepath=path, verbose=1, save_weights_only=False, save_best_only=True) + print(output_dir) + csv_logger = CSVLogger("{}/training.log".format(output_dir)) + reduce_lr = ReduceLROnPlateau(monitor='val_loss', + factor=0.1, patience=10, verbose=1, mode='auto', + epsilon=0.0001, cooldown=0, min_lr=0) + + candleRemoteMonitor = candle.CandleRemoteMonitor(params=gParameters) + timeoutMonitor = candle.TerminateOnTimeOut(gParameters['timeout']) + + # n_iters = 1 + + # val_labels = {"activation_5": Y_test} + # for epoch in range(gParameters['epochs']): + # print('Iteration = ', epoch) + history = model.fit(X_train, Y_train, + batch_size=gParameters['batch_size'], + epochs=gParameters['epochs'], + # initial_epoch=epoch, + # epochs=epoch + n_iters, + verbose=1, + validation_data=(X_test, Y_test), + # callbacks=[csv_logger, reduce_lr, candleRemoteMonitor, timeoutMonitor]) # , abstention_cbk]) + callbacks=[csv_logger, reduce_lr, candleRemoteMonitor, timeoutMonitor, abstention_cbk]) + + # ret, alpha = adjust_alpha(gParameters, X_test, Y_test, val_labels, model, alpha, [nb_classes+1]) + + score = model.evaluate(X_test, Y_test, verbose=0) + + if len(polluted_inds) > 0: + y_pred = model.predict(X_test) + abstain_inds = [] + for i in range(y_pred.shape[0]): + if np.argmax(y_pred[i]) == nb_classes: + abstain_inds.append(i) + + # Cluster indices and polluted indices are wrt to entire train + test dataset + # whereas y_pred only contains test dataset so add offset for correct indexing + offset_testset = Y_train.shape[0] + abstain_inds=[i+offset_testset for i in abstain_inds] + + polluted_percentage = c = np.sum([el in polluted_inds for el in abstain_inds])/np.max([len(abstain_inds),1]) + print("Percentage of abstained samples that were polluted {}".format(polluted_percentage)) + + cluster_inds_test = list(filter(lambda cluster_inds: cluster_inds >= offset_testset, cluster_inds)) + cluster_inds_test_abstain = [el in abstain_inds for el in cluster_inds_test] + cluster_percentage = c = np.sum(cluster_inds_test_abstain)/len(cluster_inds_test) + print("Percentage of cluster (in test set) that was abstained {}".format(cluster_percentage)) + + unabstain_inds = [] + for i in range(y_pred.shape[0]): + if np.argmax(y_pred[i]) != nb_classes and (i+offset_testset in cluster_inds_test): + unabstain_inds.append(i) + # Make sure number of unabstained indices in cluster test set plus number of abstainsed indices in cluster test set + # equals number of indices in cluster in the test set + assert(len(unabstain_inds)+np.sum(cluster_inds_test_abstain) == len(cluster_inds_test)) + score_cluster = 1 if len(unabstain_inds)==0 else model.evaluate(X_test[unabstain_inds], Y_test[unabstain_inds])[1] + print("Accuracy of unabastained cluster {}".format(score_cluster)) + + pickle.dump({'Abs polluted': polluted_percentage, 'Abs val cluster': cluster_percentage, 'Abs val acc': score_cluster}, open("{}/cluster_trace.pkl".format(output_dir), "wb")) + + alpha_trace = open(output_dir + "/alpha_trace", "w+") + for alpha in abstention_cbk.alphavalues: + alpha_trace.write(str(alpha) + '\n') + alpha_trace.close() + + if False: + print('Test score:', score[0]) + print('Test accuracy:', score[1]) + # serialize model to JSON + model_json = model.to_json() + with open("{}/{}.model.json".format(output_dir, model_name), "w") as json_file: + json_file.write(model_json) + + # serialize model to YAML + model_yaml = model.to_yaml() + with open("{}/{}.model.yaml".format(output_dir, model_name), "w") as yaml_file: + yaml_file.write(model_yaml) + + # serialize weights to HDF5 + model.save_weights("{}/{}.weights.h5".format(output_dir, model_name)) + print("Saved model to disk") + + # load json and create model + json_file = open('{}/{}.model.json'.format(output_dir, model_name), 'r') + loaded_model_json = json_file.read() + json_file.close() + loaded_model_json = model_from_json(loaded_model_json) + + # load yaml and create model + yaml_file = open('{}/{}.model.yaml'.format(output_dir, model_name), 'r') + loaded_model_yaml = yaml_file.read() + yaml_file.close() + loaded_model_yaml = model_from_yaml(loaded_model_yaml) + + # load weights into new model + loaded_model_json.load_weights('{}/{}.weights.h5'.format(output_dir, model_name)) + print("Loaded json model from disk") + + # evaluate json loaded model on test data + loaded_model_json.compile(loss=gParameters['loss'], + optimizer=gParameters['optimizer'], + metrics=[gParameters['metrics']]) + score_json = loaded_model_json.evaluate(X_test, Y_test, verbose=0) + + print('json Test score:', score_json[0]) + print('json Test accuracy:', score_json[1]) + + print("json %s: %.2f%%" % (loaded_model_json.metrics_names[1], score_json[1] * 100)) + + # load weights into new model + loaded_model_yaml.load_weights('{}/{}.weights.h5'.format(output_dir, model_name)) + print("Loaded yaml model from disk") + + # evaluate loaded model on test data + loaded_model_yaml.compile(loss=gParameters['loss'], + optimizer=gParameters['optimizer'], + metrics=[gParameters['metrics']]) + score_yaml = loaded_model_yaml.evaluate(X_test, Y_test, verbose=0) + + print('yaml Test score:', score_yaml[0]) + print('yaml Test accuracy:', score_yaml[1]) + + print("yaml %s: %.2f%%" % (loaded_model_yaml.metrics_names[1], score_yaml[1] * 100)) + + return history + + +def main(): + gParameters = initialize_parameters() + run(gParameters) + + +if __name__ == '__main__': + main() + try: + K.clear_session() + except AttributeError: # theano does not have this function + pass diff --git a/Pilot1/NT3/nt3_cf/abstention/nt3_baseline_keras2.py b/Pilot1/NT3/nt3_cf/abstention/nt3_baseline_keras2.py new file mode 100644 index 00000000..8d1227f5 --- /dev/null +++ b/Pilot1/NT3/nt3_cf/abstention/nt3_baseline_keras2.py @@ -0,0 +1,318 @@ +from __future__ import print_function + +import pandas as pd +import numpy as np +import os +import pickle + +from tensorflow.keras import backend as K + +from tensorflow.keras.layers import Dense, Dropout, Activation, Conv1D, MaxPooling1D, Flatten, LocallyConnected1D +from tensorflow.keras.models import Sequential, model_from_json, model_from_yaml +from tensorflow.keras.utils import to_categorical +from tensorflow.keras.callbacks import CSVLogger, ReduceLROnPlateau + +from sklearn.preprocessing import MaxAbsScaler + +import nt3 as bmk +import candle + + +def initialize_parameters(default_model='nt3_default_model.txt'): + + # Build benchmark object + nt3Bmk = bmk.BenchmarkNT3( + bmk.file_path, + default_model, + 'keras', + prog='nt3_baseline', + desc='1D CNN to classify RNA sequence data in normal or tumor classes') + + # Initialize parameters + gParameters = candle.finalize_parameters(nt3Bmk) + + return gParameters + +def load_data_pickle(path, gParameters): + # Rewrite this function to handle pickle files instead + print("Loading data...") + data = pickle.load(open(path, 'rb')) + X=data[0] + y=data[1] + polluted_inds = data[2] + cluster_inds = data[3] + size = X.shape[0] + X_train = X[0:(int)(0.8*size)] + X_test = X[(int)(0.8*size):] + Y_train = y[0:(int)(0.8*size)] + Y_test = y[(int)(0.8*size):] + #df_train = (pd.read_csv(train_path, header=None).values).astype('float32') + #df_test = (pd.read_csv(test_path, header=None).values).astype('float32') + #X_train,Y_train, X_test, Y_test = data + print('done') + return X_train, Y_train, X_test, Y_test + +def load_data(train_path, test_path, gParameters): + + print('Loading data...') + df_train = (pd.read_csv(train_path, header=None).values).astype('float32') + df_test = (pd.read_csv(test_path, header=None).values).astype('float32') + print('done') + + print('df_train shape:', df_train.shape) + print('df_test shape:', df_test.shape) + + seqlen = df_train.shape[1] + + df_y_train = df_train[:, 0].astype('int') + df_y_test = df_test[:, 0].astype('int') + + # only training set has noise + Y_train = to_categorical(df_y_train, gParameters['classes']) + Y_test = to_categorical(df_y_test, gParameters['classes']) + + df_x_train = df_train[:, 1:seqlen].astype(np.float32) + df_x_test = df_test[:, 1:seqlen].astype(np.float32) + + X_train = df_x_train + X_test = df_x_test + + scaler = MaxAbsScaler() + mat = np.concatenate((X_train, X_test), axis=0) + mat = scaler.fit_transform(mat) + + X_train = mat[:X_train.shape[0], :] + X_test = mat[X_train.shape[0]:, :] + + # TODO: Add better names for noise boolean, make a featue for both RNA seq and label noise together + # check if noise is on (this is for label) + if gParameters['add_noise']: + # check if we want noise correlated with a feature + if gParameters['noise_correlated']: + Y_train, y_train_noise_gen = candle.label_flip_correlated(Y_train, + gParameters['label_noise'], X_train, + gParameters['feature_col'], + gParameters['feature_threshold']) + # else add uncorrelated noise + else: + Y_train, y_train_noise_gen = candle.label_flip(Y_train, gParameters['label_noise']) + # check if noise is on for RNA-seq data + elif gParameters['noise_gaussian']: + X_train = candle.add_gaussian_noise(X_train, 0, gParameters['std_dev']) + + return X_train, Y_train, X_test, Y_test + + +def run(gParameters): + + file_train = gParameters['train_data'] + file_test = gParameters['test_data'] + url = gParameters['data_url'] + + #train_file = candle.get_file(file_train, url + file_train, cache_subdir='Pilot1') + #test_file = candle.get_file(file_test, url + file_test, cache_subdir='Pilot1') + + model = Sequential() + + initial_epoch = 0 + best_metric_last = None + + #X_train, Y_train, X_test, Y_test = load_data(train_file, test_file, gParameters) + X_train, Y_train, X_test, Y_test = load_data_pickle(file_train, gParameters) + + print('X_train shape:', X_train.shape) + print('X_test shape:', X_test.shape) + + print('Y_train shape:', Y_train.shape) + print('Y_test shape:', Y_test.shape) + + x_train_len = X_train.shape[1] + + # this reshaping is critical for the Conv1D to work + + X_train = np.expand_dims(X_train, axis=2) + X_test = np.expand_dims(X_test, axis=2) + + print('X_train shape:', X_train.shape) + print('X_test shape:', X_test.shape) + + layer_list = list(range(0, len(gParameters['conv']), 3)) + for _, i in enumerate(layer_list): + filters = gParameters['conv'][i] + filter_len = gParameters['conv'][i + 1] + stride = gParameters['conv'][i + 2] + print(int(i / 3), filters, filter_len, stride) + if gParameters['pool']: + pool_list = gParameters['pool'] + if type(pool_list) != list: + pool_list = list(pool_list) + + if filters <= 0 or filter_len <= 0 or stride <= 0: + break + if 'locally_connected' in gParameters: + model.add(LocallyConnected1D(filters, filter_len, strides=stride, padding='valid', input_shape=(x_train_len, 1))) + else: + # input layer + if i == 0: + model.add(Conv1D(filters=filters, kernel_size=filter_len, strides=stride, padding='valid', input_shape=(x_train_len, 1))) + else: + model.add(Conv1D(filters=filters, kernel_size=filter_len, strides=stride, padding='valid')) + model.add(Activation(gParameters['activation'])) + if gParameters['pool']: + model.add(MaxPooling1D(pool_size=pool_list[int(i / 3)])) + + model.add(Flatten()) + + for layer in gParameters['dense']: + if layer: + model.add(Dense(layer)) + model.add(Activation(gParameters['activation'])) + if gParameters['dropout']: + model.add(Dropout(gParameters['dropout'])) + model.add(Dense(gParameters['classes'])) + model.add(Activation(gParameters['out_activation'])) + + J = candle.restart(gParameters, model) + if J is not None: + initial_epoch = J['epoch'] + best_metric_last = J['best_metric_last'] + gParameters['ckpt_best_metric_last'] = best_metric_last + print('initial_epoch: %i' % initial_epoch) + + ckpt = candle.CandleCheckpointCallback(gParameters, + verbose=False) + +# Reference case +# model.add(Conv1D(filters=128, kernel_size=20, strides=1, padding='valid', input_shape=(P, 1))) +# model.add(Activation('relu')) +# model.add(MaxPooling1D(pool_size=1)) +# model.add(Conv1D(filters=128, kernel_size=10, strides=1, padding='valid')) +# model.add(Activation('relu')) +# model.add(MaxPooling1D(pool_size=10)) +# model.add(Flatten()) +# model.add(Dense(200)) +# model.add(Activation('relu')) +# model.add(Dropout(0.1)) +# model.add(Dense(20)) +# model.add(Activation('relu')) +# model.add(Dropout(0.1)) +# model.add(Dense(CLASSES)) +# model.add(Activation('softmax')) + + kerasDefaults = candle.keras_default_config() + + # Define optimizer + optimizer = candle.build_optimizer(gParameters['optimizer'], + gParameters['learning_rate'], + kerasDefaults) + + model.summary() + model.compile(loss=gParameters['loss'], + optimizer=optimizer, + metrics=[gParameters['metrics']]) + + output_dir = gParameters['output_dir'] + + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # calculate trainable and non-trainable params + gParameters.update(candle.compute_trainable_params(model)) + + # set up a bunch of callbacks to do work during model training.. + model_name = gParameters['model_name'] + # path = '{}/{}.autosave.model.h5'.format(output_dir, model_name) + # checkpointer = ModelCheckpoint(filepath=path, verbose=1, save_weights_only=False, save_best_only=True) + csv_logger = CSVLogger('{}/training.log'.format(output_dir)) + reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=1, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0) + candleRemoteMonitor = candle.CandleRemoteMonitor(params=gParameters) + timeoutMonitor = candle.TerminateOnTimeOut(gParameters['timeout']) + + history = model.fit(X_train, Y_train, + batch_size=gParameters['batch_size'], + epochs=gParameters['epochs'], + initial_epoch=initial_epoch, + verbose=1, + validation_data=(X_test, Y_test), + callbacks=[csv_logger, reduce_lr, candleRemoteMonitor, timeoutMonitor, + ckpt]) + + score = model.evaluate(X_test, Y_test, verbose=0) + + if False: + print('Test score:', score[0]) + print('Test accuracy:', score[1]) + # serialize model to JSON + model_json = model.to_json() + with open("{}/{}.model.json".format(output_dir, model_name), "w") as json_file: + json_file.write(model_json) + + # serialize model to YAML + model_yaml = model.to_yaml() + with open("{}/{}.model.yaml".format(output_dir, model_name), "w") as yaml_file: + yaml_file.write(model_yaml) + + # serialize weights to HDF5 + model.save_weights("{}/{}.weights.h5".format(output_dir, model_name)) + print("Saved model to disk") + + # load json and create model + json_file = open('{}/{}.model.json'.format(output_dir, model_name), 'r') + loaded_model_json = json_file.read() + json_file.close() + loaded_model_json = model_from_json(loaded_model_json) + + # load yaml and create model + yaml_file = open('{}/{}.model.yaml'.format(output_dir, model_name), 'r') + loaded_model_yaml = yaml_file.read() + yaml_file.close() + loaded_model_yaml = model_from_yaml(loaded_model_yaml) + + # load weights into new model + loaded_model_json.load_weights('{}/{}.weights.h5'.format(output_dir, model_name)) + print("Loaded json model from disk") + + # evaluate json loaded model on test data + loaded_model_json.compile(loss=gParameters['loss'], + optimizer=gParameters['optimizer'], + metrics=[gParameters['metrics']]) + score_json = loaded_model_json.evaluate(X_test, Y_test, verbose=0) + + print('json Test score:', score_json[0]) + print('json Test accuracy:', score_json[1]) + + print("json %s: %.2f%%" % (loaded_model_json.metrics_names[1], score_json[1] * 100)) + + # load weights into new model + loaded_model_yaml.load_weights('{}/{}.weights.h5'.format(output_dir, model_name)) + print("Loaded yaml model from disk") + + # evaluate loaded model on test data + loaded_model_yaml.compile(loss=gParameters['loss'], + optimizer=gParameters['optimizer'], + metrics=[gParameters['metrics']]) + score_yaml = loaded_model_yaml.evaluate(X_test, Y_test, verbose=0) + + print('yaml Test score:', score_yaml[0]) + print('yaml Test accuracy:', score_yaml[1]) + + print("yaml %s: %.2f%%" % (loaded_model_yaml.metrics_names[1], score_yaml[1] * 100)) + + model.save(path) + path = '{}/{}.autosave.data.h5'.format(output_dir, model_name) + pickle.dump( [X_train, Y_train, X_test, Y_test], open( path, "wb" ) ) + print(path) + return history + + +def main(): + gParameters = initialize_parameters() + run(gParameters) + + +if __name__ == '__main__': + main() + try: + K.clear_session() + except AttributeError: # theano does not have this function + pass diff --git a/Pilot1/NT3/nt3_cf/abstention/run_abstention_sweep.sh b/Pilot1/NT3/nt3_cf/abstention/run_abstention_sweep.sh new file mode 100755 index 00000000..b7a9f611 --- /dev/null +++ b/Pilot1/NT3/nt3_cf/abstention/run_abstention_sweep.sh @@ -0,0 +1,5 @@ +#!/bin/bash +for filename in /vol/ml/shahashka/xai-geom/nt3/nt3.data*; do + python nt3_abstention_keras2_cf.py --cf_noise $filename --output_dir cf_sweep_0906 --run_id ${filename:40:21} --epochs 100 + #cp cf_sweep_0902/EXP000/RUN000/training.log ${filename}_training_0902.log +done diff --git a/Pilot1/NT3/nt3_cf/analyze.ipynb b/Pilot1/NT3/nt3_cf/analyze.ipynb new file mode 100644 index 00000000..99d570c1 --- /dev/null +++ b/Pilot1/NT3/nt3_cf/analyze.ipynb @@ -0,0 +1,655 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pickle\n", + "import os\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.cluster import KMeans\n", + "from sklearn.decomposition import PCA\n", + "from sklearn.manifold import TSNE" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "parent = '/Users/shah38/Desktop/xai-geom/nt3/'\n", + "# directory = parent + 'pickle_summit_35'\n", + "# counterfactuals = []\n", + "# count = 0\n", + "# for filename in os.listdir(directory):\n", + "# if filename.startswith(\"save\"):\n", + "# count+=1\n", + "# d = pickle.load(open(os.path.join(directory, filename), 'rb'))\n", + "# counterfactuals.append(d)" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "test = [item for sublist in counterfactuals for item in sublist]" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1160\n" + ] + } + ], + "source": [ + "print(len(test))" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "29\n" + ] + } + ], + "source": [ + "print(count)" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "ename": "FileNotFoundError", + "evalue": "[Errno 2] No such file or directory: 'nt3.autosave.data.pkl'", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mFileNotFoundError\u001B[0m Traceback (most recent call last)", + "\u001B[0;32m\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mdata\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mpickle\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mload\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mopen\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m\"nt3.autosave.data.pkl\"\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m'rb'\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 2\u001B[0m \u001B[0mprint\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdata\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mshape\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 3\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n", + "\u001B[0;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: 'nt3.autosave.data.pkl'" + ] + } + ], + "source": [ + "data = pickle.load(open(\"nt3.autosave.data.pkl\",'rb'))\n", + "print(data[0].shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "pickle.dump(test, open(\"complete_save.pkl\", 'wb'))" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "# Plot # of positive/negative indices per threshold value\n", + "# Pick a threshold\n", + "# Find genes that overlap the most\n", + "num_pos = []\n", + "num_neg = []\n", + "#threshold_values = [0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]\n", + "threshold_values = [0.9]#, 0.925, 0.95, 0.975, 1.0]\n", + "for t in threshold_values:\n", + " thresholds = pickle.load(open('{}threshold_{}.pkl'.format(parent,t), 'rb'))\n", + " pos = thresholds['positive threshold indices']\n", + " num_pos.append([pos[i][0].shape[0] for i in range(len(pos))])\n", + " neg = thresholds['negative threshold indices']\n", + " num_neg.append([neg[i][0].shape[0] for i in range(len(neg))])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "\n" + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "pos = np.arange(len(num_pos)) + 1\n", + "total = [np.array(num_pos[i]) + np.array(num_neg[i]) for i in range(len(num_pos)) ]\n", + "bp = ax.boxplot(total, sym='k+', positions=pos)\n", + "\n", + "ax.set_xlabel('threshold value')\n", + "ax.set_ylabel('# indices')\n", + "#ax.set_xticks(np.arange(0,1.1,0.1))\n", + "plt.setp(bp['whiskers'], color='k', linestyle='-')\n", + "plt.setp(bp['fliers'], markersize=3.0)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "pos = np.arange(len(num_pos)) + 1\n", + "bp = ax.boxplot(num_pos, sym='k+', positions=pos)\n", + "\n", + "ax.set_xlabel('threshold value')\n", + "ax.set_xticklabels([0.9,0.925,0.95,0.975,1.0])\n", + "ax.set_ylabel('# indices')\n", + "plt.setp(bp['whiskers'], color='k', linestyle='-')\n", + "plt.setp(bp['fliers'], markersize=3.0)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAOnElEQVR4nO3dXYyc1X3H8e+vONAqiWIDW8uynS5pLFXcBKwVdZUoakEh4FQ1lRJEVRWLWvINkYjSqnWai6ZSL6BSQ4tUIbkF1URpCMqLsAJt4hqiqBcQlgTMWykLBWHLYCcQkihKWpJ/L+a4mZhd7+w7e/b7kVZznnPOzJz/PuOfn33mLVWFJKkvv7TSC5AkLT7DXZI6ZLhLUocMd0nqkOEuSR1at9ILADj//PNrfHx8pZchSavKww8//J2qGptu7E0R7uPj40xOTq70MiRpVUnywkxjnpaRpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOrfpwH993D+P77lnpZUjSm8qqD3dJ0hsZ7pLUoZHCPcnzSR5L8kiSydZ3bpJDSZ5plxtaf5LckmQqyZEk25eyAEnSG83lyP13quqiqppo2/uAw1W1DTjctgGuBLa1n73ArYu1WEnSaBZyWmYXcKC1DwBXDfXfUQMPAOuTbFrA/UiS5mjUcC/ga0keTrK39W2squOt/RKwsbU3Ay8OXfdo65MkLZNRv6zjfVV1LMmvAoeS/OfwYFVVkprLHbf/JPYCvPOd75zLVSVJsxjpyL2qjrXLE8CXgUuAl0+dbmmXJ9r0Y8DWoatvaX2n3+b+qpqoqomxsWm/JUqSNE+zhnuStyZ5+6k2cDnwOHAQ2N2m7Qbubu2DwLXtVTM7gNeGTt9IkpbBKKdlNgJfTnJq/r9U1b8leQi4K8ke4AXg6jb/XmAnMAX8CLhu0VctSTqjWcO9qp4D3jNN/3eBy6bpL+D6RVmdJGlefIeqJHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdGjnck5yV5NtJvtK2L0jyYJKpJJ9PcnbrP6dtT7Xx8aVZuiRpJnM5cr8BeGpo+ybg5qp6N/AqsKf17wFebf03t3mSpGU0Urgn2QJ8CPinth3gUuALbcoB4KrW3tW2aeOXtfmSpGUy6pH73wF/BvysbZ8HfK+qXm/bR4HNrb0ZeBGgjb/W5kuSlsms4Z7kd4ETVfXwYt5xkr1JJpNMnjx5cjFvWpLWvFGO3N8L/F6S54E7GZyO+XtgfZJ1bc4W4FhrHwO2ArTxdwDfPf1Gq2p/VU1U1cTY2NiCipAk/aJZw72qPlFVW6pqHLgGuK+q/hC4H/hwm7YbuLu1D7Zt2vh9VVWLumpJ0hkt5HXufw58PMkUg3Pqt7X+24DzWv/HgX0LW6Ikaa7WzT7l56rq68DXW/s54JJp5vwY+MgirE2SNE++Q1WSOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA7NGu5JfjnJN5M8muSJJH/V+i9I8mCSqSSfT3J26z+nbU+18fGlLUGSdLpRjtx/AlxaVe8BLgKuSLIDuAm4uareDbwK7Gnz9wCvtv6b2zxJ0jKaNdxr4Idt8y3tp4BLgS+0/gPAVa29q23Txi9LkkVbsSRpViOdc09yVpJHgBPAIeBZ4HtV9XqbchTY3NqbgRcB2vhrwHmLuWhJ0pmNFO5V9dOqugjYAlwC/MZC7zjJ3iSTSSZPnjy50JuTJA2Z06tlqup7wP3AbwHrk6xrQ1uAY619DNgK0MbfAXx3mtvaX1UTVTUxNjY2z+VLkqYzyqtlxpKsb+1fAT4APMUg5D/cpu0G7m7tg22bNn5fVdViLlqSdGbrZp/CJuBAkrMY/GdwV1V9JcmTwJ1J/hr4NnBbm38b8JkkU8ArwDVLsG5J0hnMGu5VdQS4eJr+5xicfz+9/8fARxZldZKkefEdqpLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHeom3Mf33bPSS5CkN41uwl2S9HOGuyR1aNZwT7I1yf1JnkzyRJIbWv+5SQ4leaZdbmj9SXJLkqkkR5JsX+oiJEm/aJQj99eBP6mqC4EdwPVJLgT2AYerahtwuG0DXAlsaz97gVsXfdWSpDOaNdyr6nhVfau1fwA8BWwGdgEH2rQDwFWtvQu4owYeANYn2bToK5ckzWhO59yTjAMXAw8CG6vqeBt6CdjY2puBF4eudrT1nX5be5NMJpk8efLkHJctSTqTkcM9yduALwIfq6rvD49VVQE1lzuuqv1VNVFVE2NjY3O5qiRpFiOFe5K3MAj2z1bVl1r3y6dOt7TLE63/GLB16OpbWp8kaZmM8mqZALcBT1XVp4eGDgK7W3s3cPdQ/7XtVTM7gNeGTt9IkpbBuhHmvBf4I+CxJI+0vr8AbgTuSrIHeAG4uo3dC+wEpoAfAdct6oolSbOaNdyr6j+AzDB82TTzC7h+geuSJC2A71CVpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkd6ircx/fds9JLkKQ3ha7CXZI0sG6lF7AUznQE//yNH1rGlUjSyvDIXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHVo1nBPcnuSE0keH+o7N8mhJM+0yw2tP0luSTKV5EiS7Uu5eEnS9EY5cv9n4IrT+vYBh6tqG3C4bQNcCWxrP3uBWxdnmZKkuZg13KvqG8Arp3XvAg609gHgqqH+O2rgAWB9kk2LtVhJ0mjme859Y1Udb+2XgI2tvRl4cWje0db3Bkn2JplMMnny5Ml5LkOSNJ0FP6FaVQXUPK63v6omqmpibGxsocuQJA2Zb7i/fOp0S7s80fqPAVuH5m1pfZKkZTTfcD8I7G7t3cDdQ/3XtlfN7ABeGzp9I0laJrN+KmSSzwG/DZyf5Cjwl8CNwF1J9gAvAFe36fcCO4Ep4EfAdUuwZknSLGYN96r6gxmGLptmbgHXL3RRkqSF8R2qktQhw12SOmS4S1KHDHdJ6pDhLkkd6i7cz/Tl2JK0VnQX7pIkw12SumS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA7N+pG/venhTU7P3/ihlV6CpDc5j9wlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR1aknBPckWSp5NMJdm3FPchSZrZon8TU5KzgH8APgAcBR5KcrCqnlzs+1qrevg2qblai98+5X7WQizF1+xdAkxV1XMASe4EdgGGu+ZtLQbdWrQW9/NS/Ye2FOG+GXhxaPso8JunT0qyF9jbNn+Y5Ol53t/5wHfmed3VyprXBmteA3LTgmr+tZkGVuwLsqtqP7B/obeTZLKqJhZhSauGNa8N1rw2LFXNS/GE6jFg69D2ltYnSVomSxHuDwHbklyQ5GzgGuDgEtyPJGkGi35apqpeT/JR4KvAWcDtVfXEYt/PkAWf2lmFrHltsOa1YUlqTlUtxe1KklaQ71CVpA4Z7pLUoVUd7j1/zEGS55M8luSRJJOt79wkh5I80y43tP4kuaX9Ho4k2b6yqx9NktuTnEjy+FDfnGtMsrvNfybJ7pWoZRQz1PupJMfafn4kyc6hsU+0ep9O8sGh/lXzuE+yNcn9SZ5M8kSSG1p/z/t5ppqXd19X1ar8YfBk7bPAu4CzgUeBC1d6XYtY3/PA+af1/Q2wr7X3ATe19k7gX4EAO4AHV3r9I9b4fmA78Ph8awTOBZ5rlxtae8NK1zaHej8F/Ok0cy9sj+lzgAvaY/2s1fa4BzYB21v77cB/tdp63s8z1bys+3o1H7n//8ccVNX/AKc+5qBnu4ADrX0AuGqo/44aeABYn2TTSixwLqrqG8Arp3XPtcYPAoeq6pWqehU4BFyx9Kufuxnqncku4M6q+klV/TcwxeAxv6oe91V1vKq+1do/AJ5i8C72nvfzTDXPZEn29WoO9+k+5uBMv8DVpoCvJXm4fVQDwMaqOt7aLwEbW7un38Vca+yh9o+2UxC3nzo9QYf1JhkHLgYeZI3s59NqhmXc16s53Hv3vqraDlwJXJ/k/cODNfh7ruvXsa6FGoFbgV8HLgKOA3+7sstZGkneBnwR+FhVfX94rNf9PE3Ny7qvV3O4d/0xB1V1rF2eAL7M4E+0l0+dbmmXJ9r0nn4Xc61xVddeVS9X1U+r6mfAPzLYz9BRvUnewiDkPltVX2rdXe/n6Wpe7n29msO92485SPLWJG8/1QYuBx5nUN+pVwnsBu5u7YPAte2VBjuA14b+5F1t5lrjV4HLk2xof+Ze3vpWhdOeG/l9BvsZBvVek+ScJBcA24Bvssoe90kC3AY8VVWfHhrqdj/PVPOy7+uVfmZ5gc9K72TwTPSzwCdXej2LWNe7GDwz/ijwxKnagPOAw8AzwL8D57b+MPiClGeBx4CJla5hxDo/x+DP0/9lcD5xz3xqBP6YwZNQU8B1K13XHOv9TKvnSPuHu2lo/idbvU8DVw71r5rHPfA+BqdcjgCPtJ+dne/nmWpe1n3txw9IUodW82kZSdIMDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUof8Dhz0GWCovLmwAAAAASUVORK5CYII=\n" + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(num_pos[0], bins=[0,10,20,30,40,100,500,1000,1500,2000,2500])\n", + "num_pos_9 = np.array(num_pos[0])\n", + "indices_pos_9 = np.argwhere((num_pos_9 <= 20) & (num_pos_9 > 10))" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAR50lEQVR4nO3dbYxcV33H8e+vTghVoU1CtpZrW90Arqq0Uk20TVOBKpoISExVBwlQUFUsGsmtFCRQn3Doi1KpkUJVSIvURjJNiqkoIeJBsUgomBCEeEHChhoTJ02zgFFsmXhLQgChpk3498Ucl8Hsw+zOPrDH3480mnPPOXfmHN/1b2fP3JmbqkKS1JefWu8BSJJWnuEuSR0y3CWpQ4a7JHXIcJekDp2z3gMAuOiii2pycnK9hyFJG8oDDzzwX1U1MVfbT0S4T05OMj09vd7DkKQNJck35mtzWUaSOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjq04cN9ct9dTO67a72HIUk/UTZ8uEuSfpzhLkkdMtwlqUMjh3uSTUn+PcnH2/bFSe5LMpPkQ0me0+rPa9szrX1ydYYuSZrPUl65vwV4eGj7ncDNVfVi4EngulZ/HfBkq7+59ZMkraGRwj3JNuDVwD+17QBXAB9uXQ4A17Ty7rZNa7+y9ZckrZFRX7n/HfDnwA/a9guAb1fVM237OLC1lbcCjwG09qda/x+RZG+S6STTs7Ozyxy+JGkui4Z7kt8BTlXVAyv5xFW1v6qmqmpqYmLOq0RJkpZplMvsvRT43SS7gOcCPwv8PXB+knPaq/NtwInW/wSwHTie5Bzg54BvrfjIJUnzWvSVe1XdUFXbqmoSuBb4TFX9HnAv8NrWbQ9wZysfbNu09s9UVa3oqCVJCxrnPPe3AX+cZIbBmvqtrf5W4AWt/o+BfeMNUZK0VKMsy/y/qvos8NlW/hpw2Rx9/ht43QqMTZK0TH5CVZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUoVEukP3cJPcn+XKSo0n+qtW/L8nXkxxut52tPknek2QmyZEkl672JCRJP2qUKzE9DVxRVd9Lci7w+SSfaG1/VlUfPqP/1cCOdvsN4JZ2L0laI6NcILuq6ntt89x2W+iC17uB97f9vgCcn2TL+EOVJI1qpDX3JJuSHAZOAYeq6r7WdGNberk5yXmtbivw2NDux1vdmY+5N8l0kunZ2dkxpiBJOtNI4V5Vz1bVTmAbcFmSXwVuAH4Z+HXgQuBtS3niqtpfVVNVNTUxMbHEYUuSFrKks2Wq6tvAvcBVVXWyLb08DfwzcFnrdgLYPrTbtlYnSVojo5wtM5Hk/Fb+aeAVwH+cXkdPEuAa4MG2y0Hgje2smcuBp6rq5KqMXpI0p1HOltkCHEiyicEvgzuq6uNJPpNkAghwGPij1v9uYBcwA3wfeNPKD1uStJBFw72qjgAvmaP+inn6F3D9+EOTJC2Xn1CVpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHVolMvsPTfJ/Um+nORokr9q9RcnuS/JTJIPJXlOqz+vbc+09snVnYIk6UyjvHJ/Griiqn4N2Alc1a6N+k7g5qp6MfAkcF3rfx3wZKu/ufWTJK2hRcO9Br7XNs9ttwKuAD7c6g8wuEg2wO62TWu/sl1EW5K0RkZac0+yKclh4BRwCPgq8O2qeqZ1OQ5sbeWtwGMArf0p4AVzPObeJNNJpmdnZ8ebhSTpR4wU7lX1bFXtBLYBlwG/PO4TV9X+qpqqqqmJiYlxH06SNGRJZ8tU1beBe4HfBM5Pck5r2gacaOUTwHaA1v5zwLdWZLSSpJGMcrbMRJLzW/mngVcADzMI+de2bnuAO1v5YNumtX+mqmolBy1JWtg5i3dhC3AgySYGvwzuqKqPJ3kIuD3JXwP/Dtza+t8K/EuSGeAJ4NpVGLckaQGLhntVHQFeMkf91xisv59Z/9/A61ZkdJKkZfETqpLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDo1ymb3tSe5N8lCSo0ne0urfkeREksPttmtonxuSzCR5JMmrVnMCkqQfN8pl9p4B/qSqvpTk+cADSQ61tpur6m+HOye5hMGl9X4F+AXg00l+qaqeXcmBS5Lmt+gr96o6WVVfauXvMrg49tYFdtkN3F5VT1fV14EZ5rgcnyRp9SxpzT3JJIPrqd7Xqt6c5EiS25Jc0Oq2Ao8N7XacOX4ZJNmbZDrJ9Ozs7JIHfqbJfXeN/RiS1IuRwz3J84CPAG+tqu8AtwAvAnYCJ4F3LeWJq2p/VU1V1dTExMRSdpUkLWKkcE9yLoNg/0BVfRSgqh6vqmer6gfAe/nh0ssJYPvQ7ttanSRpjYxytkyAW4GHq+rdQ/Vbhrq9BniwlQ8C1yY5L8nFwA7g/pUbsiRpMaOcLfNS4PeBryQ53OreDrwhyU6ggGPAHwJU1dEkdwAPMTjT5nrPlJGktbVouFfV54HM0XT3AvvcCNw4xrgkSWPwE6qS1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA6Ncpm97UnuTfJQkqNJ3tLqL0xyKMmj7f6CVp8k70kyk+RIkktXexKSpB81yiv3Z4A/qapLgMuB65NcAuwD7qmqHcA9bRvgagbXTd0B7AVuWfFRS5IWtGi4V9XJqvpSK38XeBjYCuwGDrRuB4BrWnk38P4a+AJw/hkX05YkrbIlrbknmQReAtwHbK6qk63pm8DmVt4KPDa02/FWd+Zj7U0ynWR6dnZ2icOWJC1k5HBP8jzgI8Bbq+o7w21VVUAt5Ymran9VTVXV1MTExFJ2lSQtYqRwT3Iug2D/QFV9tFU/fnq5pd2favUngO1Du29rdZKkNTLK2TIBbgUerqp3DzUdBPa08h7gzqH6N7azZi4HnhpavpEkrYFzRujzUuD3ga8kOdzq3g7cBNyR5DrgG8DrW9vdwC5gBvg+8KYVHbEkaVGLhntVfR7IPM1XztG/gOvHHJckaQx+QlWSOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1KFRLrN3W5JTSR4cqntHkhNJDrfbrqG2G5LMJHkkyatWa+CSpPmN8sr9fcBVc9TfXFU72+1ugCSXANcCv9L2+cckm1ZqsJKk0Swa7lX1OeCJER9vN3B7VT1dVV9ncB3Vy8YYnyRpGcZZc39zkiNt2eaCVrcVeGyoz/FW92OS7E0ynWR6dnZ2jGH80OS+u1bkcSRpo1tuuN8CvAjYCZwE3rXUB6iq/VU1VVVTExMTyxyGJGkuywr3qnq8qp6tqh8A7+WHSy8ngO1DXbe1OknSGlpWuCfZMrT5GuD0mTQHgWuTnJfkYmAHcP94Q5QkLdU5i3VI8kHg5cBFSY4Dfwm8PMlOoIBjwB8CVNXRJHcADwHPANdX1bOrM/T5Lbb2fuymV6/RSCRpfSwa7lX1hjmqb12g/43AjeMMSpI0Hj+hKkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nq0KLhnuS2JKeSPDhUd2GSQ0kebfcXtPokeU+SmSRHkly6moOXJM1tlFfu7wOuOqNuH3BPVe0A7mnbAFczuG7qDmAvcMvKDFOStBSLhntVfQ544ozq3cCBVj4AXDNU//4a+AJw/hkX05YkrYHlrrlvrqqTrfxNYHMrbwUeG+p3vNX9mCR7k0wnmZ6dnV3mMCRJcxn7DdWqKqCWsd/+qpqqqqmJiYlxhyFJGrLccH/89HJLuz/V6k8A24f6bWt1a2Zy311r+XSS9BNpueF+ENjTynuAO4fq39jOmrkceGpo+UaStEbOWaxDkg8CLwcuSnIc+EvgJuCOJNcB3wBe37rfDewCZoDvA29ahTFLkhaxaLhX1Rvmabpyjr4FXD/uoCRJ4/ETqpLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHVo0e9z79FGvxTfsZtevd5DkPQTbqxwT3IM+C7wLPBMVU0luRD4EDAJHANeX1VPjjdMSdJSrMSyzG9X1c6qmmrb+4B7qmoHcE/bliStodVYc98NHGjlA8A1q/AckqQFjBvuBXwqyQNJ9ra6zVV1spW/CWyea8cke5NMJ5menZ0dcxiSpGHjvqH6sqo6keTngUNJ/mO4saoqSc21Y1XtB/YDTE1NzdlHkrQ8Y71yr6oT7f4U8DHgMuDxJFsA2v2pcQcpSVqaZYd7kp9J8vzTZeCVwIPAQWBP67YHuHPcQUqSlmacZZnNwMeSnH6cf62qf0vyReCOJNcB3wBeP/4wJUlLsexwr6qvAb82R/23gCvHGZQkaTx+/YAkdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUPjXiB7XkmuAv4e2AT8U1XdtFrPdbaZ3HfXeg9hzR276dXrPQRpQ1mVcE+yCfgH4BXAceCLSQ5W1UOr8Xzq39n4C+1s5C/xlbNar9wvA2bapfhIcjuwGzDcJc3rbPwlvlq/0FYr3LcCjw1tHwd+Y7hDkr3A3rb5vSSPLPO5LgL+a5n7blTO+ezgnM8CeedYc/7F+RpWbc19MVW1H9g/7uMkma6qqRUY0obhnM8OzvnssFpzXq2zZU4A24e2t7U6SdIaWK1w/yKwI8nFSZ4DXAscXKXnkiSdYVWWZarqmSRvBj7J4FTI26rq6Go8FyuwtLMBOeezg3M+O6zKnFNVq/G4kqR15CdUJalDhrskdWhDh3uSq5I8kmQmyb71Hs9KSnIsyVeSHE4y3eouTHIoyaPt/oJWnyTvaf8OR5Jcur6jH02S25KcSvLgUN2S55hkT+v/aJI96zGXUc0z53ckOdGO9eEku4babmhzfiTJq4bqN8TPfpLtSe5N8lCSo0ne0uq7Pc4LzHltj3NVbcgbgzdqvwq8EHgO8GXgkvUe1wrO7xhw0Rl1fwPsa+V9wDtbeRfwCSDA5cB96z3+Eef4W8ClwIPLnSNwIfC1dn9BK1+w3nNb4pzfAfzpHH0vaT/X5wEXt5/3TRvpZx/YAlzays8H/rPNq9vjvMCc1/Q4b+RX7v//FQdV9T/A6a846Nlu4EArHwCuGap/fw18ATg/yZb1GOBSVNXngCfOqF7qHF8FHKqqJ6rqSeAQcNXqj3555pnzfHYDt1fV01X1dWCGwc/9hvnZr6qTVfWlVv4u8DCDT7B3e5wXmPN8VuU4b+Rwn+srDhb6B9xoCvhUkgfaVzUAbK6qk638TWBzK/f0b7HUOfYy9ze3ZYjbTi9R0Nmck0wCLwHu4yw5zmfMGdbwOG/kcO/dy6rqUuBq4PokvzXcWIO/57o+j/VsmGNzC/AiYCdwEnjX+g5n5SV5HvAR4K1V9Z3htl6P8xxzXtPjvJHDveuvOKiqE+3+FPAxBn+iPX56uaXdn2rde/q3WOocN/zcq+rxqnq2qn4AvJfBsYZO5pzkXAYh94Gq+mir7vo4zzXntT7OGzncu/2KgyQ/k+T5p8vAK4EHGczv9FkCe4A7W/kg8MZ2psHlwFNDf/JuNEud4yeBVya5oP2Z+8pWt2Gc8f7IaxgcaxjM+dok5yW5GNgB3M8G+tlPEuBW4OGqevdQU7fHeb45r/lxXu93lsd8V3oXg3eivwr8xXqPZwXn9UIG74x/GTh6em7AC4B7gEeBTwMXtvowuDjKV4GvAFPrPYcR5/lBBn+e/i+D9cTrljNH4A8YvAk1A7xpvee1jDn/S5vTkfafd8tQ/79oc34EuHqofkP87AMvY7DkcgQ43G67ej7OC8x5TY+zXz8gSR3ayMsykqR5GO6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ/8H2vJR9Ye8H/sAAAAASUVORK5CYII=\n" + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "plt.hist(num_neg[0], bins=[0,10,20,30,40,100,500,1000,1500,2000,2500])\n", + "num_neg_9 = np.array(num_neg[0])\n", + "indices_neg_9 = np.argwhere((num_neg_9 <= 20) & (num_neg_9 > 10))" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[642, 1027, 258, 134, 394, 1046, 279, 536, 791, 923, 798, 676, 682, 811, 556, 812, 432, 694, 55, 314, 957, 1093, 197, 591, 93, 735, 357, 1017, 634, 638, 767]\n" + ] + } + ], + "source": [ + "# These are the counterfactuals that have between 10 and 20 pos/neg perturbations > 0.9*max\n", + "overlap = list(set(np.squeeze(indices_pos_9)).intersection(set( np.squeeze(indices_neg_9))))\n", + "print(overlap)" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "NameError", + "evalue": "name 'pickle' is not defined", + "output_type": "error", + "traceback": [ + "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", + "\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)", + "\u001B[0;32m\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mthresholds_9\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mpickle\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mload\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mopen\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m'{}threshold_{}.pkl'\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mformat\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mparent\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;36m0.9\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m'rb'\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 2\u001B[0m \u001B[0mgenes_pos\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 3\u001B[0m \u001B[0mgenes_neg\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0mperturb_vector\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 5\u001B[0m \u001B[0mcf_class\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", + "\u001B[0;31mNameError\u001B[0m: name 'pickle' is not defined" + ] + } + ], + "source": [ + "thresholds_9 = pickle.load(open('{}threshold_{}.pkl'.format(parent,0.9), 'rb'))\n", + "genes_pos = []\n", + "genes_neg = []\n", + "perturb_vector=[]\n", + "cf_class = []\n", + "for i in overlap:\n", + " genes_pos.append(thresholds_9['positive threshold indices'][i])\n", + " genes_neg.append(thresholds_9['negative threshold indices'][i])\n", + " perturb_vector.append(thresholds_9['perturbation vector'][i])\n", + " cf_class.append(thresholds_9['counterfactual class'][i])" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "18 13\n" + ] + } + ], + "source": [ + "# These are the set of genes (indices) that have been perturbed more than 0.9*max perturbation\n", + "# in the counterfactual example \n", + "\n", + "# split by class\n", + "genes_pos_0 = []\n", + "genes_neg_0 = []\n", + "perturb_vector_0=[]\n", + "genes_pos_1 = []\n", + "genes_neg_1 = []\n", + "perturb_vector_1=[]\n", + "for i,j,k,l in zip(genes_pos, genes_neg, perturb_vector, cf_class):\n", + " if l==0:\n", + " genes_pos_0.append(i)\n", + " genes_neg_0.append(j)\n", + " perturb_vector_0.append(k)\n", + " else:\n", + " genes_pos_1.append(i)\n", + " genes_neg_1.append(j)\n", + " perturb_vector_1.append(k)\n", + "print(len(genes_neg_0), len(genes_neg_1))" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "# cluster the counterfactual examples\n", + "num_clusters = 1\n", + "kmeans_0 = KMeans(n_clusters=num_clusters, random_state=0).fit(perturb_vector_0)\n", + "kmeans_1 = KMeans(n_clusters=num_clusters, random_state=0).fit(perturb_vector_1)" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "outputs": [], + "source": [ + "for i in range(len(kmeans_0.cluster_centers_)):\n", + " diff_1=kmeans_1.cluster_centers_[i]\n", + " max_value = np.max(np.abs(diff_1))\n", + " ind_pos = np.where(diff_1 > 0.9*max_value)\n", + " ind_neg = np.where(diff_1 < -0.9*max_value)\n", + " pickle.dump([ind_pos, ind_neg], open(\"{}cf_class_1.pkl\".format(parent), \"wb\"))\n", + " diff_0=kmeans_1.cluster_centers_[i]\n", + " max_value = np.max(np.abs(diff_0))\n", + " ind_pos = np.where(diff_1 > 0.9*max_value)\n", + " ind_neg = np.where(diff_1 < -0.9*max_value)\n", + " pickle.dump([ind_pos, ind_neg],open(\"{}cf_class_0.pkl\".format(parent), \"wb\"))" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 39, + "outputs": [ + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAPSklEQVR4nO3dX4hc533G8edZuwlsmoItLapqe3cUowSUQpUwiBbcEFM3kQ1FcUpAZii+CGwKNvTfjcNe1DeCUuqGXKRux0TYNFObQmssGhMnNiWmUOqMUkVexaiWHa0soUir+KIpW5za++vFOasdrWalnZ05c8555/uB4cx5z+68P46PHr96z6szjggBANI0VXYBAIDiEPIAkDBCHgASRsgDQMIIeQBI2K1lF9Br586d0Wg0yi4DAGrl+PHjVyJipt+xSoV8o9FQt9stuwwAqBXbS5sdY7oGABJGyANAwgh5AEgYIQ8ACSPkASBhhDxQV52O1GhIU1PZttMpuyJUUKWWUALYok5Hmp+XVlay/aWlbF+SWq3y6kLlMJIH6mhhYT3g16ysZO1AD0IeqKNz5wZrx8Qi5IE6mp0drB0Ti5AH6ujIEWl6+tq26emsHehByAN11GpJ7bY0NyfZ2bbd5qYrrsPqGqCuWi1CHTfFSB4AEkbIA0DCCHkASNhIQt72UduXbS/2tD1u+4LtE/nrgVH0BQDYulGN5J+WdLBP+9ciYn/+enFEfQEAtmgkIR8Rr0p6dxSfBQAYnaLn5B+1fTKfzrmt3w/Ynrfdtd1dXl4uuBwAmCxFhvyTku6WtF/SRUlP9PuhiGhHRDMimjMzfb9sHACwTYWFfERciogPImJV0lOSDhTVFwCgv8JC3vbunt0HJS1u9rMAgGKM5LEGtp+V9FlJO22fl/Tnkj5re7+kkHRW0ldG0RcAYOtGEvIR8VCf5m+O4rMBANvHv3gFgIQR8gCQMEIeABJGyANAwgh5AEgYIQ8ACSPkASBhhDwAJIyQB4CEEfIAkDBCHgASRsgDQMIIeQBIGCEPAAkj5AEgYYQ8ACSMkAeAhBHyAJAwQh4AEkbIA0DCCHkASBghDwAJI+QBIGGEPAAkbCQhb/uo7cu2F3vabrf9Pdtv5tvbRtEXAGDrRjWSf1rSwQ1tj0l6JSL2Snol3wcAjNFIQj4iXpX07obmQ5Keyd8/I+kLo+gLALB1Rc7J74qIi/n7n0ra1e+HbM/b7truLi8vF1gOAEyesdx4jYiQFJsca0dEMyKaMzMz4ygHACZGkSF/yfZuScq3lwvsCwDQR5Ehf0zSw/n7hyW9UGBfAIA+RrWE8llJ/y7pE7bP2/6ypL+Q9Lu235R0X74PABijW0fxIRHx0CaHfmcUnw8A2B7+xSsAJIyQB4CEEfIAkDBCHgASRsgDQMIIeQBIGCEPAAkj5AEgYYQ8ACSMkAeAhBHyAJAwQh6om05HajSkqals2+mUXREqbCQPKAMwJp2OND8vraxk+0tL2b4ktVrl1YXKYiQP1MnCwnrAr1lZydqBPgh5oE7OnRusHROPkAfqZHZ2sHZMPEIeqJMjR6Tp6WvbpqezdqAPQh6ok1ZLareluTnJzrbtNjddsSlW1wB102oR6tgyRvKoF9aIAwNhJI/6YI04MDBG8qgP1ogDAyPkUR+sEQcGRsijPlgjDgyMkEd9sEYcGFjhN15tn5X0c0kfSHo/IppF94lErd1cXVjIpmhmZ7OA56YrsKlxra65NyKujKkvpIw14sBAmK4BgISNI+RD0ndtH7c9v/Gg7XnbXdvd5eXlMZQDAJNjHCF/T0R8WtL9kh6x/ZnegxHRjohmRDRnZmbGUA4ATI7CQz4iLuTby5Kel3Sg6D4BAJlCQ972R2x/dO29pM9JWiyyTwDAuqJH8rsk/ZvtH0l6TdK3I+I7BfeZFh7IBWAIhS6hjIi3Jf1GkX0kjQdyARgSSyirjAdyARgSIV9lPJALwJDqH/Ipz1nzQC4AQ6p3yK/NWS8tSRHrc9apBD0P5AIwpHqHfOpz1nxpM4AhOSLKruGqZrMZ3W53678wNZWN4DeypdXV0RUGABVm+/hmT/it90ieOWsAuKF6hzxz1kDaiw8wtHqHPHPWmHSpLz7A0Oo9Jw9MukYjC/aN5uaks2fHXQ1Kku6cPDDp+AdzuAlCHqgzFh/gJgh5oM5YfICbIOSBOmPxAW6i0EcNAxiDVotQx6YYyQNAwgh5AEgYIQ8ACSPkASBhhDwAJIyQB4CEEfIAkDBCHgASRsgDQMIKD3nbB22ftn3G9mNF9wcAWFdoyNu+RdI3JN0vaZ+kh2zvK7JPAMC6okfyBySdiYi3I+IXkp6TdKjgPgEAuaJD/g5J7/Tsn8/brrI9b7tru7u8vFxwOQAwWUq/8RoR7YhoRkRzZmam7HIAIClFh/wFSXf17N+ZtwEAxqDokP+BpL2299j+kKTDko4V3CcAIFfol4ZExPu2H5X0kqRbJB2NiFNF9gkAWFf4N0NFxIuSXiy6HwDA9Uq/8QoAKA4hDwAJI+QBIGGEPAAkjJAHgIQR8gCQMEIeABJGyANAwgh5AEgYIQ8ACSPkASBhhDwAJIyQR/E6HanRkKamsm2nU3ZFwMQo/CmUmHCdjjQ/L62sZPtLS9m+JLVa5dUFTAhG8ijWwsJ6wK9ZWcnaARSOkEexzp0brB3ASBHyKNbs7GDtAEaKkEexjhyRpqevbZueztoBFI6QR7FaLandlubmJDvbttvp33RlRREqwhFRdg1XNZvN6Ha7ZZcBDGfjiqI1O3ZIX/96+v+Dw9jZPh4RzX7HGMkDo9ZvRZEk/exnWfgzqscYEfLAqN1o5RDLRzFmhDwwajdbOcTyUYwRIY96q+INzn4rinqxfBRjRMijvtZucC4tSRHrj0woO+jXVhTt2HH9MZaPYqOCByqFhbztx21fsH0ifz1QVF+YUFV+ZEKrJV25In3rW5O3fBRbN4aBSmFLKG0/Lul/IuKvtvo7LKHEQKamsj8YG9nS6ur46wEG1Whkwb7R3Jx09uyWP4YllEgTj0xA3Y3h2U5Fh/yjtk/aPmr7tn4/YHvedtd2d3l5ueBykBQemYC6G8NAZaiQt/2y7cU+r0OSnpR0t6T9ki5KeqLfZ0REOyKaEdGcmZkZphxMmkl9ZALSMYaBylAhHxH3RcSv93m9EBGXIuKDiFiV9JSkA6MpGZKquXSwDK1WNne5upptCXjUyRgGKoV9M5Tt3RFxMd99UNJiUX1NHL5tCUhHq1Xon9si5+T/0vbrtk9KulfSnxTY12Sp8tJBAJVS2Eg+Iv6gqM+eeHzbEoAtYgllHbF0EMAWEfJ1xNJBAFtEyNcRSwcBbFFhc/IoWMF35AGkgZE8ACSMkAeAhBHyAJAwQh4AEkbIA0DCCHkASBghDwAJI+QBIGGEPAAkjJAHgIQR8gCQMEIeABJGyANAwgh5AEgYIQ8ACSPkASBhhDwAJIyQB4CEEfIAkDBCHgASNlTI2/6S7VO2V203Nxz7qu0ztk/b/vxwZQIAtuPWIX9/UdIXJf1db6PtfZIOS/qkpF+T9LLtj0fEB0P2BwAYwFAj+Yh4IyJO9zl0SNJzEfFeRPxE0hlJB4bpCwAwuKLm5O+Q9E7P/vm87Tq25213bXeXl5cLKgcAJtNNp2tsvyzpV/scWoiIF4YtICLaktqS1Gw2Y9jPAwCsu2nIR8R92/jcC5Lu6tm/M28DAIxRUdM1xyQdtv1h23sk7ZX0WkF9AQA2MewSygdtn5f0W5K+bfslSYqIU5L+UdKPJX1H0iOsrAGA8RtqCWVEPC/p+U2OHZF0ZJjPBwAMh3/xCgAJI+QBIGGEPAAkjJAHgIQR8gCQMEIeABJGyANAwgh5AEgYIQ8ACSPkASBhhDwAJIyQx2TqdKRGQ5qayradTtkVAYUY9jtegfrpdKT5eWllJdtfWsr2JanVKq8uoACM5DF5FhbWA37NykrWDiSGkMfkOXdusHagxgh5TJ7Z2cHagRoj5DF5jhyRpqevbZueztqBxBDymDytltRuS3Nzkp1t221uuiJJrK7BZGq1CHVMBEbyAJAwQh4AEkbIA0DCCHkASBghDwAJc0SUXcNVtpclLZVdRx87JV0pu4hNVLk2qdr1Udv2UNv2FFnbXETM9DtQqZCvKtvdiGiWXUc/Va5NqnZ91LY91LY9ZdXGdA0AJIyQB4CEEfJb0y67gBuocm1Steujtu2htu0ppTbm5AEgYYzkASBhhDwAJIyQvwHbX7J9yvaq7WZPe8P2/9o+kb/+tiq15ce+avuM7dO2Pz/u2jbU8rjtCz3n6oEy68lrOpifmzO2Hyu7nl62z9p+PT9X3QrUc9T2ZduLPW232/6e7Tfz7W0Vqq0S15vtu2z/q+0f539O/yhvH/u5I+RvbFHSFyW92ufYWxGxP3/94ZjrkjapzfY+SYclfVLSQUl/Y/uW8Zd3ja/1nKsXyywkPxffkHS/pH2SHsrPWZXcm5+rKqz3flrZddTrMUmvRMReSa/k+2V4WtfXJlXjentf0p9FxD5Jvynpkfw6G/u5I+RvICLeiIjTZdfRzw1qOyTpuYh4LyJ+IumMpAPjra7SDkg6ExFvR8QvJD2n7Jyhj4h4VdK7G5oPSXomf/+MpC+MtajcJrVVQkRcjIgf5u9/LukNSXeohHNHyG/fHtv/afv7tn+77GJ63CHpnZ7983lbmR61fTL/63Upf7XvUcXz0yskfdf2cdvzZReziV0RcTF//1NJu8ospo8qXW+y3ZD0KUn/oRLO3cSHvO2XbS/2ed1odHdR0mxEfErSn0r6B9u/UpHaxu4mdT4p6W5J+5WdtydKLbb67omITyubTnrE9mfKLuhGIluDXaV12JW63mz/sqR/kvTHEfHfvcfGde4m/uv/IuK+bfzOe5Ley98ft/2WpI9LGumNsu3UJumCpLt69u/M2wqz1TptPyXpX4qsZQvGfn4GEREX8u1l288rm17qd0+oTJds746Ii7Z3S7pcdkFrIuLS2vuyrzfbv6Qs4DsR8c9589jP3cSP5LfD9szazUzbH5O0V9Lb5VZ11TFJh21/2PYeZbW9VlYx+YW85kFlN4zL9ANJe23vsf0hZTepj5VckyTJ9kdsf3TtvaTPqfzz1c8xSQ/n7x+W9EKJtVyjKtebbUv6pqQ3IuKvew6N/9xFBK9NXsoukvPKRu2XJL2Ut/++pFOSTkj6oaTfq0pt+bEFSW9JOi3p/pLP4d9Lel3SSWUX+O4K/Hd9QNJ/5edooex6eur6mKQf5a9TVahN0rPKpj3+L7/evixph7KVIW9KelnS7RWqrRLXm6R7lE3FnMxz4kR+3Y393PFYAwBIGNM1AJAwQh4AEkbIA0DCCHkASBghDwAJI+QBIGGEPAAk7P8B5yRZLr0XHV8AAAAASUVORK5CYII=\n" + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAOC0lEQVR4nO3dT4gc553G8eeRTAITcnDQrCJsz4wJ2gXnsCY0Oi1Lwjqx1hfFgSwKzWLYwORg791hDgkEQQgJWQjJQgeEfZiN8cVYxCEb2xexsIvdAm9WchAWzows4VhjcgkMJNj57aF6rNGkezR/6q3q/tX3A0N1vTXq922X+6Hmfd96yxEhAEBOR9puAACgHEIeABIj5AEgMUIeABIj5AEgsXvabsB2x44di6WlpbabAQAz5dKlS+9HxPy4Y1MV8ktLSxoOh203AwBmiu31ScforgGAxAh5AEiMkAeAxAh5AEiMkAeAxAh5AGjT6qq0tCQdOVJtV1drffupmkIJAJ2yuiotL0ubm9X++nq1L0n9fi1VcCUPAG1ZWbkd8Fs2N6vymhDyANCW69f3V34AhDwAtGVhYX/lB0DIA0Bbzp2T5ubuLJubq8prQsgDQFv6fWkwkBYXJbvaDga1DbpKzK4BgHb1+7WG+k5cyQNAYoQ8ACRGyANAYoQ8ACRGyANAYoQ8ACRGyANAYoQ8ACRGyANAYoQ8ACRGyAO4U+EnFaFZrF0D4LYGnlSEZnElD+C2Bp5UhGYR8gBua+BJRWhWLSFv+7ztW7Yvbyv7lO2Xbb812t5bR10ACmrgSUVoVl1X8s9IOr2j7GlJr0bESUmvjvYBTLMGnlSEZtUS8hFxUdLvdxSfkfTs6PWzkr5cR10ACmrgSUVoVsnZNccj4t3R699JOl6wLgB1KfykIjSrkYHXiAhJMe6Y7WXbQ9vDjY2NJpoDAJ1RMuTfs31CkkbbW+N+KSIGEdGLiN78/HzB5gBA95QM+QuSnhi9fkLSiwXrAgCMUdcUyp9J+m9Jf2P7hu2vS/qupC/afkvSI6N9AECDahl4jYivTTj0D3W8PwDgYLjjFQASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBILF7Sldge03SHyR9KOmDiOiVrhMAUCke8iNfiIj3G6oLADDS7e6a1VVpaUk6cqTarq623SIAqFUTIR+SfmX7ku3lnQdtL9se2h5ubGw00JyR1VVpeVlaX5ciqu3yMkEPIBVHRNkK7Psi4qbtv5L0sqR/jYiL43631+vFcDgs2p6PLC1Vwb7T4qK0ttZMGwCgBrYvTRrvLH4lHxE3R9tbkl6QdKp0nXty/fr+yqcZ3U4AJiga8rY/YfuTW68lfUnS5ZJ17tnCwv7KpxXdTgB2UfpK/rik/7L9v5Jek/RSRPyycJ17c+6cNDd3Z9ncXFU+S1ZWpM3NO8s2N6tyAJ1XdAplRLwt6W9L1nFg/X61XVmpumgWFqqA3yqfFZm6nQDUrql58tOp35+9UN9pYWH8APKsdTsBKKLb8+QzyNLtBKAIQn7W9fvSYFBN/bSr7WCwt79QmJUDpNft7posDtLttDUrZ2vQdmtWztb7AUiBK/muYlYO0AmEfFcxKwfoBEK+q7LcDAZgV4R8VzErB+gEQr6rDjMrB8DMYHZNl2W4GQzArnJcyTPfGwDGmv0reeZ7A8BEs38lz3xvAJho9kOe+d4AMNHshzzzvQFgotkPeeZ7A8BEsx/yzPcGgIlmf3aNxHxvAJhg9q/k68A8+27gPKODclzJHwbz7LuB84yOckS03YaP9Hq9GA6HzVa6tDT+GamLi9LaWrNtQTmcZyRm+1JE9MYdo7uGefbdwHlGRxHyzLPvBs4zOoqQZ559N3Ce0VGEPPPsu4HzjI5i4BUAZhwDrwDQUYQ8ACRGyANAYoQ8ACRWPORtn7Z91fY120+Xrg/ALli/p3OKrl1j+6ikH0v6oqQbkl63fSEi3ixZL4AxWL+nk0pfyZ+SdC0i3o6IP0l6TtKZwnUCGIfnIXdS6ZC/T9I72/ZvjMo+YnvZ9tD2cGNjo3BzgA5j/Z5Oan3gNSIGEdGLiN78/HzbzQHyYv2eTiod8jclPbBt//5RGWZJE4N1DAiWx/o9nVQ65F+XdNL2g7Y/JumspAuF60Sdtgbr1teliNuDdXWGcBN1gPV7Oqr42jW2H5P0b5KOSjofERMvG1i7Zgo18bANHugBHEqra9dExC8i4q8j4jO7BTxqUKLLo4nBOgYEgWJaH3hFTUp1eTQxWMeAIFAMIZ9FqTnQTQzWMSAIFEPIZ1Gqy6OJwToGBIFieGhIFgxeAp3FQ0O6YJq6PHYbAGY+PNCooguUoUFbXRsrK1UXzcJCFfBNd3nstgiWxAJZQMPorkG9dus2kuhSAgrYrbuGK3nU6yADwMyHB4qhTx712m3OO/PhgcYR8qjXbgPA0zQ4DHQE3TWo114GgNseHAY6hIFXAJhxzJMHgI4i5AEgMUIeABIj5AEgMUIeABIj5AEgMUIeABIj5DEZywIDM487XjHebksGc4cqMDO4ksd4pZ4ZC6BRhDzGK/XMWACNIuQxHssCAykQ8hiPZYGBFAh5jNfvS4NB9Wg+u9oOBgy6AjOG2TWYrN8n1IEZx5U8ACRGyANAYoQ8ACRWLORtf9v2TdtvjH4eK1UXAGC80gOvP4yI7xeuAwAwAd01AJBY6ZB/yvavbZ+3fe+4X7C9bHtoe7ixsVG4OQDQLY6Ig/9j+xVJnx5zaEXS/0h6X1JI+o6kExHxL7u9X6/Xi+FweOD2AEAX2b4UEb1xxw7VJx8Rj+yxAT+V9PPD1AUA2L+Ss2tObNt9XNLlUnUBAMYrObvme7YfVtVdsybpGwXrAgCMUSzkI+KfS703AGBvmEIJAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8sCsWV2VlpakI0eq7epq2y3CFLun7QYA2IfVVWl5WdrcrPbX16t9Ser322sXphZX8sAsWVm5HfBbNjercmCMQ4W87a/avmL7z7Z7O4590/Y121dtP3q4ZgKQJF2/vr9ydN5hr+QvS/qKpIvbC20/JOmspM9KOi3pJ7aPHrIuAAsL+ytH5x0q5CPiNxFxdcyhM5Kei4g/RsRvJV2TdOowdQGQdO6cNDd3Z9ncXFUOjFGqT/4+Se9s278xKvsLtpdtD20PNzY2CjUHSKLflwYDaXFRsqvtYMCgKya66+wa269I+vSYQysR8eJhGxARA0kDSer1enHY9wPS6/cJdezZXUM+Ih45wPvelPTAtv37R2UAgAaV6q65IOms7Y/bflDSSUmvFaoL04KbdICpc6iboWw/LulHkuYlvWT7jYh4NCKu2H5e0puSPpD0ZER8ePjmYmpxkw4wlRwxPd3gvV4vhsNh283AQSwtVcG+0+KitLbWdGuATrF9KSJ6445xxyvqwU06wFQi5FEPbtIBphIhj3pwkw4wlQh51IObdICpxFLDqA836QBThyt5AEiMkAeAxAh5AEiMkAeAxAh5AEiMkAeAxAh5AHmxMirz5AEkxcqokriSB5DVysrtgN+yuVmVdwghDyAnVkaVRMgDyIqVUSUR8gCyYmVUSYQ8gKxYGVUSs2sAZMbKqFzJA0BmhDwAJEbIA0BihDwAJEbIA0Bijoi22/AR2xuS1vf468ckvV+wOdOKz90tfO5uOejnXoyI+XEHpirk98P2MCJ6bbejaXzubuFzd0uJz013DQAkRsgDQGKzHPKDthvQEj53t/C5u6X2zz2zffIAgLub5St5AMBdEPIAkNhMhbztr9q+YvvPtns7jn3T9jXbV20/2lYbm2D727Zv2n5j9PNY220qxfbp0Tm9ZvvpttvTFNtrtv9vdH6HbbenJNvnbd+yfXlb2adsv2z7rdH23jbbWMKEz137d3umQl7SZUlfkXRxe6HthySdlfRZSacl/cT20eab16gfRsTDo59ftN2YEkbn8MeS/lHSQ5K+NjrXXfGF0fnNPl/8GVXf2+2elvRqRJyU9OpoP5tn9JefW6r5uz1TIR8Rv4mIq2MOnZH0XET8MSJ+K+mapFPNtg4FnJJ0LSLejog/SXpO1blGIhFxUdLvdxSfkfTs6PWzkr7caKMaMOFz126mQn4X90l6Z9v+jVFZZk/Z/vXoT750f8qOdPG8bglJv7J9yfZy241pwfGIeHf0+neSjrfZmIbV+t2eupC3/Yrty2N+OnUFd5f/Dv8u6TOSHpb0rqQftNpYlPB3EfE5VV1VT9r++7Yb1Jao5nl3Za537d/tqXv8X0Q8coB/dlPSA9v27x+Vzay9/new/VNJPy/cnLakO697FRE3R9tbtl9Q1XV1cfd/lcp7tk9ExLu2T0i61XaDmhAR7229ruu7PXVX8gd0QdJZ2x+3/aCkk5Jea7lNxYz+p9/yuKoB6Yxel3TS9oO2P6ZqcP1Cy20qzvYnbH9y67WkLynvOZ7kgqQnRq+fkPRii21pTInv9tRdye/G9uOSfiRpXtJLtt+IiEcj4ort5yW9KekDSU9GxIdttrWw79l+WNWfsGuSvtFuc8qIiA9sPyXpPyUdlXQ+Iq603KwmHJf0gm2p+o7+R0T8st0mlWP7Z5I+L+mY7RuSviXpu5Ket/11VcuP/1N7LSxjwuf+fN3fbZY1AIDEsnTXAADGIOQBIDFCHgASI+QBIDFCHgASI+QBIDFCHgAS+3+scocs/QSQcwAAAABJRU5ErkJggg==\n" + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD8CAYAAACfF6SlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAASp0lEQVR4nO3df4hl533f8fdn15bSISGyrUEW+7tkSbsJbSwGRaGhGMutV66xnBKHNQNWEsFQkMEhBXeV+aMUulA3ELfGTtqhMlXKYEXkB7skThVZtnHyh2SP8kOxpCgeyx7tLrK1jm2lZYjStb79456VrlazuzNz79x75z7vF1zuOd/n7L3Pc7X6zNnnnHluqgpJUlv2jLsDkqTRM/wlqUGGvyQ1yPCXpAYZ/pLUIMNfkho0tPBPsjfJnyX5/W7/SJLHkqwm+a0k13X167v91a798LD6IEnanGGe+X8YeLpv/6PAx6rqR4DvAnd39buB73b1j3XHSZJGaCjhn2Q/8K+A/9HtB3gH8NvdIfcD7+u27+z26dpv746XJI3IG4b0Ov8F+AjwQ93+W4DvVdXFbv8csK/b3gecBaiqi0le7I7/9pVe/MYbb6zDhw8PqauS1IbHH3/821U1u1HbwOGf5D3AC1X1eJK3D/p6fa+7ACwAHDx4kJWVlWG9tCQ1IcnaldqGMe3zz4D3JvkG8AC96Z7/CtyQ5NIPl/3A+W77PHCg69gbgB8G/ubyF62qpaqaq6q52dkNf3BJkrZp4PCvqnuran9VHQZOAJ+rqnng88DPdofdBZzuts90+3TtnytXl5OkkdrJ+/z/HfDLSVbpzenf19XvA97S1X8ZOLmDfZAkbWBYF3wBqKovAF/otp8Fbt3gmL8D3j/M95UkbY2/4StJDTL8tX3Ly3D4MOzZ03teXh53jyRt0lCnfdSQ5WVYWID19d7+2lpvH2B+fnz9krQpnvlrexYXXw3+S9bXe3VJE8/w1/Y899zW6pImiuGv7Tl4cGt1SRPF8Nf2nDoFMzOvrc3M9OqSJp7hr+2Zn4elJTh0CJLe89KSF3ulXcK7fbR98/OGvbRLeeYvSQ0y/CWpQYa/JDXI8JekBhn+ktQgw1/S5rmY39TwVk9Jm+NiflPFM39Jm+NiflPF8Je0OS7mN1UMf0mb42J+U8Xwl7Q5LuY3VQx/SZvjYn5Txbt9JG2ei/lNjYHP/JP8QJIvJfmLJE8m+Q9d/UiSx5KsJvmtJNd19eu7/dWu/fCgfZAkbc0wpn1eAt5RVf8U+AngeJLbgI8CH6uqHwG+C9zdHX838N2u/rHuOEnSCA0c/tXzf7vdN3aPAt4B/HZXvx94X7d9Z7dP1357kgzaD0nS5g3lgm+SvUn+HHgBeBj4GvC9qrrYHXIO2Ndt7wPOAnTtLwJvGUY/JEmbM5Twr6rvV9VPAPuBW4F/NOhrJllIspJk5cKFCwP3UZL0qqHe6llV3wM+D/wUcEOSS3cT7QfOd9vngQMAXfsPA3+zwWstVdVcVc3Nzs4Os5uS1Lxh3O0zm+SGbvsfAP8CeJreD4Gf7Q67CzjdbZ/p9unaP1dVNWg/JEmbN4z7/G8G7k+yl94Pkwer6veTPAU8kOQ/An8G3Ncdfx/wv5KsAt8BTgyhD5KkLRg4/KvqCeBtG9SfpTf/f3n974D3D/q+kqTtc3kHSWqQ4S9JDTL8JalBhr8kNcjwl6QGGf6S1CDDX5IaZPhLUoMMf0lqkOEvSQ0y/CWpQYa/JDXI8JekBhn+ktQgw1+SGmT4S1KDDH9JapDhr+mzvAyHD8OePb3n5eVx90iaOMP4Dl9pciwvw8ICrK/39tfWevsA8/Pj65c0YTzz13RZXHw1+C9ZX+/VJb3C8Nd0ee65rdWlSbXD05eGv6bLwYNbq0uT6NL05doaVL06fTnEHwADh3+SA0k+n+SpJE8m+XBXf3OSh5N8tXt+U1dPko8nWU3yRJJbBu3D0HnBcPc6dQpmZl5bm5np1aXdYgTTl8M4878I/NuqOgbcBtyT5BhwEnikqo4Cj3T7AHcAR7vHAvAbQ+jD8IzgJ6520Pw8LC3BoUOQ9J6XlrzYq91lBNOXqaqhvRhAktPAJ7rH26vq+SQ3A1+oqh9N8t+77U93xz9z6bgrvebc3FytrKwMtZ9XdPhwL/Avd+gQfOMbo+mDpLYNKYeSPF5Vcxu1DXXOP8lh4G3AY8BNfYH+TeCmbnsfcLbvj53rapPBC4aSxm0E05dDC/8kPwj8DvBLVfW3/W3V++fFlv6JkWQhyUqSlQsXLgyrm9fmBUNJ4zaC6cuhhH+SN9IL/uWq+t2u/K1uuofu+YWufh440PfH93e116iqpaqaq6q52dnZYXRzc7xgKGkSzM/3pnhefrn3POTrVsO42yfAfcDTVfVrfU1ngLu67buA0331D3Z3/dwGvHi1+f6R84KhpAYMfME3yU8Dfwz8JfByV/4VevP+DwIHgTXg56rqO90Pi08Ax4F14Beq6qpXc0d6wVeSpsTVLvgOvLZPVf0JkCs0377B8QXcM+j7SpK2z9/wlaQGGf6S1CDDX5IaZPhLUoMMf0lqkOEvSQ0y/CWpQYa/JDXI8JekBhn+ktQgw1+SGmT4S1KDDH9JG1te7n2d4J49vWe/x3qqDLyqp6QptLwMCwuwvt7bX1vr7YPfbTElPPOX9HqLi68G/yXr6726poLhL+n1nntua3XtOoa/pNc7eHBrde06hr+k1zt1CmZmXlubmenVNRUMf0mvNz8PS0tw6BAkveelJS/2ThHv9pG0sfl5w36KeeYvSQ0y/CWpQUMJ/ySfSvJCkq/01d6c5OEkX+2e39TVk+TjSVaTPJHklmH0QZK0ecM68/+fwPHLaieBR6rqKPBItw9wB3C0eywAvzGkPkiSNmko4V9VXwS+c1n5TuD+bvt+4H199d+snkeBG5LcPIx+SJI2Zyfn/G+qque77W8CN3Xb+4Czfced62qSpBEZyQXfqiqgtvJnkiwkWUmycuHChR3qmSS1aSfD/1uXpnO65xe6+nngQN9x+7vaa1TVUlXNVdXc7OzsDnZTktqzk+F/Brir274LON1X/2B3189twIt900OSpBEYym/4Jvk08HbgxiTngH8P/CfgwSR3A2vAz3WHfwZ4N7AKrAO/MIw+SJI2byjhX1UfuELT7RscW8A9w3hfSdL2+Bu+ktQgw1+SGmT4a/T8YnBp7FzSWaPlF4NLE8Ezf42WXwwuTQTDX6PlF4NLE8Hw12j5xeDSRDD8NVp+Mbg0EQx/jZZfDC5NBO/20ej5xeDS2HnmL0kNMvwlqUGGvyQ1yPCXpAYZ/pLUIMNfkhpk+EtSgwx/SRqHMS9tPt3h77rxkibRpaXN19ag6tWlzUeYUdMb/hPw4UrShiZgafPpDf8J+HAlaUMTsLT59Ib/BHy4krShCVjafGzhn+R4kmeSrCY5OfQ3mIAPV5I2NAFLm48l/JPsBT4J3AEcAz6Q5NhQ32QCPlxJ2tAELG0+riWdbwVWq+pZgCQPAHcCTw3tHS59iIuLvamegwd7we9SwpImwZiXNh9X+O8DzvbtnwN+cujv4rrxkrShib3gm2QhyUqSlQsXLoy7O5I0VcYV/ueBA337+7vaK6pqqarmqmpudnZ2pJ2TpGk3rvD/MnA0yZEk1wEngDNj6oskNWcsc/5VdTHJh4CHgL3Ap6rqyXH0RZJaNLYvcK+qzwCfGdf7S1LLJvaCryRp5xj+0rRwFVttwdimfSQN0aVVbC8tZnhpFVvwd120Ic/8pWngKrbaIsNfmgauYqstMvylaeAqttoiw1+aBq5iqy0y/KVpMAFLBGt38W4faVq4iq22wDN/SWqQ4S9JDTL8JalBhr8kNcjwl6QGGf6S1CDDX5IaZPhLUoMMf0lqkOEvSQ0y/CWpQYa/JDXI8JekBg0U/knen+TJJC8nmbus7d4kq0meSfKuvvrxrraa5OQg7y9J2p5Bz/y/Avxr4Iv9xSTHgBPAjwHHgV9PsjfJXuCTwB3AMeAD3bGSpBEaaD3/qnoaIMnlTXcCD1TVS8DXk6wCt3Ztq1X1bPfnHuiOfWqQfkiStman5vz3AWf79s91tSvVJUkjdM0z/ySfBd66QdNiVZ0efpdeed8FYAHgoF9CLUlDdc0z/6p6Z1X9+AaPqwX/eeBA3/7+rnal+kbvu1RVc1U1Nzs7e+2RaPSWl+HwYdizp/e8vDzuHknapJ2a9jkDnEhyfZIjwFHgS8CXgaNJjiS5jt5F4TM71AftpOVlWFiAtTWo6j0vLPgDQNolBr3V82eSnAN+CviDJA8BVNWTwIP0LuT+b+Ceqvp+VV0EPgQ8BDwNPNgdq91mcRHW119bW1/v1SVNvFTVuPtwTXNzc7WysjLubqjfnj29M/7LJfDyy6Pvj6TXSfJ4Vc1t1OZv+Gp7rnQR3ovz0q5g+Gt7Tp2CmZnX1mZmenVJE8/w1/bMz8PSEhw61JvqOXSotz8/P+6eSdqEgX7DV42bnzfspV3KM39JapDhL0kNMvwlqUGGvyQ1yPCXpAYZ/pLUIMNfkhpk+EtSgwx/SWqQ4S9JDTL8JalBhr8kNcjwl6QGGf6S1CDDX5IaZPhLUoMMf0lqkOEvSQ0aKPyT/GqSv0ryRJLfS3JDX9u9SVaTPJPkXX31411tNcnJQd5fkrQ9g575Pwz8eFX9E+CvgXsBkhwDTgA/BhwHfj3J3iR7gU8CdwDHgA90x0qSRmig8K+qP6qqi93uo8D+bvtO4IGqeqmqvg6sArd2j9Wqeraq/h54oDtWkjRCw5zz/0XgD7vtfcDZvrZzXe1KdUnSCL3hWgck+Szw1g2aFqvqdHfMInARWB5Wx5IsAAsABw8eHNbLSpLYRPhX1Tuv1p7k54H3ALdXVXXl88CBvsP2dzWuUr/8fZeAJYC5ubna6BhJ0vYMerfPceAjwHurar2v6QxwIsn1SY4AR4EvAV8GjiY5kuQ6eheFzwzSB0nS1l3zzP8aPgFcDzycBODRqvo3VfVkkgeBp+hNB91TVd8HSPIh4CFgL/CpqnpywD5IkrYor87UTK65ublaWVkZdzckaVdJ8nhVzW3U5m/4SlKDDH9JapDhL0kNMvwlqUGGvyQ1yPCX1K7lZTh8GPbs6T0vD22Rgok36H3+krQ7LS/DwgKsd7+furbW2weYnx9fv0bEM39JbVpcfDX4L1lf79UbYPhLatNzz22tPmUMf0ltutJqwY2sImz4S2rTqVMwM/Pa2sxMr94Aw19Sm+bnYWkJDh2CpPe8tNTExV7wbh9JLZufbybsL+eZv6Sd0/B99JPOM39JO6Px++gnnWf+knZG4/fRTzrDX9LOaPw++kln+EvaGY3fRz/pDH9JO6Px++gnneEvaWc0fh/9pPNuH0k7p+H76CedZ/6S1CDDX5IaZPhLUoMMf0lqkOEvSQ1KVY27D9eU5AKwNu5+bMGNwLfH3YkxcextcuyT6VBVzW7UsCvCf7dJslJVc+Puxzg4dsfemt06dqd9JKlBhr8kNcjw3xlL4+7AGDn2Njn2XcY5f0lqkGf+ktQgw38ASX41yV8leSLJ7yW5oa/t3iSrSZ5J8q6++vGutprk5Hh6Prgk70/yZJKXk8xd1jbVY7/ctI7rkiSfSvJCkq/01d6c5OEkX+2e39TVk+Tj3WfxRJJbxtfzwSU5kOTzSZ7q/r5/uKvv/vFXlY9tPoB/Cbyh2/4o8NFu+xjwF8D1wBHga8De7vE14B8C13XHHBv3OLY59n8M/CjwBWCurz71Y7/sc5jKcV02xn8O3AJ8pa/2n4GT3fbJvr/77wb+EAhwG/DYuPs/4NhvBm7ptn8I+Ovu7/iuH79n/gOoqj+qqovd7qPA/m77TuCBqnqpqr4OrAK3do/Vqnq2qv4eeKA7dtepqqer6pkNmqZ+7JeZ1nG9oqq+CHznsvKdwP3d9v3A+/rqv1k9jwI3JLl5ND0dvqp6vqr+tNv+P8DTwD6mYPyG//D8Ir2f+ND7y3G2r+1cV7tSfZq0NvZpHde13FRVz3fb3wRu6ran9vNIchh4G/AYUzB+v8zlGpJ8FnjrBk2LVXW6O2YRuAgsj7JvO20zY5eqqpJM9W2DSX4Q+B3gl6rqb5O80rZbx2/4X0NVvfNq7Ul+HngPcHt1k37AeeBA32H7uxpXqU+ca439CqZi7FtwtfFOs28lubmqnu+mNV7o6lP3eSR5I73gX66q3+3Ku378TvsMIMlx4CPAe6tqva/pDHAiyfVJjgBHgS8BXwaOJjmS5DrgRHfsNGlt7NM6rms5A9zVbd8FnO6rf7C76+U24MW+6ZFdJ71T/PuAp6vq1/qadv/4x33FeTc/6F3MPAv8eff4b31ti/TuAnkGuKOv/m56dwx8jd70ydjHsc2x/wy9+cyXgG8BD7Uy9g0+i6kcV9/4Pg08D/y/7r/53cBbgEeArwKfBd7cHRvgk91n8Zf03Qm2Gx/ATwMFPNH3//m7p2H8/oavJDXIaR9JapDhL0kNMvwlqUGGvyQ1yPCXpAYZ/pLUIMNfkhpk+EtSg/4/XIFY5kNG/sYAAAAASUVORK5CYII=\n" + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + }, + { + "data": { + "text/plain": "
", + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAUPklEQVR4nO3db4xdd33n8fcnSZPKbcGkGdLIjjNGG7oK2xbSIUpF/0BCS0gR5kGFIs0uKWVrKUoRULRsgqWV9kEk/qk0aFvYEWkVtLObZgMlFqKUJIVKfRCHcUhC/hDiQhw7OGSQFlqttYnSfPfBOa5vkrHNzP137pn3Sxrdc37nzpyvz9gfn3vu73xvqgpJUj+dNu0CJEnjY8hLUo8Z8pLUY4a8JPWYIS9JPXbGtAsYdM4559T8/Py0y5CkmbJ///4fVtXcWts6FfLz8/OsrKxMuwxJmilJDp5om5drJKnHDHlJ6jFDXpJ6zJCXpB4z5CWpxwx5bS7LyzA/D6ed1jwuL0+7ImmsOjWFUhqr5WXYvRuOHm3WDx5s1gEWF6dXlzRGnslr89iz53jAH3P0aDMu9dRIQj7J1iS3Jfl2kkeS/FqSs5PckeSx9vEVo9iXtGFPPLG+cakHRnUmfyPwlar6t8CvAI8A1wF3VdWFwF3tujQ9O3asb1zqgaFDPsnLgd8EbgKoqmer6kfALuDm9mk3A+8Ydl/SUG64AbZseeHYli3NuNRToziT3wmsAn+Z5JtJPpvkZ4Bzq+pI+5yngHPX+uYku5OsJFlZXV0dQTnSCSwuwtISXHABJM3j0pJvuqrXRhHyZwAXA5+uqtcB/5cXXZqp5oNk1/ww2apaqqqFqlqYm1uziZo0OouL8Pjj8PzzzaMBPx5OVe2MUYT8YeBwVe1r12+jCf0fJDkPoH18egT7ktR1x6aqHjwIVcenqhr0UzF0yFfVU8ChJL/YDl0OPAzsBa5ux64Gbh92X5JmgFNVO2VUN0O9F1hOcibwXeDdNP+B3JrkPcBB4J0j2pekLnOqaqeMJOSr6j5gYY1Nl4/i50uaITt2NJdo1hrXxHnHq6TRcqpqpxjykkbLqaqdYoMySaO3uGiod4Rn8pLUY4b8LPDGEkkb5OWarrMHuqQheCbfdd5YImkIhnzXeWOJpCEY8l1nD3RJQzDku84bSyQNwZDvOm8skTQEZ9fMAm8skbRBnslLUo8Z8pLUY4a8JPWYIS9JPWbIS1KPjSzkk5ye5JtJvtSu70yyL8mBJH/VfjSgJGmCRnkm/z7gkYH1jwKfrKp/A/wf4D0j3Jck6ScwkpBPsh34XeCz7XqAy4Db2qfcDLxjFPuSJP3kRnUm/6fAh4Dn2/WfB35UVc+164eBbWt9Y5LdSVaSrKyuro6oHEkSjCDkk7wNeLqq9m/k+6tqqaoWqmphbm5u2HIkSQNG0dbgDcDbk1wJ/DTwMuBGYGuSM9qz+e3AkyPYlyRpHYY+k6+q66tqe1XNA1cBf1dVi8DXgN9rn3Y1cPuw+5Ikrc8458n/Z+CPkxyguUZ/0xj3JUlaw0i7UFbV14Gvt8vfBS4Z5c+XJK2Pd7xKUo8Z8pLUY4a8JPWYIS9JPWbIS1KPGfKS1GOGvCT1mCEvST1myEtSjxnyUhcsL8P8PJx2WvO4vDztirQRHfw9jrStgaQNWF6G3bvh6NFm/eDBZh1gcXF6dWl9Ovp7TFVNbecvtrCwUCsrK9MuQ5qs+fkmEF7sggvg8ccnXY02aoq/xyT7q2phrW1erpGm7Ykn1jeuburo79GQl6Ztx471jaubOvp7NOSlabvhBtiy5YVjW7Y045odHf09GvLStC0uwtJSc+02aR6XlnzTddZ09PfoG6+SNOPG+sZrkvOTfC3Jw0keSvK+dvzsJHckeax9fMWw+5Ikrc8oLtc8B3ywqi4CLgWuTXIRcB1wV1VdCNzVrkuSJmjokK+qI1V1b7v8z8AjwDZgF3Bz+7SbgXcMuy9J0vqM9I3XJPPA64B9wLlVdaTd9BRw7gm+Z3eSlSQrq6uroyxHkja9kYV8kp8FPg+8v6r+aXBbNe/urvkOb1UtVdVCVS3Mzc2Nqhx1RQd7eUibyUh61yT5KZqAX66qL7TDP0hyXlUdSXIe8PQo9qUZ0tFeHtJmMorZNQFuAh6pqj8Z2LQXuLpdvhq4fdh9acbs2XM84I85erQZlzQRoziTfwPwH4BvJbmvHfsw8BHg1iTvAQ4C7xzBvjRLOtrLQ9pMhg75qvoHICfYfPmwP18zbMeOtbvy2ZNFmhjbGmh8OtrLQ9pMDHmNT0d7eUibiZ8MpfFaXDTUpSnyTP7FnNctqUc8kx/kvG5JPeOZ/CDndUvqGUN+kPO6JfWMIT+oo5/RKEkbZcgPcl63pJ4x5Ac5r1tSzzi75sWc1y2pRzyTl6QeM+QlqccMeUnqMUNeknrMkJekHjPkJanHxh7ySa5I8miSA0muG/f+JHWEHV07Yazz5JOcDvwZ8NvAYeAbSfZW1cPj3K+kKbOja2eM+0z+EuBAVX23qp4FbgF2jXmfkqbNjq6dMe6Q3wYcGlg/3I79qyS7k6wkWVldXR1zOZImwo6unTH1N16raqmqFqpqYW5ubtrlSBoFO7p2xrhD/kng/IH17e2YpD6zo2tnjDvkvwFcmGRnkjOBq4C9Y96npGmzo2tnjHV2TVU9l+SPgL8FTgf+oqoeGuc+JXWEHV07Yeythqvqy8CXx70fSdJLTf2NV0nS+BjyktRjhrwk9ZghL0k9ZshLUo8Z8pLUY4a8tF620NUMGfs8ealXbKGrGeOZvLQettDVjDHkpfWwha5mjCEvrYctdDVjDHlpPWyhqxljyEvrYQtdzRhn10jrZQtdzRDP5CWpxwx5SZqmMd9c5+UaSZqWCdxcN9SZfJKPJ/l2kgeS/HWSrQPbrk9yIMmjSd4yfKmS1DMTuLlu2Ms1dwD/rqp+GfgOcD1AkotoPrT7NcAVwJ8nOX3IfWnc7MkiTdYEbq4bKuSr6qtV9Vy7ejewvV3eBdxSVc9U1feAA8Alw+xLY3bsZePBg1B1/GWjQS+NzwRurhvlG69/APxNu7wNODSw7XA7pq6yJ4s0eRO4ue6UIZ/kziQPrvG1a+A5e4DngHWf9iXZnWQlycrq6up6v12jYk8WafImcHPdKWfXVNWbT7Y9ye8DbwMur6pqh58Ezh942vZ2bK2fvwQsASwsLNRaz9EE7NjRXKJZa1zS+Iz55rphZ9dcAXwIeHtVDb7W3wtcleSsJDuBC4F7htmXxsyeLFIvDXtN/r8BPwfckeS+JJ8BqKqHgFuBh4GvANdW1b8MuS+Nkz1ZpF7K8Sss07ewsFArKyvTLkOSZkqS/VW1sNY22xpIUo8Z8pLUY4a8JPWYIS9JPWbIS1KPzX7I21RLkk5otvvJT6AXsyTNstk+k7epliSd1GyHvE21JOmkZjvkJ9CLWZJm2WyHvE21JOmkZjvkbaolSSc127NrYOy9mCVpls32mbyk8fEelF6Y/TN5SaPnPSi94Zm8pJfyHpTeMOQlvZT3oPSGIS/ppbwHpTdGEvJJPpikkpzTrifJp5IcSPJAkotHsR9JE+I9KL0xdMgnOR/4HWDwddxbgQvbr93Ap4fdj6QJ8h6U3hjF7JpPAh8Cbh8Y2wV8rppPCb87ydYk51XVkRHsT9IkeA9KLwx1Jp9kF/BkVd3/ok3bgEMD64fbsbV+xu4kK0lWVldXhylHkvQipzyTT3In8AtrbNoDfJjmUs2GVdUSsASwsLBQw/wsSdILnTLkq+rNa40n+SVgJ3B/EoDtwL1JLgGeBM4fePr2dkySNEEbvlxTVd+qqldW1XxVzdNckrm4qp4C9gLvamfZXAr82OvxkjR542pr8GXgSuAAcBR495j2I0k6iZGFfHs2f2y5gGtH9bMlSRvjHa+S1GOGvDYPW+dqE7LVsDYHW+dqk/JMXpuDrXO1SRny2hxsnatNypDX5mDrXG1Shrw2B1vnapMy5LU52DpXm5Sza7R52DpXm5Bn8pLUY4a8JPWYIS9JPWbIS1KPGfKS1GOGvCT1mCEvST1myEvSqcxwm+qhQz7Je5N8O8lDST42MH59kgNJHk3ylmH3I0lTcaxN9cGDUHW8TfWMBP1QIZ/kTcAu4Feq6jXAJ9rxi4CrgNcAVwB/nuT0IWuVpMmb8TbVw57JXwN8pKqeAaiqp9vxXcAtVfVMVX2P5gO9LxlyX5vbDL9clGbajLepHjbkXw38RpJ9Sf4+yevb8W3AoYHnHW7HXiLJ7iQrSVZWV1eHLKenZvzlojTTZrxN9SlDPsmdSR5c42sXTYOzs4FLgf8E3Jok6ymgqpaqaqGqFubm5jb0h+i9GX+5KM20GW9TfcoulFX15hNtS3IN8IWqKuCeJM8D5wBPAucPPHV7O6aNmPGXi9JMO9a5dM+e5t/cjh1NwM9IR9NhL9d8EXgTQJJXA2cCPwT2AlclOSvJTuBC4J4h97V5zfjLRWnmLS7C44/D8883jzMS8DB8yP8F8KokDwK3AFdX4yHgVuBh4CvAtVX1L0Pua/Oa8ZeLkqZnqA8NqapngX9/gm03AKbQKMz4y0VJ0+MnQ80KP9VI0gbY1kCSesyQl6QeM+QlqccMeUnqMUNeknrMkJekHjPkJanHDHlJ6jFDXpJ6zJCXpB4z5CWpxwx5SeoxQ16SesyQl6QeM+SlzWR5Gebn4bTTmkc/DL737CcvbRbLy7B79/EPhT94sFkHP6ugx4Y6k0/y2iR3J7kvyUqSS9rxJPlUkgNJHkhy8WjKlbRhe/YcD/hjjh5txtVbw16u+RjwX6vqtcB/adcB3krz4d0XAruBTw+5H0nDeuKJ9Y2rF4YN+QJe1i6/HPh+u7wL+Fz7od53A1uTnDfkviQNY8eO9Y2rF4YN+fcDH09yCPgEcH07vg04NPC8w+3YSyTZ3V7qWVldXR2yHEkndMMNsGXLC8e2bGnG1VunDPkkdyZ5cI2vXcA1wAeq6nzgA8BN6y2gqpaqaqGqFubm5tb/J5D0k1lchKUluOACSJrHpSXfdO25VNXGvzn5MbC1qipJgB9X1cuS/Hfg61X1v9rnPQq8saqOnOznLSws1MrKyobrkaTNKMn+qlpYa9uwl2u+D/xWu3wZ8Fi7vBd4VzvL5lKa8D9pwEuSRm/YefJ/CNyY5Azg/9HMpAH4MnAlcAA4Crx7yP1IkjZgqJCvqn8AfnWN8QKuHeZnS5KGZ1sDSeoxQ16SesyQl6QeM+Q1G+yeKG2IXSjVfXZPlDbMM3l1n90TpQ0z5NV9dk+UNsyQV/fZPVHaMENe3Wf3RGnDDHl1n90TpQ1zdo1mw+KioS5tgGfyktRjhrwk9ZghL0k9ZshLUo8Z8pLUY0N9xuuoJVkFDk6xhHOAH05x/6fS9fqg+zVa3/C6XuNmrO+Cqppba0OnQn7akqyc6MNwu6Dr9UH3a7S+4XW9Rut7IS/XSFKPGfKS1GOG/AstTbuAU+h6fdD9Gq1veF2v0foGeE1eknrMM3lJ6jFDXpJ6zJAHkrw2yd1J7kuykuSSdjxJPpXkQJIHklw8xRrfm+TbSR5K8rGB8evb+h5N8pZp1dfW8sEkleScdr1Lx+/j7fF7IMlfJ9k6sK0TxzDJFW0NB5JcN606Buo5P8nXkjzc/r17Xzt+dpI7kjzWPr5iynWenuSbSb7Uru9Msq89jn+V5Mwp17c1yW3t379HkvzaRI9hVW36L+CrwFvb5SuBrw8s/w0Q4FJg35TqexNwJ3BWu/7K9vEi4H7gLGAn8I/A6VOq8Xzgb2luZjunS8evreV3gDPa5Y8CH+3SMQROb/f9KuDMtqaLpnW82prOAy5ul38O+E57vD4GXNeOX3fsWE6xzj8G/ifwpXb9VuCqdvkzwDVTru9m4D+2y2cCWyd5DD2TbxTwsnb55cD32+VdwOeqcTewNcl5U6jvGuAjVfUMQFU9PVDfLVX1TFV9DzgAXDKF+gA+CXyI5lge05XjR1V9taqea1fvBrYP1NiFY3gJcKCqvltVzwK3tLVNTVUdqap72+V/Bh4BtrV13dw+7WbgHdOpEJJsB34X+Gy7HuAy4Lb2KdOu7+XAbwI3AVTVs1X1IyZ4DA35xvuBjyc5BHwCuL4d3wYcGnje4XZs0l4N/Eb7EvTvk7y+He9EfUl2AU9W1f0v2tSJ+tbwBzSvMKA7NXaljjUlmQdeB+wDzq2qI+2mp4Bzp1QWwJ/SnFw8367/PPCjgf/Qp30cdwKrwF+2l5Q+m+RnmOAx3DSfDJXkTuAX1ti0B7gc+EBVfT7JO2n+131zh+o7Azib5pLH64Fbk7xqguWdqr4P01wOmaqT1VhVt7fP2QM8ByxPsrZZluRngc8D76+qf2pOlhtVVUmmMg87yduAp6tqf5I3TqOGn8AZwMXAe6tqX5IbaS7P/KtxH8NNE/JVdcLQTvI54H3t6v+mfekHPElzrfmY7e3YpOu7BvhCNRfw7knyPE2To6nXl+SXaM5W7m//8W8H7m3fvJ5YfSer8Zgkvw+8Dbi8PZYw4RpPoit1vECSn6IJ+OWq+kI7/IMk51XVkfby29Mn/glj9Qbg7UmuBH6a5pLrjTSXBc9oz+anfRwPA4eral+7fhtNyE/sGHq5pvF94Lfa5cuAx9rlvcC72lkilwI/HniJNUlfpHnzlSSvpnnz5odtfVclOSvJTuBC4J5JFlZV36qqV1bVfFXN0/ylvriqnqI7x48kV9C8rH97VR0d2DT1Y9j6BnBhOzPkTOCqtrapaa9v3wQ8UlV/MrBpL3B1u3w1cPukawOoquuranv79+4q4O+qahH4GvB7064PoP13cCjJL7ZDlwMPM8ljOM13nbvyBfw6sJ9mRsM+4Ffb8QB/RjPr4VvAwpTqOxP4H8CDwL3AZQPb9rT1PUo7Q2jKx/Jxjs+u6cTxa2s5QHPN+7726zNdO4Y0s5G+09aypwO/y1+neSP9gYHjdiXNde+7aE6G7gTO7kCtb+T47JpX0fxHfYDmlflZU67ttcBKexy/CLxiksfQtgaS1GNerpGkHjPkJanHDHlJ6jFDXpJ6zJCXpB4z5CWpxwx5Seqx/w/1zAjqr4SiiAAAAABJRU5ErkJggg==\n" + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "# PCA/tsne for high dimensional data (60k dims) does not really work well...\n", + "pca = PCA(2)\n", + "tsne = TSNE(n_components=2)\n", + "tsne_fit_1 = tsne.fit_transform(perturb_vector_1)\n", + "tsne_fit_0 = tsne.fit_transform(perturb_vector_0)\n", + "pca_fit_1 = pca.fit_transform(perturb_vector_1)\n", + "pca_fit_0 = pca.fit_transform(perturb_vector_0)\n", + "label_1 = kmeans_1.predict(perturb_vector_1)\n", + "label_0 = kmeans_0.predict(perturb_vector_0)\n", + "colors = ['r','g','b','c','y']\n", + "for i in np.arange(num_clusters):\n", + " plt.scatter(pca_fit_1[label_1==i][:,0], pca_fit_1[label_1==i][:,1], color=colors[i])\n", + "plt.show()\n", + "for i in np.arange(num_clusters):\n", + " plt.scatter(pca_fit_0[label_0==i][:,0], pca_fit_0[label_0==i][:,1], color=colors[i])\n", + "plt.show()\n", + "for i in np.arange(num_clusters):\n", + " plt.scatter(tsne_fit_1[label_1==i][:,0], tsne_fit_1[label_1==i][:,1], color=colors[i])\n", + "plt.show()\n", + "for i in np.arange(num_clusters):\n", + " plt.scatter(tsne_fit_0[label_0==i][:,0], tsne_fit_0[label_0==i][:,1], color=colors[i])\n", + "plt.show()" + ], + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%%\n" + } + } + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[ 1981 3535 5008 8364 11864 13916 15069 15253 21367 24170 28987 33326\n", + " 34237 37660 39210 40684 43098 45503]\n" + ] + } + ], + "source": [ + "print(genes_pos_0[0][0])" + ] + }, + { + "cell_type": "code", + "execution_count": 129, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "CF=0, Number of overlapping positive genes for cluster 0: 0\n", + "CF=0, Number of overlapping negative genes for cluster 0: 0\n", + "CF=0, Number of overlapping positive genes for cluster 1: 0\n", + "CF=0, Number of overlapping negative genes for cluster 1: 0\n", + "CF=0, Number of overlapping positive genes for cluster 2: 0\n", + "CF=0, Number of overlapping negative genes for cluster 2: 0\n", + "CF=0, Number of overlapping positive genes for cluster 3: 0\n", + "CF=0, Number of overlapping negative genes for cluster 3: 0\n", + "CF=0, Number of overlapping positive genes for cluster 4: 0\n", + "CF=0, Number of overlapping negative genes for cluster 4: 0\n", + "CF=1, Number of overlapping positive genes for cluster 0: 2\n", + "CF=1, Number of overlapping negative genes for cluster 0: 0\n", + "CF=1, Number of overlapping positive genes for cluster 1: 0\n", + "CF=1, Number of overlapping negative genes for cluster 1: 0\n", + "CF=1, Number of overlapping positive genes for cluster 2: 0\n", + "CF=1, Number of overlapping negative genes for cluster 2: 0\n", + "CF=1, Number of overlapping positive genes for cluster 3: 0\n", + "CF=1, Number of overlapping negative genes for cluster 3: 0\n", + "CF=1, Number of overlapping positive genes for cluster 4: 0\n", + "CF=1, Number of overlapping negative genes for cluster 4: 0\n" + ] + } + ], + "source": [ + "# within each cluster determine the overlapping gene indices that have been perturbed\n", + "for i in np.arange(num_clusters):\n", + " in_cluster_0 = np.argwhere(kmeans_0.labels_==i).squeeze()\n", + " genes_pos_0_sets = []\n", + " genes_neg_0_sets = []\n", + " for j in in_cluster_0:\n", + " genes_pos_0_sets.append(set(genes_pos_0[j][0]))\n", + " genes_neg_0_sets.append(set(genes_neg_0[j][0]))\n", + " overlap_pos_0 = list(genes_pos_0_sets[0].intersection(*genes_pos_0_sets))\n", + " overlap_neg_0 = list(genes_neg_0_sets[0].intersection(*genes_neg_0_sets))\n", + " print(\"CF=0, Number of overlapping positive genes for cluster {}: {}\".format(i,len(overlap_pos_0)))\n", + " print(\"CF=0, Number of overlapping negative genes for cluster {}: {}\".format(i,len(overlap_neg_0)))\n", + "\n", + "for i in np.arange(num_clusters):\n", + " in_cluster_1 = np.argwhere(kmeans_1.labels_==i).squeeze()\n", + " genes_pos_1_sets = []\n", + " genes_neg_1_sets = []\n", + " for j in in_cluster_1:\n", + " genes_pos_1_sets.append(set(genes_pos_1[j][0]))\n", + " genes_neg_1_sets.append(set(genes_neg_1[j][0]))\n", + " overlap_pos_1 = list(genes_pos_1_sets[0].intersection(*genes_pos_1_sets))\n", + " overlap_neg_1 = list(genes_neg_1_sets[0].intersection(*genes_neg_1_sets))\n", + " print(\"CF=1, Number of overlapping positive genes for cluster {}: {}\".format(i,len(overlap_pos_1)))\n", + " print(\"CF=1, Number of overlapping negative genes for cluster {}: {}\".format(i,len(overlap_neg_1)))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Not many overlapping genes which means that clustering doesnt give <> correlation to \n", + "# highly perturbed genes, but there may be some statistical methods to find which gene perturbations\n", + "# are <> in each cluster" + ] + }, + { + "cell_type": "code", + "execution_count": 126, + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAYsAAAEGCAYAAACUzrmNAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAY0UlEQVR4nO3dfbBkdX3n8fcHcDGLIjNhiiI8jVKERF1FnAA+RK5PgMQEK6VZXR8gxS5xxUSS3drgPsg1uls+rG6MG62gUmji464aWUWRQgbiMzMECKAGCoeC2UFGZ0SNWyQO3/2jz6V7xntv35m5554+t9+vqq4+/bunu3/3V939Oed3fud3UlVIkrSYA7qugCRp8hkWkqSxDAtJ0liGhSRpLMNCkjTWQV1XoA2HH354rV+/vutqSFKvbN68+ftVtW6+v63KsFi/fj2bNm3quhqS1CtJ7l7ob3ZDSZLGMiwkSWMZFpKksQwLSdJYhoUkaSzDQpI0lmEhSRrLsGjR7Oxs11WQpGVhWLRkdnaWN77xjSQxNCT1XlbjxY82bNhQk3AGdxJWY/tKWp2SbK6qDfP9zT2LFl1yySVdV0GSloVh0SK7nyStFoaFJGksw0KSNJZhIUkay7CQJI1lWEiSxjIsJEljGRaSpLEMC0nSWIaFJGksw0KSNJZhIUkay7DQinCeLKnfDAu1zmt7SP3n9Sy0Iry2hzT5vJ6FOue1PaR+Myy0Iux+kvrNsJAkjWVYSJLGMiwkSWMZFpKksQwLSdJYhoUkaSzDQpI0VmthkeSYJNcmuT3JbUle15SvTXJ1kjua+zVNeZL8WZI7k9yS5OSR1zq3Wf+OJOe2VWdJ0vza3LP4GfDvqurxwGnAhUkeD1wMXFNVJwDXNI8BXgCc0NwuAN4Lg3ABLgFOBU4BLpkLGEnSymgtLKpqW1Xd2Cz/GPgWcBRwDvDBZrUPAi9qls8BPlQDXwcOS3IkcCZwdVXtqKqdwNXAWW3VW5L081bkmEWS9cBTgG8AR1TVtuZP9wFHNMtHAfeMPO3epmyh8j3f44Ikm5Js2r59+7LWX5KmXethkeRRwCeBi6rqR6N/q8E0pMsyFWlVXVpVG6pqw7p165bjJSVJjVbDIskjGATFh6vqU03x95ruJZr7+5vyrcAxI08/uilbqFyStELaHA0V4APAt6rqnSN/ugKYG9F0LvCZkfJXNaOiTgMeaLqrrgLOSLKmObB9RlMm9ZIz8KqP2tyzeAbwSuA5SW5qbmcDbwGen+QO4HnNY4ArgbuAO4H3Aa8BqKodwJuAG5rbnzRlUu941UD1lVfKk1aYVw3UpPJKedIE8aqB6iPDQlphdj+pjwwLSdJYhoUkaSzDQpI0lmEhSRrLsJAkjWVYSJLGMiwkSWMZFpKksQwLSdJYhoUkaSzDQpI0lmEhSRrLsGiRE8ZJWi0Mi5Z4kRtJq4kXP2qRF7mR1Cde/KgjXuRG0mphWLTI7idJq4VhIUkay7CQJI1lWEjqjF21/WFYSOqEw8v7xaGzkjrj8PLJ4tBZSRPJ4eX9YVi0yF1raXF+R/rDsGiJ/bGSVhOPWbTI/lhJfeIxi47YHytptTAsWmT3k6TVwrCQJI1lWLTIPQtJq4Vh0RJHQ0laTVoLiySXJbk/ya0jZbNJtia5qbmdPfK31ye5M8l3kpw5Un5WU3Znkovbqu9ymwuIqjIsJPVem3sWlwNnzVP+P6rqpOZ2JUCSxwMvBZ7QPOc9SQ5MciDw58ALgMcDL2vW7QVHQw0ZmFK/tRYWVXU9sGOJq58DfKyqHqyq7wJ3Aqc0tzur6q6q+kfgY826veAP5IBdclL/dXHM4rVJbmm6qdY0ZUcB94ysc29TtlD5z0lyQZJNSTZt3769jXprH9klJ/XfSofFe4HjgZOAbcA7luuFq+rSqtpQVRvWrVu3XC+7X/xhHDr99NO7roKk/bCiYVFV36uqXVX1EPA+Bt1MAFuBY0ZWPbopW6h84tn1MjQ7O8t1111nW0g91urcUEnWA5+tqic2j4+sqm3N8h8Cp1bVS5M8AfgIg/D4JeAa4AQgwN8Dz2UQEjcA/6qqblvsfZ0bavLYFtLkW2xuqINafNOPAjPA4UnuBS4BZpKcBBSwBfg9gKq6LckngNuBnwEXVtWu5nVeC1wFHAhcNi4oJoldL0OODJP6beyeRZJDgP9XVQ8l+WXgV4DPV9U/rUQF98Uk7FnMdUPB4IfS7hdJk26xPYulhMVm4NeBNcBXGHQF/WNVvXy5K7pcJiEswK4XSf2yv1OUp6p+Cvw28J6qegmDk+c0hl0vklaLJYVFkqcBLwc+15Qd2F6VVg+7niStFksJi4uA1wOfbg5EPw64tt1qSZImydiwqKrrquq3gHc3j++qqj9ovWZaVdzLGpqZmem6CtJeGxsWSZ6W5Hbg283jJyd5T+s106rhCYpDMzMzD5+gaGioT5ZynsWfAmcCVwBU1c1JntVqraRVauPGjY6SUy8tabqPqrpnj6JdLdRFq5QTCe7OkzXVR0vZs7gnydOBSvII4HXAt9qtllYbhxEPbdy4sesqSHttKXsWrwYuZDA1+FYGM8Ze2GaltPq4RyH129g9i6r6PoNzLCRJU2opo6E+mOSwkcdrklzWbrUkSZNkKd1QT6qqH849qKqdwFPaq5IkadIsJSwOGLn8KUnW0uLU5tJq5/Eb9dFSwuIdwNeSvCnJm4GvAm9rt1pabfyBHPAERfXVUqb7+BCDGWe/B9wH/HZV/WXbFdPq4Q+k1H8LXs8iyaFV9aOm2+nnVNWOVmu2HyblehYa8qzlIdtiaHZ21g2ICbKv17P4SHO/Gdg0cpt7LC2ZJ+UN2RYD7nH2y4JhUVUvbO4fW1WPG7k9tqoet3JV7C8nihvyx2DItlAfLdYNdfJiT6yqG1up0TKYhG6oudlFYTAXkFM8SD/PLrnJslg31GJDYN/R3D8S2ADcDAR4EoNuqKctZyVXG2cXlcazS64/FuuGenZVPRvYBpxcVRuq6qkMTsjbulIV7DNnF5UWZ5dcfyzlPIsTq+rv5h5U1a3Ar7ZXpdXDridJq8VSwuKWJO9PMtPc3gfc0nbFJK1+7ln0x1LC4neB2xhcx+J1wO1NmSTtM4fO9suCo6H6bBJGQ0kaz0Egk2VfR0PNPfkZwCxw3Oj6nmshaX85Gqo/ljJ77AeAP2Rw5rbX3pa0bOx+6o+lHLN4oKo+X1X3V9UP5m6t10yrij8KUr8tJSyuTfL2JE9LcvLcrfWaadXwQKYW4uehP8Ye4E5y7TzFVVXPaadK+88D3JPHA5na09xGBAyOXRgc3VvsALejobQinIpa83EjYrLs02ioJK+oqr9K8kfz/b2q3rlcFdTqZ1BoPo6G6o/Fjlkc0tw/eoHbopJcluT+JLeOlK1NcnWSO5r7NU15kvxZkjuT3DJ6TCTJuc36dyQ5dx/+R0kTyo2I/mitGyrJs4CfAB+qqic2ZW8DdlTVW5JcDKypqj9Ocjbw+8DZwKnAu6rq1OYqfZsYzHpbDIbvPrWqdi723nZDSdLe29cr5e2Xqroe2PPSq+cAH2yWPwi8aKT8QzXwdeCwJEcCZwJXV9WOJiCuBs5qq86SpPm1FhYLOKKqtjXL9wFHNMtHAfeMrHdvU7ZQ+c9JckGSTUk2bd++fXlrLS0ju17URysdFg+rQf/XsvWBVdWlzTU3Nqxbt265XlZaVp5zor4aGxZJ/vPI8sH7+X7fa7qXaO7vb8q3AseMrHd0U7ZQudRLcwFRVYaFemXBsEjyx0meBrx4pPhr+/l+VwBzI5rOBT4zUv6qZlTUaQymGNkGXAWckWRNM3LqjKZM6i2voKg+WmzP4tvAS4DHJfmb5qJHv5jkxKW8cJKPMgiXE5Pcm+R84C3A85PcATyveQxwJXAXcCfwPuA1AFW1A3gTcENz+5OmrBfcctSeZmdnue666+yGUu8sOHQ2yenAN4CvAr/G4FKqnwO+xOBSq09fqUrurUkYOutUBlqIZy1rUu3r9SzOBN4AHA+8k8GlVP+hqrxK3hLMhYU/CtqTZy2rjxbshqqq/1hVzwW2AH8JHAisS/LlJP9nherXa/ZND83MzHRdhYnhXqb6aClDZ6+qqk1VdSlwb1U9E6/BPZZ900MzMzMPt4WhoVHT/t3ok72a7iPJk6vq5hbrsywm4ZgF2Dc9yrbQnjyuN3mWbbqPPgTFJLFvesguuSF/FAc856RfOjuDexr4BRjauHFj11WYCJ7BvTs3qPrDix9JK8wuOU2qTmadlTQ/t6bVR4aFVoRdLkO2hfrIsFDr7KfXQvw89IdhodY56kXzcSOiXwwLrQj76aV+czSUpM44MmyyOBpK0kRyj7M/DAtJnfFYRX8YFtIK8wdSfWRYSCvIEUDqK8OiRf4YaE8OI1ZfGRYtcQtSC/GgrvrIobMtcligpD5x6GxH3IKUtFoYFi2y+0nSamFYaEUYnFK/GRYt8gdywIP9u7MN1EeGRUv8gRxyuOiQn4vd2Qb94WioFjkaamh2dtYfhoafi4G54ITBYBA/H91bbDTUQStdmWniaKghfwiG/FwMzIWFwdkPdkO1yB/IIdtiyLYYMjj7w7BQ6+yn10L8PPSHYaHWeYB7d7aB+siw0Iqwu2HAvSz1laOhWuQIIM3H0VCaVM4N1QG3IDWfuc+Cnwv1TSd7Fkm2AD8GdgE/q6oNSdYCHwfWA1uA36mqnUkCvAs4G/gpcF5V3bjY60/KnoVbkJqPnwtNqknds3h2VZ00UrGLgWuq6gTgmuYxwAuAE5rbBcB7V7ym+8AtSC3E4zdDfjf6o8s9iw1V9f2Rsu8AM1W1LcmRwMaqOjHJXzTLH91zvYVe3z0LafJ5BvfkmcQ9iwK+mGRzkguasiNGAuA+4Ihm+SjgnpHn3tuU7SbJBUk2Jdm0ffv2tuq9V04//fSuqyBJy6KrsHhmVZ3MoIvpwiTPGv1jDTbH92qTvKouraoNVbVh3bp1y1jVfTM7O8t1111nN1RjZmam6ypownj+Tb90EhZVtbW5vx/4NHAK8L2m+4nm/v5m9a3AMSNPP7opm2h+EYZmZmYeDk5DQ3M8rtcvKx4WSQ5J8ui5ZeAM4FbgCuDcZrVzgc80y1cAr8rAacADix2vmCQeyBzYuHEjMAjOuWXJDap+6WLW2SOATw9GxHIQ8JGq+kKSG4BPJDkfuBv4nWb9KxkMm72TwdDZ3135Ku8bvwBDxx13XNdV0ARyg6o/VnzPoqruqqonN7cnVNV/bcp/UFXPraoTqup5VbWjKa+qurCqjq+qf1FV3Q9zWiLDYmB2dpa7777b7oaGbTBkW/SHZ3C3xDO4h+xuGPJzob5ybqgWeZ7FkPNkDfm50KSaxPMspoL9sdqTI4B2Zxv0h2HRIr8IA3a9aD5+LvrFbiitCLtehmyLIdtistgNpc7ZJTdkWww5JU5/GBYt8mzlIbsZtCenxOkXw6IlTnGxO38MBuynH3JIdb8YFi1xioshfyCH/IHcnV1y/WFYtMQhklqIP5BDfjf6w9FQLXKkx5Btofl4suZkcTRURxzpMeTW9JDHsAbsnuwXw6IljvTQfBz4MOTxm36xG6pFdr0MeK3l3fm5GJqZmZn6ASCTxG6oDqxfvx4Y/DDMLU8rtyB3Z/fkgHvf/dLFxY+mwpYtW9yCHOHFj4bcklYfuWfREofODs3MzDx88aNp76cHh4uqnzxm0SL3LIZsiwGP3+zOz8Vk8ZhFB+a2oN2a9vjNqNEuKLujHFLdJ4ZFS0YDYtrD4rzzzpt3eRo5DYz6yrBoyZvf/OZ5l6eRW9ND7mUNeVJevzgaqiUHHXQQu3btenhZAkfJjXIjol/cs2jJaaedNu/yNLrpppvmXZ5GHssasqu2XwyLllx//fXzLk+jww47bN7labRly5Z5l6fR5ZdfPu+yJpNh0ZJDDz103uVpdN999827PI1Gj1NM+zEL26JfDIuWXHTRRfMuT6MHH3xw3uVp5B6n+sqwaImjoYYe85jHzLs8jY499th5l6eRxyz6xTO4W5Jkt8ersZ2X6oADDnj4/0/CQw891HGNumNbDPkdmTyewd2Bgw8+eN7laXTAAQfMuzyNRn8Qp/3H0T3Ofpnub26L7KcfmjvfZM9lTbcHHnhg3mVNJs8W2w977kbvy7rTvnUpqR8Mi/0w7od+ms7UNTil1c2w0LIwOAf2JjQXW38a2kr90ptjFknOSvKdJHcmubjt91u7di1J9uvW1Hu/bmvXrm37X9Uyqqqxt6Wstxos5fuxlPU1GXqxZ5HkQODPgecD9wI3JLmiqm5v6z137tw5EV/aSfiyrF27lp07d+736+zv/7JmzRp27Nix3/XYH7bF0HK1xTjj2moS2mIa9CIsgFOAO6vqLoAkHwPOAVoLi7rkUJjtfjhfXdL9VCE7/mAX0H09oPuRVLbFkG0xXXpxUl6SFwNnVdW/bh6/Eji1ql47ss4FwAUAxx577FPvvvvulajXfr9GH9p/KaalLSZhTw8mY2vathgxARuWD5vd92HIi52U15c9i7Gq6lLgUhicwb1C77kSb9ML09IW0/J/LoVtMWI/fqD7oi8HuLcCx4w8PropkyStgL6ExQ3ACUkem+SfAS8Frui4TpI0NXrRDVVVP0vyWuAq4EDgsqq6reNqSdLU6EVYAFTVlcCVXddDkqZRX7qhJEkdMiwkSWMZFpKksQwLSdJYvTiDe28l2Q60fwr3eIcD3++6EhPCthiyLYZsi6FJaIvjqmrdfH9YlWExKZJsWujU+WljWwzZFkO2xdCkt4XdUJKksQwLSdJYhkW7Lu26AhPEthiyLYZsi6GJbguPWUiSxnLPQpI0lmEhSRrLsGhBksuS3J/k1q7r0qUkxyS5NsntSW5L8rqu69SVJI9M8s0kNzdt8cau69S1JAcm+dskn+26Ll1KsiXJ3yW5KcmmruuzEI9ZtCDJs4CfAB+qqid2XZ+uJDkSOLKqbkzyaGAz8KKqau3a6ZMqg2uQHlJVP0nyCODLwOuq6usdV60zSf4I2AAcWlUv7Lo+XUmyBdhQVV2fkLco9yxaUFXXAx1fFLh7VbWtqm5sln8MfAs4qttadaMGftI8fERzm9ottSRHA78BvL/rumhpDAutiCTrgacA3+i2Jt1pul1uAu4Hrq6qqW0L4E+B/wA81HVFJkABX0yyOckFXVdmIYaFWpfkUcAngYuq6kdd16crVbWrqk5icA35U5JMZRdlkhcC91fV5q7rMiGeWVUnAy8ALmy6sSeOYaFWNf3znwQ+XFWf6ro+k6CqfghcC5zVdV068gzgt5q++o8Bz0nyV91WqTtVtbW5vx/4NHBKtzWan2Gh1jQHdT8AfKuq3tl1fbqUZF2Sw5rlXwCeD3y721p1o6peX1VHV9V64KXAl6rqFR1XqxNJDmkGf5DkEOAMYCJHURoWLUjyUeBrwIlJ7k1yftd16sgzgFcy2HK8qbmd3XWlOnIkcG2SW4AbGByzmOohowLgCODLSW4Gvgl8rqq+0HGd5uXQWUnSWO5ZSJLGMiwkSWMZFpKksQwLSdJYhoUkaSzDQqtSksOSvGbk8Uwbs5smOS/J/9zL52xJcvg85bNJ/v0y1GlZXkcaZVhotToMeM3YtfaQ5MAW6iL1nmGh1eotwPHNiYBvb8oeleR/J/l2kg83Z5jPbem/NcmNwEuSHJ/kC83Ebn+T5Fea9V6S5NbmmhTXj7zXLzXr35HkbXOFSV7WXKfg1iRvna+SSf5Tkr9P8mXgxHn+/pgkdyc5oHl8SJJ7kjwiyb9JckNTn08m+efzPH9jkg3N8uHNFBtzkxq+vXn+LUl+b++bWNPkoK4rILXkYuCJzcR9JJlhMOvtE4D/C3yFwRnmX27W/0EzmRtJrgFeXVV3JDkVeA/wHOANwJlVtXVu6o7GSc1rPwh8J8m7gV3AW4GnAjsZzCr6oqr667knJXkqg+kuTmLwXbyRwTU/HlZVDzQz1Z7OYD6pFwJXVdU/JflUVb2vea03A+cD715i+5wPPFBVv5bkYOArSb5YVd9d4vM1ZQwLTZNvVtW9AM0P8HqGYfHxpvxRwNOB/9XseAAc3Nx/Bbg8ySeA0UkRr6mqB5rn3w4cB/wisLGqtjflHwaeBfz1yPN+Hfh0Vf20WeeKBer9ceBfMgiLlzIIL4AnNiFxGPAo4KqlNgSDOYielOTFzePHACcAhoXmZVhomjw4sryL3T///9DcHwD8cG6PZFRVvbrZ0/gNYHOzZzDudZfDFcB/S7KWwZ7Kl5ryyxlcefDmJOcBM/M892cMu5sfOVIe4Peram8CRlPMYxZarX4MPHpvn9Rcb+O7SV4Cg5lzkzy5WT6+qr5RVW8AtgPHLPJS3wROb44THAi8DLhuj3WuB16U5BeamUd/c4E6/YTB5IPvAj5bVbuaPz0a2NZMA//yBeqxhUHAALx4pPwq4N82zyXJLzeznkrzMiy0KlXVDxj0w986coB7qV4OnN/MBHobcE5T/va5A9bAV4GbF3n/bQyOm1zbrLe5qj6zxzo3Muhiuhn4PINAWMjHgVc093P+C4MrD36Fhac7/+8MQuFvgdHhuu8HbgdubP6fv8CeBi3CWWclSWO5ZyFJGsuwkCSNZVhIksYyLCRJYxkWkqSxDAtJ0liGhSRprP8PjYSVEziWFNoAAAAASUVORK5CYII=\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots()\n", + "pos = np.arange(len(num_pos)) + 1\n", + "bp = ax.boxplot(num_neg, sym='k+', positions=pos)\n", + "\n", + "ax.set_xlabel('threshold value')\n", + "ax.set_ylabel('# indices')\n", + "plt.setp(bp['whiskers'], color='k', linestyle='-')\n", + "plt.setp(bp['fliers'], markersize=3.0)\n", + "plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [], + "source": [ + "thresholds = pickle.load(open(\"threshold.complete.pkl\",'rb'))" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "683\n" + ] + } + ], + "source": [ + "print(thresholds['positive threshold indices'][0][0].shape[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "1120\n" + ] + } + ], + "source": [ + "diff = thresholds['perturbation vector']\n", + "print(len(diff))" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([0, 1, 1, ..., 0, 1, 0], dtype=int32)" + ] + }, + "execution_count": 16, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "kmeans.labels_ " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.1" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/Pilot1/NT3/nt3_cf/analyze.py b/Pilot1/NT3/nt3_cf/analyze.py new file mode 100644 index 00000000..6e62783b --- /dev/null +++ b/Pilot1/NT3/nt3_cf/analyze.py @@ -0,0 +1,33 @@ +# Script to analyze perturbation by cluster +# Plot the perturbations by cluster +# Plot the pertubation centroids + +import os +import pickle +import matplotlib.pyplot as plt +import numpy as np +directory = 'clusters_0911_0.5/' +orig_dataset = pickle.load(open("nt3.autosave.data.pkl", 'rb'))[0] +cf_dataset = pickle.load(open("threshold_0905.pkl", 'rb'))['perturbation vector'] +for filename in os.listdir(directory): + if filename.startswith("cf_class_0") or filename.startswith("cf_class_1") : + data = pickle.load(open(os.path.join(directory, filename), 'rb')) + x_range = np.arange(len(data['centroid perturb vector'])) + ind_in_cluster = data['sample indices in this cluster'][0:5] + fig,ax = plt.subplots(3, figsize=(20,15)) + fig.suptitle("Perturbation Vectors for counterfactual class 1, cluster 1", fontsize=25) + for i,ax_i in zip(ind_in_cluster,ax): + d = cf_dataset[i] + ax_i.plot(x_range, d, label='perturbation vector') + ax_i.plot(x_range ,data['centroid perturb vector'], label='centroid') + #ax_i.axhline(y=0.5*np.max(np.abs(d)), color='r', linestyle='-') + #ax_i.axhline(y=-0.5*np.max(np.abs(d)), color='r', linestyle='-') + ax_i.axvline(x=9603, color='r', linestyle='-', linewidth=5, alpha=0.3) + + ax_i.set_title("sample {}".format(i)) + ax_i.legend() + fig.supxlabel("Feature index", fontsize=18) + plt.savefig("centroids_{}.png".format(filename)) + + else: + continue diff --git a/Pilot1/NT3/nt3_cf/cf_nb.py b/Pilot1/NT3/nt3_cf/cf_nb.py new file mode 100644 index 00000000..d30b3613 --- /dev/null +++ b/Pilot1/NT3/nt3_cf/cf_nb.py @@ -0,0 +1,70 @@ +import tensorflow as tf +tf.get_logger().setLevel(40) # suppress deprecation messages +tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs +from tensorflow.keras.models import Model, load_model +import matplotlib.pyplot as plt +import numpy as np +import os +os.environ["CUDA_VISIBLE_DEVICES"]="1" +from time import time +from alibi.explainers import CounterFactual, CounterFactualProto +print('TF version: ', tf.__version__) +print('Eager execution enabled: ', tf.executing_eagerly()) # False +print(tf.test.is_gpu_available()) +import pickle +model_nt3 = tf.keras.models.load_model('/vol/ml/shahashka/xai-geom/nt3/nt3.autosave.model') +with open('/vol/ml/shahashka/xai-geom/nt3/nt3.autosave.data.pkl', 'rb') as pickle_file: + X_train,Y_train,X_test,Y_test = pickle.load(pickle_file) + +shape_cf = (1,) + X_train.shape[1:] +print(shape_cf) +target_proba = 0.9 +tol = 0.1 # want counterfactuals with p(class)>0.90 +target_class = 'other' # any class other than will do +max_iter = 1000 +lam_init = 1e-1 +max_lam_steps = 20 +learning_rate_init = 0.1 +feature_range = (0,1) +cf = CounterFactual(model_nt3, shape=shape_cf, target_proba=target_proba, tol=tol, + target_class=target_class, max_iter=max_iter, lam_init=lam_init, + max_lam_steps=max_lam_steps, learning_rate_init=learning_rate_init, + feature_range=feature_range) +shape = X_train[0].shape[0] +results=[] +X = np.concatenate([X_train,X_test]) + +for i in np.arange(902,903): + print(i) + x_sample=X[i:i+1] + print(x_sample.shape) + start = time() + try: + explanation = cf.explain(x_sample) + print('Counterfactual prediction: {}, {}'.format(explanation.cf['class'], explanation.cf['proba'])) + print("Actual prediction: {}".format(model_nt3.predict(x_sample))) + results.append([explanation.cf['X'],explanation.cf['class'], explanation.cf['proba']]) + test = model_nt3.predict(explanation.cf['X']) + print(test, explanation.cf['proba'], explanation.cf['class']) + except: + print("Failed cf generation") + results.append([None, None, None]) + #if i %100==0: +pickle.dump(results, open("redo_cf_rest.pkl", "wb")) + # results = [] +#for i in range(len(results)): +# plt.figure(figsize=(20, 4)) +# sample = X_train[i].flatten() +# y = results[i][0].flatten() +# x = np.arange(y.shape[0]) +# plt.plot(x,y,alpha=0.5, label='counterfactual') +# plt.plot(x,sample,alpha=0.5, label='input') +# plt.plot(x,sample-y, label='diff') +# props = dict(boxstyle='round', facecolor='wheat', alpha=1) +# prediction = model_nt3.predict(X_test[i:i+1]) +# plt.text(0.05, 0.95, "original input: {} {} \n counterfactual: {} {}".format(np.argmax(prediction), +# prediction,results[i][1] ,results[i][2]), +# fontsize=16, +# verticalalignment='top', bbox=props) +# plt.legend() +# plt.savefig("fig_{}.png".format(i)) diff --git a/Pilot1/NT3/nt3_cf/cf_script.py b/Pilot1/NT3/nt3_cf/cf_script.py new file mode 100644 index 00000000..0b74a2d5 --- /dev/null +++ b/Pilot1/NT3/nt3_cf/cf_script.py @@ -0,0 +1,65 @@ +import tensorflow as tf +tf.get_logger().setLevel(40) # suppress deprecation messages +tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs +from tensorflow.keras.models import Model, load_model +import matplotlib.pyplot as plt +import numpy as np +import os +from time import time +from alibi.explainers import CounterFactual, CounterFactualProto +#print('TF version: ', tf.__version__) +#print('Eager execution enabled: ', tf.executing_eagerly()) # False +import pickle + +model_nt3 = tf.keras.models.load_model('./nt3.autosave.model') +with open('./nt3.autosave.data.pkl', 'rb') as pickle_file: + X_train,Y_train,X_test,Y_test = pickle.load(pickle_file) + + +shape_cf = (1,) + X_train.shape[1:] +print(shape_cf) +target_proba = 0.9 +tol = 0.1 # want counterfactuals with p(class)>0.90 +target_class = 'other' # any class other than will do +max_iter = 1000 +lam_init = 1e-1 +max_lam_steps = 20 +learning_rate_init = 0.1 +feature_range = (0,1) + + +cf = CounterFactual(model_nt3, shape=shape_cf, target_proba=target_proba, tol=tol, + target_class=target_class, max_iter=max_iter, lam_init=lam_init, + max_lam_steps=max_lam_steps, learning_rate_init=learning_rate_init, + feature_range=feature_range) + +shape = X_train[0].shape[0] +results=[] + +X = np.concatenate([X_train,X_test]) +#X=X_test +print(X.shape[0], "-x shape 0") +for i in np.arange(0,X.shape[0]): +# for i in range(4): + + x_sample=X[i:i+1] + print(x_sample.shape) + start = time() + explanation = cf.explain(x_sample) + iter_cf = 0 + n = len(explanation['all'][iter_cf]) + + print('Counterfactual prediction: {}, {}'.format(explanation.cf['class'], explanation.cf['proba'])) + print("Actual prediction: {}".format(model_nt3.predict(x_sample))) + + results.append([i, explanation.cf['X'],explanation.cf['class'], explanation.cf['proba']]) + + if ((i+1)%2 == 0): + print("saving i=", i) + filename = "save.p" + str(i) + pickle.dump(results, open(filename, "wb")) + results=[] + + if i==1: + print("before exit", i) + break diff --git a/Pilot1/NT3/nt3_cf/environment.yml b/Pilot1/NT3/nt3_cf/environment.yml new file mode 100644 index 00000000..669ca25e --- /dev/null +++ b/Pilot1/NT3/nt3_cf/environment.yml @@ -0,0 +1,264 @@ +name: xai-geom-tf +channels: + - anaconda + - conda-forge + - defaults +dependencies: + - _libgcc_mutex=0.1=conda_forge + - _openmp_mutex=4.5=1_gnu + - _tflow_select=2.1.0=gpu + - absl-py=0.11.0=py38h578d9bd_0 + - aiohttp=3.7.3=py38h497a2fe_1 + - anyio=2.0.2=py38h578d9bd_4 + - argon2-cffi=20.1.0=py38h497a2fe_2 + - astor=0.8.1=pyh9f0ad1d_0 + - astunparse=1.6.3=pyhd8ed1ab_0 + - async-timeout=3.0.1=py_1000 + - async_generator=1.10=py_0 + - attrs=20.3.0=pyhd3deb0d_0 + - babel=2.9.0=pyhd3deb0d_0 + - backcall=0.2.0=pyh9f0ad1d_0 + - backports=1.0=py_2 + - backports.functools_lru_cache=1.6.1=py_0 + - bleach=3.3.0=pyh44b312d_0 + - blinker=1.4=py_1 + - brotlipy=0.7.0=py38h497a2fe_1001 + - c-ares=1.17.1=h36c2ea0_0 + - ca-certificates=2020.10.14=0 + - cachetools=4.2.1=pyhd8ed1ab_0 + - certifi=2020.6.20=py38_0 + - cffi=1.14.4=py38ha65f79e_1 + - chardet=3.0.4=py38h924ce5b_1008 + - click=7.1.2=pyh9f0ad1d_0 + - cryptography=3.3.1=py38h2b97feb_1 + - cudatoolkit=10.1.243=h036e899_7 + - cudnn=7.6.5.32=hc0a50b0_1 + - cupti=10.1.168=0 + - cycler=0.10.0=py_2 + - dbus=1.13.6=hfdff14a_1 + - decorator=4.4.2=py_0 + - defusedxml=0.6.0=py_0 + - entrypoints=0.3=pyhd8ed1ab_1003 + - expat=2.2.10=h9c3ff4c_0 + - fontconfig=2.13.1=hba837de_1004 + - freetype=2.10.4=h0708190_1 + - gast=0.3.3=py_0 + - gettext=0.19.8.1=h0b5b191_1005 + - glib=2.66.4=hc4f0c31_2 + - glib-tools=2.66.4=hc4f0c31_2 + - google-auth=1.24.0=pyhd3deb0d_0 + - google-auth-oauthlib=0.4.1=py_2 + - google-pasta=0.2.0=pyh8c360ce_0 + - grpcio=1.35.0=py38hdd6454d_0 + - gst-plugins-base=1.14.5=h0935bb2_2 + - gstreamer=1.18.3=h3560a44_0 + - h5py=2.10.0=nompi_py38h7442b35_105 + - hdf5=1.10.6=nompi_h6a2412b_1114 + - icu=68.1=h58526e2_0 + - idna=2.10=pyh9f0ad1d_0 + - importlib-metadata=3.4.0=py38h578d9bd_0 + - importlib_metadata=3.4.0=hd8ed1ab_0 + - intel-openmp=2020.2=254 + - ipykernel=5.3.4=py38h5ca1d4c_0 + - ipython=7.20.0=py38h81c977d_0 + - ipython_genutils=0.2.0=py_1 + - jedi=0.18.0=py38h578d9bd_2 + - jinja2=2.11.3=pyh44b312d_0 + - jpeg=9d=h36c2ea0_0 + - json5=0.9.5=pyh9f0ad1d_0 + - jsonschema=3.2.0=py_2 + - jupyter_client=6.1.11=pyhd8ed1ab_1 + - jupyter_core=4.7.1=py38h578d9bd_0 + - jupyter_server=1.2.3=py38h578d9bd_1 + - jupyterlab=3.0.6=pyhd8ed1ab_0 + - jupyterlab_pygments=0.1.2=pyh9f0ad1d_0 + - jupyterlab_server=2.1.3=pyhd8ed1ab_0 + - keras-preprocessing=1.1.2=pyhd8ed1ab_0 + - kiwisolver=1.3.1=py38h1fd1430_1 + - krb5=1.17.2=h926e7f8_0 + - lcms2=2.11=hcbb858e_1 + - ld_impl_linux-64=2.35.1=hea4e1c9_2 + - libblas=3.9.0=7_openblas + - libcblas=3.9.0=7_openblas + - libclang=11.0.1=default_ha53f305_1 + - libcurl=7.71.1=hcdd3856_8 + - libedit=3.1.20191231=he28a2e2_2 + - libev=4.33=h516909a_1 + - libevent=2.1.10=hcdb4288_3 + - libffi=3.3=h58526e2_2 + - libgcc-ng=9.3.0=h2828fa1_18 + - libgfortran-ng=9.3.0=hff62375_18 + - libgfortran5=9.3.0=hff62375_18 + - libglib=2.66.4=h748fe8e_2 + - libgomp=9.3.0=h2828fa1_18 + - libiconv=1.16=h516909a_0 + - liblapack=3.9.0=7_openblas + - libllvm11=11.0.1=hf817b99_0 + - libnghttp2=1.43.0=h812cca2_0 + - libopenblas=0.3.12=pthreads_h4812303_1 + - libpng=1.6.37=h21135ba_2 + - libpq=12.3=h255efa7_3 + - libprotobuf=3.14.0=h780b84a_0 + - libsodium=1.0.18=h36c2ea0_1 + - libssh2=1.9.0=hab1572f_5 + - libstdcxx-ng=9.3.0=h6de172a_18 + - libtiff=4.2.0=hdc55705_0 + - libuuid=2.32.1=h7f98852_1000 + - libwebp-base=1.2.0=h7f98852_0 + - libxcb=1.13=h7f98852_1003 + - libxkbcommon=1.0.3=he3ba5ed_0 + - libxml2=2.9.10=h72842e0_3 + - lz4-c=1.9.3=h9c3ff4c_0 + - markdown=3.3.3=pyh9f0ad1d_0 + - markupsafe=1.1.1=py38h497a2fe_3 + - matplotlib=3.3.4=py38h578d9bd_0 + - matplotlib-base=3.3.4=py38h0efea84_0 + - mistune=0.8.4=py38h497a2fe_1003 + - mkl=2020.2=256 + - multidict=5.1.0=py38h497a2fe_1 + - mysql-common=8.0.22=ha770c72_3 + - mysql-libs=8.0.22=h935591d_3 + - nbclassic=0.2.6=pyhd8ed1ab_0 + - nbclient=0.5.1=py_0 + - nbconvert=6.0.7=py38h578d9bd_3 + - nbformat=5.1.2=pyhd8ed1ab_1 + - ncurses=6.2=h58526e2_4 + - nest-asyncio=1.4.3=pyhd8ed1ab_0 + - ninja=1.10.2=h4bd325d_0 + - notebook=6.2.0=py38h578d9bd_0 + - nspr=4.29=h9c3ff4c_1 + - nss=3.61=hb5efdd6_0 + - numpy=1.20.0=py38h18fd61f_0 + - oauthlib=3.0.1=py_0 + - olefile=0.46=pyh9f0ad1d_1 + - openssl=1.1.1i=h7f98852_0 + - opt_einsum=3.3.0=py_0 + - packaging=20.8=pyhd3deb0d_0 + - pandoc=2.11.4=h7f98852_0 + - pandocfilters=1.4.2=py_1 + - parso=0.8.1=pyhd8ed1ab_0 + - pcre=8.44=he1b5a44_0 + - pexpect=4.8.0=pyh9f0ad1d_2 + - pickleshare=0.7.5=py_1003 + - pillow=8.1.0=py38h357d4e7_1 + - pip=21.0.1=pyhd8ed1ab_0 + - prometheus_client=0.9.0=pyhd3deb0d_0 + - prompt-toolkit=3.0.14=pyha770c72_0 + - protobuf=3.14.0=py38h709712a_1 + - pthread-stubs=0.4=h36c2ea0_1001 + - ptyprocess=0.7.0=pyhd3deb0d_0 + - pyasn1=0.4.8=py_0 + - pyasn1-modules=0.2.7=py_0 + - pycparser=2.20=pyh9f0ad1d_2 + - pygments=2.7.4=pyhd8ed1ab_0 + - pyjwt=2.0.1=pyhd8ed1ab_0 + - pyopenssl=20.0.1=pyhd8ed1ab_0 + - pyparsing=2.4.7=pyh9f0ad1d_0 + - pyqt=5.12.3=py38h578d9bd_7 + - pyqt-impl=5.12.3=py38h7400c14_7 + - pyqt5-sip=4.19.18=py38h709712a_7 + - pyqtchart=5.12=py38h7400c14_7 + - pyqtwebengine=5.12.1=py38h7400c14_7 + - pyrsistent=0.17.3=py38h497a2fe_2 + - pysocks=1.7.1=py38h578d9bd_3 + - python=3.8.6=hffdb5ce_5_cpython + - python-dateutil=2.8.1=py_0 + - python_abi=3.8=1_cp38 + - pytz=2021.1=pyhd8ed1ab_0 + - pyzmq=22.0.1=py38h3d7ac18_0 + - qt=5.12.9=h9d6b050_2 + - readline=8.0=he28a2e2_2 + - requests=2.25.1=pyhd3deb0d_0 + - requests-oauthlib=1.3.0=pyh9f0ad1d_0 + - rsa=4.7=pyhd3deb0d_0 + - send2trash=1.5.0=py_0 + - setuptools=49.6.0=py38h578d9bd_3 + - sip=4.19.13=py38he6710b0_0 + - six=1.15.0=pyh9f0ad1d_0 + - sniffio=1.2.0=py38h578d9bd_1 + - sqlite=3.34.0=h74cdb3f_0 + - tensorboard-plugin-wit=1.8.0=pyh44b312d_0 + - tensorflow=2.2.0=gpu_py38hb782248_0 + - tensorflow-base=2.2.0=gpu_py38h83e3d50_0 + - tensorflow-gpu=2.2.0=h0d30ee6_0 + - termcolor=1.1.0=py_2 + - terminado=0.9.2=py38h578d9bd_0 + - testpath=0.4.4=py_0 + - tk=8.6.10=h21135ba_1 + - tornado=6.1=py38h497a2fe_1 + - traitlets=5.0.5=py_0 + - typing-extensions=3.7.4.3=0 + - typing_extensions=3.7.4.3=py_0 + - urllib3=1.26.3=pyhd8ed1ab_0 + - wcwidth=0.2.5=pyh9f0ad1d_2 + - webencodings=0.5.1=py_1 + - werkzeug=1.0.1=pyh9f0ad1d_0 + - wheel=0.36.2=pyhd3deb0d_0 + - wrapt=1.12.1=py38h497a2fe_3 + - xorg-libxau=1.0.9=h7f98852_0 + - xorg-libxdmcp=1.1.3=h7f98852_0 + - xz=5.2.5=h516909a_1 + - yarl=1.6.3=py38h497a2fe_1 + - zeromq=4.3.3=h58526e2_3 + - zipp=3.4.0=py_0 + - zlib=1.2.11=h516909a_1010 + - zstd=1.4.8=ha95c52a_1 + - pip: + - alibi==0.5.5 + - altair==4.1.0 + - astropy==4.2 + - beautifulsoup4==4.9.3 + - blis==0.7.4 + - catalogue==2.0.1 + - click-plugins==1.1.1 + - cligj==0.7.1 + - cloudpickle==1.6.0 + - cymem==2.0.5 + - descartes==1.1.0 + - eli5==0.11.0 + - fiona==1.8.18 + - geopandas==0.8.2 + - imageio==2.9.0 + - joblib==1.0.0 + - keras==2.4.3 + - llvmlite==0.35.0 + - munch==2.5.0 + - murmurhash==1.0.5 + - networkx==2.5 + - numba==0.52.0 + - opt-einsum==3.3.0 + - pandas==1.2.1 + - pathy==0.3.4 + - patsy==0.5.1 + - preshed==3.0.5 + - pydantic==1.7.3 + - pyerfa==1.7.1.1 + - pyproj==3.0.0.post1 + - python-graphviz==0.16 + - pywavelets==1.1.1 + - pyyaml==5.4.1 + - scikit-image==0.18.1 + - scikit-learn==0.24.1 + - scipy==1.4.1 + - shap==0.38.1 + - shapely==1.7.1 + - slicer==0.0.7 + - smart-open==3.0.0 + - soupsieve==2.1 + - spacy==3.0.0 + - spacy-legacy==3.0.1 + - spacy-lookups-data==1.0.0 + - srsly==2.4.0 + - statsmodels==0.12.2 + - tabulate==0.8.7 + - tensorboard==2.2.2 + - tensorflow-estimator==2.2.0 + - thinc==8.0.1 + - threadpoolctl==2.1.0 + - tifffile==2021.1.14 + - toolz==0.11.1 + - tqdm==4.56.0 + - typer==0.3.2 + - wasabi==0.8.2 +prefix: /vol/ml/shahashka/anaconda3/envs/xai-geom-tf + diff --git a/Pilot1/NT3/nt3_cf/gen_clusters.py b/Pilot1/NT3/nt3_cf/gen_clusters.py new file mode 100644 index 00000000..cf1aa43a --- /dev/null +++ b/Pilot1/NT3/nt3_cf/gen_clusters.py @@ -0,0 +1,109 @@ + +import numpy as np +import pickle +import matplotlib.pyplot as plt +from sklearn.cluster import KMeans +from sklearn.decomposition import PCA +from sklearn.metrics import silhouette_score +import argparse + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-t", type=str, help="threshod input file") + parser.add_argument("-t_value", type=float, help="threshold value") + args = parser.parse_args() + return args + +if __name__ == '__main__': + + args = get_args() + + thresholds_9 = pickle.load(open(args.t, 'rb')) + + perturb_vector=thresholds_9['perturbation vector'] + cf_class = thresholds_9['counterfactual class'] + indices = thresholds_9['sample index'] + + # split by class + perturb_vector_0=[] + perturb_vector_1=[] + indices_0 = [] + indices_1 = [] + for i,j,k in zip(perturb_vector, cf_class, indices): + if j==0: + perturb_vector_0.append(i) + indices_0.append(k) + else: + perturb_vector_1.append(i) + indices_1.append(k) + + indices_0 = np.array(indices_0) + indices_1 = np.array(indices_1) + sil = [] + kmax = 10 + + # dissimilarity would not be defined for a single cluster, thus, minimum number of clusters should be 2 + for k in range(2, kmax + 1): + kmeans = KMeans(n_clusters=k).fit(perturb_vector_0) + labels = kmeans.labels_ + sil.append(silhouette_score(perturb_vector_0, labels, metric='euclidean')) + plt.plot(np.arange(2, kmax+1), sil) + plt.title("Silhouette scores to determine optimal k") + plt.xlabel("k") + plt.show() + k = np.argmax(sil) + 2 + print(k) + data_2D = PCA(2).fit_transform(perturb_vector_0) + kmeans_0 = KMeans(n_clusters=k).fit(perturb_vector_0) + labels_0 = kmeans_0.labels_ + colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] + for i in range(k): + plt.scatter(data_2D[:,0][labels_0==i], data_2D[:,1][labels_0==i], c=colors[i%len(colors)]) + plt.title("CF 0 KMeans clusters with 2D PCA") + plt.savefig("CF_0.png") + + sil=[] + for k in range(2, kmax + 1): + kmeans = KMeans(n_clusters=k).fit(perturb_vector_1) + labels = kmeans.labels_ + sil.append(silhouette_score(perturb_vector_1, labels, metric='euclidean')) + plt.plot(np.arange(2, kmax+1), sil) + plt.title("Silhouette scores to determine optimal k") + plt.xlabel("k") + plt.show() + k = np.argmax(sil) + 2 + print(k) + data_2D = PCA(2).fit_transform(perturb_vector_1) + kmeans_1 = KMeans(n_clusters=k).fit(perturb_vector_1) + labels_1 = kmeans_1.labels_ + colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] + for i in range(k): + plt.scatter(data_2D[:,0][labels_1==i], data_2D[:,1][labels_1==i], c=colors[i%len(colors)]) + plt.title("CF 1 KMeans clusters with 2D PCA") + plt.savefig("CF_1.png") + +for i in range(len(kmeans_0.cluster_centers_)): + diff_0=kmeans_0.cluster_centers_[i] + max_value = np.max(np.abs(diff_0)) + ind_pos = np.where(diff_0 > args.t_value*max_value) + ind_neg = np.where(diff_0 < -1*args.t_value*max_value) + output = {'centroid perturb vector': diff_0, + 'positive threshold indices':ind_pos, + 'negative threshold indices':ind_neg, + 'sample indices in this cluster':indices_0[labels_0==i]} + print(output) + pickle.dump(output, + open("cf_class_0_cluster{}.pkl".format(i), "wb")) + +for i in range(len(kmeans_1.cluster_centers_)): + diff_1=kmeans_1.cluster_centers_[i] + max_value = np.max(np.abs(diff_1)) + ind_pos = np.where(diff_1 > args.t_value*max_value) + ind_neg = np.where(diff_1 < -1*args.t_value*max_value) + output = {'centroid perturb vector': diff_1, + 'positive threshold indices':ind_pos, + 'negative threshold indices':ind_neg, + 'sample indices in this cluster':indices_1[labels_1==i]} + print(output) + pickle.dump(output, + open("cf_class_1_cluster{}.pkl".format(i), "wb")) diff --git a/Pilot1/NT3/nt3_cf/inject_noise.py b/Pilot1/NT3/nt3_cf/inject_noise.py new file mode 100644 index 00000000..c2481eff --- /dev/null +++ b/Pilot1/NT3/nt3_cf/inject_noise.py @@ -0,0 +1,81 @@ +import pickle +import numpy as np +import copy +import argparse + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-t", type=str, help="threshold pickle file") + parser.add_argument("-c1", type=list, help="cluster 1") + parser.add_argument("-c2", type=list, help="cluster 2") + parser.add_argument("-scale", type=float, help="scale factor for noise injection") + parser.add_argument("-r", type=bool, help="flag to add random noise") + args = parser.parse_args() + return args +def random_noise(c1,c2,scale,size, cluster_inds): + X_train, y_train, X_test, y_test = pickle.load(open("nt3.autosave.data.pkl", 'rb')) + X_data = np.concatenate([X_train, X_test]) + y_data = np.concatenate([y_train, y_test]) + genes = np.random.choice(np.arange(X_data.shape[0]), replace=False, size=size) + noise = np.random.normal(0,1,size) + X_data_noise = copy.deepcopy(X_data) + print(c1,c2) + for p in np.arange(0.1,1.0, 0.1): + for i in cluster_inds: + for j in range(size): + X_data_noise[i][genes[j]]+=noise[j] + pickle.dump([X_data_noise, y_data, [], cluster_inds], open("nt3.data.random.scale_{}_cluster_{}_{}.noise_{}.pkl".format(scale,c1,c2,round(p,1)), "wb")) + +def main(): + args = get_args() + # For 2 clusters (with sparse injection feature vector) add CF noise to x% of samples + X_train, y_train, X_test, y_test = pickle.load(open("nt3.autosave.data.pkl", 'rb')) + threshold_dataset = pickle.load(open(args.t, 'rb')) + perturb_dataset = threshold_dataset['perturbation vector'] + #failed index + perturb_dataset.insert(919, np.zeros(X_train.shape[1])) + perturb_dataset = np.array(perturb_dataset) + X_data = np.concatenate([X_train, X_test]) + y_data = np.concatenate([y_train, y_test]) + clusters = [(0,1),(1,1)] + cluster_files = [] + for c in clusters: + cluster_files.append(pickle.load(open("clusters_0911_0.5/cf_class_{}_cluster{}.pkl".format(c[0], c[1]), 'rb'))) + for i in range(len(cluster_files)): + d=cluster_files[i] + cluster_inds = d['sample indices in this cluster'] + random_noise(clusters[i][0],clusters[i][1],args.scale,20, cluster_inds) + #return + for p in np.arange(0.1,1.0, 0.1): + print("p={}".format(p)) + X_data_noise = copy.deepcopy(X_data) + # Full cf injection + # Choose x% of the indices to be perturbed + selector = np.random.choice(a=cluster_inds, replace=False, size = (int)(p*len(cluster_inds))) + #print(perturb_dataset[selector]) + X_data_noise[selector]-= args.scale*perturb_dataset[selector][:,:,None] + #print(np.sum(X_data_noise - X_data)) + pickle.dump([X_data_noise, y_data, selector, cluster_inds], open("nt3.data.scale_{}_cluster_{}_{}.noise_{}.pkl".format(args.scale,clusters[i][0], clusters[i][1], round(p,1)), "wb")) + + # Threshold cf injection + inds = [] + print(d) + for j in d['positive threshold indices'][0]: + inds.append(j) + for j in d['negative threshold indices'][0]: + inds.append(j) + print(len(inds)) + X_data_noise_2 = copy.deepcopy(X_data) + for j in inds: + perturb_dataset[:,j]=0 + X_data_noise_2[selector]-= args.scale*perturb_dataset[selector][:,:,None] + pickle.dump([X_data_noise_2, y_data, selector, cluster_inds], open("nt3.data.threshold_scale_{}_cluster_{}_{}.noise_{}.pkl".format(args.scale, clusters[i][0], clusters[i][1], round(p,1)), "wb")) + +if __name__ == "__main__": + main() + + + + + +# Save dataset file diff --git a/Pilot1/NT3/nt3_cf/nt3.ipynb b/Pilot1/NT3/nt3_cf/nt3.ipynb new file mode 100644 index 00000000..6e5cd05c --- /dev/null +++ b/Pilot1/NT3/nt3_cf/nt3.ipynb @@ -0,0 +1,426 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "hazardous-tokyo", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "TF version: 2.2.0\n", + "Eager execution enabled: False\n" + ] + } + ], + "source": [ + "import tensorflow as tf\n", + "tf.get_logger().setLevel(40) # suppress deprecation messages\n", + "tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs \n", + "from tensorflow.keras.models import Model, load_model\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import os\n", + "# os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n", + "from time import time\n", + "from alibi.explainers import CounterFactual, CounterFactualProto\n", + "print('TF version: ', tf.__version__)\n", + "print('Eager execution enabled: ', tf.executing_eagerly()) # False\n", + "import pickle" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "polar-netherlands", + "metadata": {}, + "outputs": [], + "source": [ + "model_nt3 = tf.keras.models.load_model('/vol/ml/shahashka/xai-geom/nt3/nt3.autosave.model')\n", + "with open('/vol/ml/shahashka/xai-geom/nt3/nt3.autosave.data.pkl', 'rb') as pickle_file:\n", + " X_train,Y_train,X_test,Y_test = pickle.load(pickle_file)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "satellite-passage", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1, 60483, 1)\n" + ] + }, + { + "ename": "UnknownError", + "evalue": "2 root error(s) found.\n (0) Unknown: Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.\n\t [[{{node conv1d/conv1d}}]]\n\t [[activation_4/Softmax/_83]]\n (1) Unknown: Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.\n\t [[{{node conv1d/conv1d}}]]\n0 successful operations.\n0 derived errors ignored.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mUnknownError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mfeature_range\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m cf = CounterFactual(model_nt3, shape=shape_cf, target_proba=target_proba, tol=tol,\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0mtarget_class\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtarget_class\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_iter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_iter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlam_init\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlam_init\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mmax_lam_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_lam_steps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearning_rate_init\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlearning_rate_init\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/vol/ml/shahashka/anaconda3/envs/xai-geom-tf/lib/python3.8/site-packages/alibi/explainers/counterfactual.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, predict_fn, shape, distance_fn, target_proba, target_class, max_iter, early_stop, lam_init, max_lam_steps, tol, learning_rate_init, feature_range, eps, init, decay, write_dir, debug, sess)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_classes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 180\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[0;31m# flag to keep track if explainer is fit or not\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/vol/ml/shahashka/anaconda3/envs/xai-geom-tf/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_v1.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[1;32m 955\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 956\u001b[0m \u001b[0mfunc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_select_training_loop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 957\u001b[0;31m return func.predict(\n\u001b[0m\u001b[1;32m 958\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 959\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/vol/ml/shahashka/anaconda3/envs/xai-geom-tf/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_arrays.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, model, x, batch_size, verbose, steps, callbacks, **kwargs)\u001b[0m\n\u001b[1;32m 706\u001b[0m x, _, _ = model._standardize_user_data(\n\u001b[1;32m 707\u001b[0m x, check_steps=True, steps_name='steps', steps=steps)\n\u001b[0;32m--> 708\u001b[0;31m return predict_loop(\n\u001b[0m\u001b[1;32m 709\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 710\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/vol/ml/shahashka/anaconda3/envs/xai-geom-tf/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_arrays.py\u001b[0m in \u001b[0;36mmodel_iteration\u001b[0;34m(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_freq, mode, validation_in_fit, prepared_feed_values_from_dataset, steps_name, **kwargs)\u001b[0m\n\u001b[1;32m 384\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 385\u001b[0m \u001b[0;31m# Get outputs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 386\u001b[0;31m \u001b[0mbatch_outs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 387\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_outs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[0mbatch_outs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mbatch_outs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/vol/ml/shahashka/anaconda3/envs/xai-geom-tf/lib/python3.8/site-packages/tensorflow/python/keras/backend.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 3629\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_callable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeed_arrays\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_symbols\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msymbol_vals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msession\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3630\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3631\u001b[0;31m fetched = self._callable_fn(*array_vals,\n\u001b[0m\u001b[1;32m 3632\u001b[0m run_metadata=self.run_metadata)\n\u001b[1;32m 3633\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_fetch_callbacks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfetched\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fetches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m/vol/ml/shahashka/anaconda3/envs/xai-geom-tf/lib/python3.8/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1468\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1469\u001b[0m \u001b[0mrun_metadata_ptr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_NewBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1470\u001b[0;31m ret = tf_session.TF_SessionRunCallable(self._session._session,\n\u001b[0m\u001b[1;32m 1471\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_handle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1472\u001b[0m run_metadata_ptr)\n", + "\u001b[0;31mUnknownError\u001b[0m: 2 root error(s) found.\n (0) Unknown: Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.\n\t [[{{node conv1d/conv1d}}]]\n\t [[activation_4/Softmax/_83]]\n (1) Unknown: Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.\n\t [[{{node conv1d/conv1d}}]]\n0 successful operations.\n0 derived errors ignored." + ] + } + ], + "source": [ + "shape_cf = (1,) + X_train.shape[1:] \n", + "print(shape_cf)\n", + "target_proba = 0.9\n", + "tol = 0.1 # want counterfactuals with p(class)>0.90\n", + "target_class = 'other' # any class other than will do\n", + "max_iter = 1000\n", + "lam_init = 1e-1\n", + "max_lam_steps = 20\n", + "learning_rate_init = 0.1\n", + "feature_range = (0,1)\n", + "\n", + "cf = CounterFactual(model_nt3, shape=shape_cf, target_proba=target_proba, tol=tol,\n", + " target_class=target_class, max_iter=max_iter, lam_init=lam_init,\n", + " max_lam_steps=max_lam_steps, learning_rate_init=learning_rate_init,\n", + " feature_range=feature_range)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "closing-quarterly", + "metadata": {}, + "outputs": [], + "source": [ + "shape = X_train[0].shape[0]\n", + "results=[]\n", + "for i in np.arange(0,5):\n", + " x_sample=X_train[i:i+1]\n", + " print(x_sample.shape)\n", + " start = time()\n", + " try:\n", + " explanation = cf.explain(x_sample)\n", + " print('Counterfactual prediction: {}, {}'.format(explanation.cf['class'], explanation.cf['proba']))\n", + " print(\"Actual prediction: {}\".format(model_nt3.predict(x_sample)))\n", + " results.append([explanation.cf['X'],explanation.cf['class'], explanation.cf['proba']])\n", + " # if counterfactual not found make a dummy array \n", + " except IndexError:\n", + " dummy = np.empty(x_sample.shape)\n", + " dummy[:] = np.nan\n", + " results.append(dummy)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dirty-hebrew", + "metadata": {}, + "outputs": [], + "source": [ + "print(X_train.shape, X_test.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "reliable-video", + "metadata": {}, + "outputs": [], + "source": [ + "pickle.dump(results, open(\"small_cf_test.pkl\", \"wb\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "rapid-beauty", + "metadata": {}, + "outputs": [], + "source": [ + "for i in range(len(results)):\n", + " plt.figure(figsize=(20, 4))\n", + " sample = X_train[i].flatten()\n", + " y = results[i][0].flatten()\n", + " x = np.arange(y.shape[0])\n", + " plt.plot(x,y,alpha=0.5, label='counterfactual')\n", + " plt.plot(x,sample,alpha=0.5, label='input')\n", + " plt.plot(x,sample-y, label='diff')\n", + " props = dict(boxstyle='round', facecolor='wheat', alpha=1)\n", + " prediction = model_nt3.predict(X_test[i:i+1])\n", + " plt.text(0.05, 0.95, \"original input: {} {} \\n counterfactual: {} {}\".format(np.argmax(prediction), \n", + " prediction,results[i][1] ,results[i][2]), \n", + " fontsize=16,\n", + " verticalalignment='top', bbox=props)\n", + " plt.legend()\n", + " plt.show()" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "defensive-seeker", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.4542566392434886\n", + "0.4086305624547959\n", + "\n", + "\n", + "0.5043837824693964\n", + "0.46664387832620924\n", + "\n", + "\n", + "0.6335310691943951\n", + "0.6069955054804324\n", + "\n", + "\n", + "0.6563886318117149\n", + "0.5876853023454891\n", + "\n", + "\n", + "0.16572694427639978\n", + "0.1858917752190248\n", + "\n", + "\n" + ] + } + ], + "source": [ + "from scipy.stats import pearsonr\n", + "Y_flag = np.argmax(Y_test,axis=1)\n", + "\n", + "for r in range(len(results)):\n", + " pearson_0 = []\n", + " pearson_1 = []\n", + " for i in range(len(Y_flag)):\n", + " if Y_flag[i]==0:\n", + " pearson_0.append(pearsonr(results[r][0].flatten(), X_test[i].flatten())[0])\n", + " else:\n", + " pearson_1.append(pearsonr(results[r][0].flatten(), X_test[i].flatten())[0])\n", + "\n", + " print(np.average(pearson_0))\n", + " print(np.average(pearson_1))\n", + " print(\"\\n\")" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "corporate-future", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[0.36569104 0.634309 ]\n" + ] + } + ], + "source": [ + "Y_predict = model_nt3.predict(X_test)\n", + "noise = np.random.uniform(-0.3, 0.3, X_test.shape)\n", + "Y_predict_noise = model_nt3.predict(X_test+noise)\n", + "print(Y_predict_noise[0])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "subtle-blood", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Text(0.5, 1.0, 'Class 0 predictions with uniform random noise [-0.2,0.2]')" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "\n", + "text/plain": [ + "
" + ] + }, + "metadata": { + "needs_background": "light" + }, + "output_type": "display_data" + } + ], + "source": [ + "Y_flag_predict = np.argmax(Y_predict,axis=1)\n", + "Y_flag_predict_noise = np.argmax(Y_predict_noise,axis=1)\n", + "fig, ax = plt.subplots(figsize=(10,7))\n", + "# Example data\n", + "num_samples=5\n", + "y_pos = np.arange(num_samples)\n", + "y_pos_off = y_pos+0.25\n", + "ax.barh(y_pos, Y_predict[0:num_samples][:,0], height=0.25,align='center', label=\"predict\", alpha=0.5)\n", + "ax.barh(y_pos_off, Y_predict_noise[0:num_samples][:,0],height=0.25, align='center', label=\"predict with noise\")\n", + "\n", + "ax.set_yticks(y_pos)\n", + "ax.invert_yaxis() # labels read top-to-bottom\n", + "plt.legend()\n", + "plt.xlabel(\"Class 0 prediction probability\")\n", + "plt.ylabel(\"Input index\")\n", + "plt.title(\"Class 0 predictions with uniform random noise [-0.2,0.2]\")" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "oriented-factor", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "0.0858112 0.2\n", + "2986\n", + "3631\n", + "1\n", + "[0.01942994 0.01748677 0.01767595 ... 0.05898305 0.01972752 0.01901901]\n" + ] + } + ], + "source": [ + "test_y = X_test[1].flatten()\n", + "test_cf = results[1][0].flatten()\n", + "diff = test_y-test_cf\n", + "threshold=0.2\n", + "max_value = np.max(np.abs(diff))\n", + "print(max_value ,threshold)\n", + "ind_pos = np.where(diff > threshold*max_value)\n", + "ind_neg = np.where(diff < -threshold*max_value)\n", + "print(len(ind_pos[0]))\n", + "print(len(ind_neg[0]))\n", + "cf_class = np.abs(1-np.argmax(Y_test[0]))\n", + "print(cf_class)\n", + "print(diff[ind_pos[0]])" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "behind-associate", + "metadata": {}, + "outputs": [], + "source": [ + "test = np.concatenate([X_train,X_test])" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "flexible-sussex", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1400, 60483, 1)\n" + ] + } + ], + "source": [ + "print(test.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "specific-director", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1120, 60483, 1) (280, 60483, 1)\n" + ] + } + ], + "source": [ + "print(X_train.shape, X_test.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "id": "muslim-mortgage", + "metadata": {}, + "outputs": [], + "source": [ + "test = np.concatenate([Y_train,Y_test])" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "id": "alternate-wrestling", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(1400, 2)\n" + ] + } + ], + "source": [ + "print(test.shape)" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "sweet-venezuela", + "metadata": {}, + "outputs": [], + "source": [ + "with open('small_threshold.pkl', 'rb') as pickle_file:\n", + " t_results = pickle.load(pickle_file)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "xai-geom-tf", + "language": "python", + "name": "xai-geom-tf" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/Pilot1/NT3/nt3_cf/test_cf_accuracy.py b/Pilot1/NT3/nt3_cf/test_cf_accuracy.py new file mode 100644 index 00000000..d90a834c --- /dev/null +++ b/Pilot1/NT3/nt3_cf/test_cf_accuracy.py @@ -0,0 +1,66 @@ +import tensorflow as tf +tf.get_logger().setLevel(40) # suppress deprecation messages +tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs +from tensorflow.keras.models import Model, load_model +import matplotlib.pyplot as plt +import numpy as np +import os +os.environ["CUDA_VISIBLE_DEVICES"]="1" +from time import time +from alibi.explainers import CounterFactual, CounterFactualProto +print('TF version: ', tf.__version__) +print('Eager execution enabled: ', tf.executing_eagerly()) # False +print(tf.test.is_gpu_available()) +import pickle + +model_nt3 = tf.keras.models.load_model('/vol/ml/shahashka/xai-geom/nt3/nt3.autosave.model') +# results = [] +# for i in np.arange(0.1,1.0, 0.1): +# cf_dataset = pickle.load(open("nt3.data.scale_1.0.cluster_0_1.noise_{}.pkl".format(round(i,2)), "rb")) +# X_cf_dataset = cf_dataset[0] +# y_cf_dataset = cf_dataset[1] +# cluster_inds = cf_dataset[-1] +# print(model_nt3.metrics_names) +# acc = model_nt3.evaluate(X_cf_dataset, y_cf_dataset) +# cluster_acc = model_nt3.evaluate(X_cf_dataset[cluster_inds], y_cf_dataset[cluster_inds]) +# print(i, acc, cluster_acc) +# results.append([acc[1], cluster_acc[1]]) +# plt.plot(np.arange(0.1,1.0,0.1), results[:,0], label="accuracy", marker='o') +# plt.plot(np.arange(0.1,1.0, 0.1), results[:,1], label="cluster accuracy", marker='o') + +results = [] +for i in np.arange(0.1,1.0, 0.1): + cf_dataset = pickle.load(open("nt3.data.threshold_scale_1.0_cluster_0_1.noise_{}.pkl".format(round(i,2)), "rb")) + X_cf_dataset = cf_dataset[0] + y_cf_dataset = cf_dataset[1] + cluster_inds = cf_dataset[-1] + print(model_nt3.metrics_names) + acc = model_nt3.evaluate(X_cf_dataset, y_cf_dataset) + cluster_acc = model_nt3.evaluate(X_cf_dataset[cluster_inds], y_cf_dataset[cluster_inds]) + print(i, acc, cluster_acc) + results.append([acc[1], cluster_acc[1]]) +results = np.array(results) +plt.plot(np.arange(0.1,1.0,0.1), results[:,0], label="accuracy", marker='o') +plt.plot(np.arange(0.1,1.0, 0.1), results[:,1], label="cluster accuracy", marker='o') + +results = [] +for i in np.arange(0.1,1.0, 0.1): + cf_dataset = pickle.load(open("nt3.data.random.scale_1.0_cluster_0_1.noise_{}.pkl".format(round(i,2)), "rb")) + X_cf_dataset = cf_dataset[0] + y_cf_dataset = cf_dataset[1] + cluster_inds = cf_dataset[-1] + print(model_nt3.metrics_names) + acc = model_nt3.evaluate(X_cf_dataset, y_cf_dataset) + cluster_acc = model_nt3.evaluate(X_cf_dataset[cluster_inds], y_cf_dataset[cluster_inds]) + print(i, acc, cluster_acc) + results.append([acc[1], cluster_acc[1]]) +results = np.array(results) +plt.plot(np.arange(0.1,1.0,0.1), results[:,0], label="accuracy with Gaussian noise", marker='o') +plt.plot(np.arange(0.1,1.0, 0.1), results[:,1], label="cluster accuracy with Gaussian noise", marker='o') + + +plt.xlabel("Noise fraction in cluster") +plt.ylabel("Accuracy") +plt.legend() +plt.title("Model accuracy with counterfactual noise injection for class 0, cluster 1") +plt.savefig("abstract_plot.png") diff --git a/Pilot1/NT3/nt3_cf/threshold.py b/Pilot1/NT3/nt3_cf/threshold.py new file mode 100644 index 00000000..92841f12 --- /dev/null +++ b/Pilot1/NT3/nt3_cf/threshold.py @@ -0,0 +1,70 @@ +# Example run python threshold.py -d nt3.autosave.data.pkl -c small_cf.pkl -t 0.2 -o small_threshold.pkl +import pickle +import numpy as np +import argparse + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument('-d', type=str, + help='data input file', required=True) + parser.add_argument('-c', type=str, + help='counterfactual input file', required=True) + parser.add_argument('-o', type=str, + help='output file', required=True) + parser.add_argument('-t', type=float, + help='threshold value', required=True) + + args = parser.parse_args() + return args + +def threshold(t_value, X, y, cf): + pos = [] + neg = [] + cf_classes = [] + inds = [] + diffs = [] + for i in range(len(cf)): + test_y = X[i].flatten() + test_cf = cf[i][0].flatten() + + diff = test_y-test_cf + max_value = np.max(np.abs(diff)) + + ind_pos = np.where(diff > t_value*max_value) + ind_neg = np.where(diff < -t_value*max_value) + + cf_class = np.abs(1-np.argmax(y[i])) + + pos.append(ind_pos) + neg.append(ind_neg) + cf_classes.append(cf_class) + inds.append(i) + diffs.append(diff) + + return pos,neg,cf_classes,inds, diffs + +def main(): + args = get_args() + with open(args.d, 'rb') as pickle_file: + X_train,Y_train,X_test,Y_test = pickle.load(pickle_file) + + with open(args.c, 'rb') as pickle_file: + cf = pickle.load(pickle_file) + + X = np.concatenate([X_train,X_test]) + Y = np.concatenate([Y_train, Y_test]) +# X=X_test +# Y=Y_test + pos,neg,cf_classes,inds, diff = threshold(args.t, X, Y, cf) + + # Note that sample index is here to keep track of counterfactuals that succeeded, counterfactuals that failed are not included here + results = {'sample index': inds, + 'positive threshold indices': pos, + 'negative threshold indices':neg, + 'counterfactual class':cf_classes, + 'perturbation vector': diff} + pickle.dump(results, open(args.o, "wb")) + + +if __name__ == "__main__": + main() \ No newline at end of file From 3a82f29625cdec37be930a35c47a3adf528f5f15 Mon Sep 17 00:00:00 2001 From: Ashka Shah Date: Mon, 4 Oct 2021 10:07:15 -0500 Subject: [PATCH 02/12] modify baseline to save model and data files for cf in TF and pickle formats --- Pilot1/NT3/nt3_baseline_keras2.py | 6 +++++- Pilot1/NT3/nt3_default_model.txt | 1 + common/parsing_utils.py | 6 +++++- 3 files changed, 11 insertions(+), 2 deletions(-) diff --git a/Pilot1/NT3/nt3_baseline_keras2.py b/Pilot1/NT3/nt3_baseline_keras2.py index 4ab49ce7..ac25eaf0 100644 --- a/Pilot1/NT3/nt3_baseline_keras2.py +++ b/Pilot1/NT3/nt3_baseline_keras2.py @@ -15,7 +15,7 @@ import nt3 as bmk import candle - +import pickle def initialize_parameters(default_model='nt3_default_model.txt'): @@ -278,6 +278,10 @@ def run(gParameters): print("yaml %s: %.2f%%" % (loaded_model_yaml.metrics_names[1], score_yaml[1] * 100)) + if gParameters['save_cf']: + model.save('{}/{}.autosave.model'.format(output_dir, model_name)) + pickle.dump([X_train, Y_train], open('{}/{}.autosave.data.train.pkl'.format(output_dir, model_name), "wb")) + pickle.dump( [X_test, Y_test], open( '{}/{}.autosave.data.test.pkl'.format(output_dir, model_name), "wb" )) return history diff --git a/Pilot1/NT3/nt3_default_model.txt b/Pilot1/NT3/nt3_default_model.txt index 61f00b4d..11f22a73 100644 --- a/Pilot1/NT3/nt3_default_model.txt +++ b/Pilot1/NT3/nt3_default_model.txt @@ -27,3 +27,4 @@ std_dev = 0.5 timeout = 3600 ckpt_restart_mode = 'off' ckpt_checksum = True +save_cf = True diff --git a/common/parsing_utils.py b/common/parsing_utils.py index f89e3900..e812f93f 100644 --- a/common/parsing_utils.py +++ b/common/parsing_utils.py @@ -94,7 +94,11 @@ {'name': 'run_id', 'type': str, 'default': 'RUN000', - 'help': 'set the run unique identifier.'} + 'help': 'set the run unique identifier.'}, + {'name': 'save_cf', + 'type': bool, + 'default': False, + 'help': 'save the model (Tensoflow saved model format) and data (pickle) objects for cf runs'} ] logging_conf = [ From 1e81d37d6f19b55c61cbde33f40ef3629e81579f Mon Sep 17 00:00:00 2001 From: Ashka Shah Date: Mon, 4 Oct 2021 13:17:17 -0500 Subject: [PATCH 03/12] update abstention model to handle pickle and cf data --- Pilot1/NT3/nt3_abstention_keras2.py | 48 +++++++++++++++++++++++++++-- Pilot1/NT3/nt3_baseline_keras2.py | 5 ++- Pilot1/NT3/nt3_cf/inject_noise.py | 18 +++++++++-- Pilot1/NT3/nt3_default_model.txt | 4 +-- Pilot1/NT3/nt3_noise_model.txt | 2 ++ common/parsing_utils.py | 15 ++++++--- 6 files changed, 79 insertions(+), 13 deletions(-) diff --git a/Pilot1/NT3/nt3_abstention_keras2.py b/Pilot1/NT3/nt3_abstention_keras2.py index e9cafc37..032730b2 100644 --- a/Pilot1/NT3/nt3_abstention_keras2.py +++ b/Pilot1/NT3/nt3_abstention_keras2.py @@ -16,6 +16,7 @@ import nt3 as bmk import candle +import pickle additional_definitions = abs_definitions @@ -50,7 +51,13 @@ def initialize_parameters(default_model='nt3_noise_model.txt'): gParameters = candle.finalize_parameters(nt3Bmk) return gParameters - + +def load_data_cf(cf_path): + # Pickle file holds the test train split and cf index info + print("Loading data...") + X_train, X_test, Y_train, Y_test, polluted_inds, cluster_inds = pickle.load(open(cf_path, 'rb')) + print('done') + return X_train, Y_train, X_test, Y_test, polluted_inds, cluster_inds def load_data(train_path, test_path, gParameters): @@ -104,6 +111,38 @@ def load_data(train_path, test_path, gParameters): return X_train, Y_train, X_test, Y_test +def evaluate_cf(model, nb_classes, output_dir, y_pred, y, polluted_inds, cluster_inds, gParameters): + if len(polluted_inds) > 0: + y_pred = model.predict(X_test) + abstain_inds = [] + for i in range(y_pred.shape[0]): + if np.argmax(y_pred[i]) == nb_classes: + abstain_inds.append(i) + + # Cluster indices and polluted indices are wrt to entire train + test dataset + # whereas y_pred only contains test dataset so add offset for correct indexing + offset_testset = Y_train.shape[0] + abstain_inds=[i+offset_testset for i in abstain_inds] + polluted_percentage = np.sum([el in polluted_inds for el in abstain_inds])/np.max([len(abstain_inds),\ +1]) + print("Percentage of abstained samples that were polluted {}".format(polluted_percentage)) + + cluster_inds_test = list(filter(lambda cluster_inds: cluster_inds >= offset_testset, cluster_inds)) + cluster_inds_test_abstain = [el in abstain_inds for el in cluster_inds_test] + cluster_percentage = c = np.sum(cluster_inds_test_abstain)/len(cluster_inds_test) + print("Percentage of cluster (in test set) that was abstained {}".format(cluster_percentage)) + + unabstain_inds = [] + for i in range(y_pred.shape[0]): + if np.argmax(y_pred[i]) != nb_classes and (i+offset_testset in cluster_inds_test): + unabstain_inds.append(i) + # Make sure number of unabstained indices in cluster test set plus number of abstainsed indices in cluster test set + # equals number of indices in cluster in the test set + assert(len(unabstain_inds)+np.sum(cluster_inds_test_abstain) == len(cluster_inds_test)) + score_cluster = 1 if len(unabstain_inds)==0 else model.evaluate(X_test[unabstain_inds], Y_test[unabstain_inds])[1] + print("Accuracy of unabastained cluster {}".format(score_cluster)) + if gParameters['noise_save_cf']: + pickle.dump({'Abs polluted': polluted_percentage, 'Abs val cluster': cluster_percentage, 'Abs val acc': score_cluster}, open("{}/cluster_trace.pkl".format(output_dir), "wb")) def run(gParameters): @@ -116,7 +155,10 @@ def run(gParameters): train_file = candle.get_file(file_train, url + file_train, cache_subdir='Pilot1') test_file = candle.get_file(file_test, url + file_test, cache_subdir='Pilot1') - X_train, Y_train, X_test, Y_test = load_data(train_file, test_file, gParameters) + if gParameters['noise_cf'] is not None: + X_train, Y_train, X_test, Y_test, polluted_inds, cluster_inds = load_data_cf(gParameters['noise_cf']) + else: + X_train, Y_train, X_test, Y_test = load_data(train_file, test_file, gParameters) # add extra class for abstention # first reverse the to_categorical @@ -291,6 +333,8 @@ def run(gParameters): score = model.evaluate(X_test, Y_test, verbose=0) + if gParameters['noise_cf'] is not None: + evaluate_cf(model, nb_classes, output_dir, y_pred, y, polluted_inds, cluster_inds, gParameters) alpha_trace = open(output_dir + "/alpha_trace", "w+") for alpha in abstention_cbk.alphavalues: alpha_trace.write(str(alpha) + '\n') diff --git a/Pilot1/NT3/nt3_baseline_keras2.py b/Pilot1/NT3/nt3_baseline_keras2.py index ac25eaf0..efe2e7d4 100644 --- a/Pilot1/NT3/nt3_baseline_keras2.py +++ b/Pilot1/NT3/nt3_baseline_keras2.py @@ -278,10 +278,9 @@ def run(gParameters): print("yaml %s: %.2f%%" % (loaded_model_yaml.metrics_names[1], score_yaml[1] * 100)) - if gParameters['save_cf']: + if gParameters['noise_save_cf']: model.save('{}/{}.autosave.model'.format(output_dir, model_name)) - pickle.dump([X_train, Y_train], open('{}/{}.autosave.data.train.pkl'.format(output_dir, model_name), "wb")) - pickle.dump( [X_test, Y_test], open( '{}/{}.autosave.data.test.pkl'.format(output_dir, model_name), "wb" )) + pickle.dump([X_train, X_test, Y_train, Y_test], open('{}/{}.autosave.data.pkl'.format(output_dir, model_name), "wb")) return history diff --git a/Pilot1/NT3/nt3_cf/inject_noise.py b/Pilot1/NT3/nt3_cf/inject_noise.py index c2481eff..870096f4 100644 --- a/Pilot1/NT3/nt3_cf/inject_noise.py +++ b/Pilot1/NT3/nt3_cf/inject_noise.py @@ -55,7 +55,15 @@ def main(): #print(perturb_dataset[selector]) X_data_noise[selector]-= args.scale*perturb_dataset[selector][:,:,None] #print(np.sum(X_data_noise - X_data)) - pickle.dump([X_data_noise, y_data, selector, cluster_inds], open("nt3.data.scale_{}_cluster_{}_{}.noise_{}.pkl".format(args.scale,clusters[i][0], clusters[i][1], round(p,1)), "wb")) + + # Now split back into train test + X_train = X_data[0:(int)(0.8*X_data.shape[0])] + X_test = X_data[0.8*X_data.shape[0]:] + + y_train = y_data[0:(int)(0.8*y_data.shape[0])] + y_test = y_data[0.8*y_data.shape[0]:] + + pickle.dump([X_train, X_test, y_train, y_test, selector, cluster_inds], open("nt3.data.scale_{}_cluster_{}_{}.noise_{}.pkl".format(args.scale,clusters[i][0], clusters[i][1], round(p,1)), "wb")) # Threshold cf injection inds = [] @@ -69,7 +77,13 @@ def main(): for j in inds: perturb_dataset[:,j]=0 X_data_noise_2[selector]-= args.scale*perturb_dataset[selector][:,:,None] - pickle.dump([X_data_noise_2, y_data, selector, cluster_inds], open("nt3.data.threshold_scale_{}_cluster_{}_{}.noise_{}.pkl".format(args.scale, clusters[i][0], clusters[i][1], round(p,1)), "wb")) + # Now split back into train test + X_train = X_data[0:(int)(0.8*X_data.shape[0])] + X_test = X_data[0.8*X_data.shape[0]:] + + y_train = y_data[0:(int)(0.8*y_data.shape[0])] + y_test = y_data[0.8*y_data.shape[0]:] + pickle.dump([X_train, X_test, y_train, y_test, selector, cluster_inds], open("nt3.data.threshold_scale_{}_cluster_{}_{}.noise_{}.pkl".format(args.scale, clusters[i][0], clusters[i][1], round(p,1)), "wb")) if __name__ == "__main__": main() diff --git a/Pilot1/NT3/nt3_default_model.txt b/Pilot1/NT3/nt3_default_model.txt index 11f22a73..c762310d 100644 --- a/Pilot1/NT3/nt3_default_model.txt +++ b/Pilot1/NT3/nt3_default_model.txt @@ -10,7 +10,7 @@ out_activation = 'softmax' loss = 'categorical_crossentropy' optimizer = 'sgd' metrics = 'accuracy' -epochs = 400 +epochs = 10 batch_size = 20 learning_rate = 0.001 dropout = 0.1 @@ -27,4 +27,4 @@ std_dev = 0.5 timeout = 3600 ckpt_restart_mode = 'off' ckpt_checksum = True -save_cf = True +noise_save_cf = True diff --git a/Pilot1/NT3/nt3_noise_model.txt b/Pilot1/NT3/nt3_noise_model.txt index 88c5f19d..e4d929fa 100644 --- a/Pilot1/NT3/nt3_noise_model.txt +++ b/Pilot1/NT3/nt3_noise_model.txt @@ -33,3 +33,5 @@ alpha_scale_factor = 0.95 init_abs_epoch = 2 task_list = 0 task_names = ['activation_5'] +noise_save_cf = True +noise_cf = '/vol/ml/shahashka/nt3_cf/cf_data/data_with_cf_noise/nt3.data.scale_1.0_cluster_0_1.noise_0.1.pkl' diff --git a/common/parsing_utils.py b/common/parsing_utils.py index e812f93f..27006000 100644 --- a/common/parsing_utils.py +++ b/common/parsing_utils.py @@ -94,11 +94,18 @@ {'name': 'run_id', 'type': str, 'default': 'RUN000', - 'help': 'set the run unique identifier.'}, - {'name': 'save_cf', + 'help': 'set the run unique identifier.'} +] + +noise_conf = [ + {'name': 'noise_save_cf', 'type': bool, 'default': False, - 'help': 'save the model (Tensoflow saved model format) and data (pickle) objects for cf runs'} + 'help': 'save the model (Tensoflow saved model format) and data (pickle) objects for cf runs'}, + {'name': 'noise_cf', + 'type': str, + 'default': None, + 'help': 'pickle file to hold dataset with noise already added through counterfactuals'} ] logging_conf = [ @@ -315,7 +322,7 @@ ] -registered_conf = [basic_conf, input_output_conf, logging_conf, data_preprocess_conf, model_conf, training_conf, cyclic_learning_conf, ckpt_conf] +registered_conf = [basic_conf, input_output_conf, logging_conf, data_preprocess_conf, model_conf, training_conf, cyclic_learning_conf, ckpt_conf, noise_conf] def extract_keywords(lst_dict, kw): From 76ea9092c71390202b927f6ec1f7ac48de6bbb4e Mon Sep 17 00:00:00 2001 From: Ashka Shah Date: Mon, 4 Oct 2021 13:21:14 -0500 Subject: [PATCH 04/12] remove extra files --- .../nt3_cf/abstention/abstain_functions.py | 203 ------ .../abstention/nt3_abstention_keras2_cf.py | 378 ---------- .../nt3_cf/abstention/nt3_baseline_keras2.py | 318 --------- Pilot1/NT3/nt3_cf/analyze.ipynb | 655 ------------------ Pilot1/NT3/nt3_cf/nt3.ipynb | 426 ------------ 5 files changed, 1980 deletions(-) delete mode 100644 Pilot1/NT3/nt3_cf/abstention/abstain_functions.py delete mode 100644 Pilot1/NT3/nt3_cf/abstention/nt3_abstention_keras2_cf.py delete mode 100644 Pilot1/NT3/nt3_cf/abstention/nt3_baseline_keras2.py delete mode 100644 Pilot1/NT3/nt3_cf/analyze.ipynb delete mode 100644 Pilot1/NT3/nt3_cf/nt3.ipynb diff --git a/Pilot1/NT3/nt3_cf/abstention/abstain_functions.py b/Pilot1/NT3/nt3_cf/abstention/abstain_functions.py deleted file mode 100644 index 8ee5c84b..00000000 --- a/Pilot1/NT3/nt3_cf/abstention/abstain_functions.py +++ /dev/null @@ -1,203 +0,0 @@ -from tensorflow.keras import backend as K - -abs_definitions = [ - {'name': 'add_class', - 'nargs': '+', - 'type': int, - 'help': 'flag to add abstention (per task)'}, - {'name': 'alpha', - 'nargs': '+', - 'type': float, - 'help': 'abstention penalty coefficient (per task)'}, - {'name': 'min_acc', - 'nargs': '+', - 'type': float, - 'help': 'minimum accuracy required (per task)'}, - {'name': 'max_abs', - 'nargs': '+', - 'type': float, - 'help': 'maximum abstention fraction allowed (per task)'}, - {'name': 'alpha_scale_factor', - 'nargs': '+', - 'type': float, - 'help': 'scaling factor for modifying alpha (per task)'}, - {'name': 'init_abs_epoch', - 'action': 'store', - 'type': int, - 'help': 'number of epochs to skip before modifying alpha'}, - {'name': 'n_iters', - 'action': 'store', - 'type': int, - 'help': 'number of iterations to iterate alpha'}, - {'name': 'acc_gain', - 'type': float, - 'default': 5.0, - 'help': 'factor to weight accuracy when determining new alpha scale'}, - {'name': 'abs_gain', - 'type': float, - 'default': 1.0, - 'help': 'factor to weight abstention fraction when determining new alpha scale'}, - {'name': 'task_list', - 'nargs': '+', - 'type': int, - 'help': 'list of task indices to use'}, - {'name': 'task_names', - 'nargs': '+', - 'type': int, - 'help': 'list of names corresponding to each task to use'}, - {'name': 'cf_noise', - 'type': str, - 'help': 'input file with cf noise'} -] - - -def adjust_alpha(gParameters, X_test, truths_test, labels_val, model, alpha, add_index): - - task_names = gParameters['task_names'] - task_list = gParameters['task_list'] - # retrieve truth-pred pair - avg_loss = 0.0 - ret = [] - ret_k = [] - - # set abstaining classifier parameters - max_abs = gParameters['max_abs'] - min_acc = gParameters['min_acc'] - alpha_scale_factor = gParameters['alpha_scale_factor'] - - # print('labels_test', labels_test) - # print('Add_index', add_index) - - feature_test = X_test - # label_test = keras.utils.to_categorical(truths_test) - - # loss = model.evaluate(feature_test, [label_test[0], label_test[1],label_test[2], label_test[3]]) - loss = model.evaluate(feature_test, labels_val) - avg_loss = avg_loss + loss[0] - - pred = model.predict(feature_test) - # print('pred',pred.shape, pred) - - abs_gain = gParameters['abs_gain'] - acc_gain = gParameters['acc_gain'] - - accs = [] - abst = [] - - for k in range((alpha.shape[0])): - if k in task_list: - truth_test = truths_test[:, k] - alpha_k = K.eval(alpha[k]) - pred_classes = pred[k].argmax(axis=-1) - # true_classes = labels_test[k].argmax(axis=-1) - true_classes = truth_test - - # print('pred_classes',pred_classes.shape, pred_classes) - # print('true_classes',true_classes.shape, true_classes) - # print('labels',label_test.shape, label_test) - - true = K.eval(K.sum(K.cast(K.equal(pred_classes, true_classes), 'int64'))) - false = K.eval(K.sum(K.cast(K.not_equal(pred_classes, true_classes), 'int64'))) - abstain = K.eval(K.sum(K.cast(K.equal(pred_classes, add_index[k] - 1), 'int64'))) - - print(true, false, abstain) - - total = false + true - tot_pred = total - abstain - abs_acc = 0.0 - abs_frac = abstain / total - - if tot_pred > 0: - abs_acc = true / tot_pred - - scale_k = alpha_scale_factor[k] - min_scale = scale_k - max_scale = 1. / scale_k - - acc_error = abs_acc - min_acc[k] - acc_error = min(acc_error, 0.0) - abs_error = abs_frac - max_abs[k] - abs_error = max(abs_error, 0.0) - new_scale = 1.0 + acc_gain * acc_error + abs_gain * abs_error - - # threshold to avoid huge swings - new_scale = min(new_scale, max_scale) - new_scale = max(new_scale, min_scale) - - print('Scaling factor: ', new_scale) - K.set_value(alpha[k], new_scale * alpha_k) - - print_abs_stats(task_names[k], new_scale * alpha_k, true, false, abstain, max_abs[k]) - - ret_k.append(truth_test) - ret_k.append(pred) - - ret.append(ret_k) - - accs.append(abs_acc) - abst.append(abs_frac) - else: - accs.append(1.0) - accs.append(0.0) - - write_abs_stats(gParameters['output_dir'] + 'abs_stats.csv', alpha, accs, abst) - - return ret, alpha - - -def loss_param(alpha, mask): - def loss(y_true, y_pred): - - cost = 0 - - base_pred = (1 - mask) * y_pred - # base_true = (1 - mask) * y_true - base_true = y_true - - base_cost = K.sparse_categorical_crossentropy(base_true, base_pred) - - abs_pred = K.mean(mask * (y_pred), axis=-1) - # add some small value to prevent NaN when prediction is abstained - abs_pred = K.clip(abs_pred, K.epsilon(), 1. - K.epsilon()) - cost = (1. - abs_pred) * base_cost - (alpha) * K.log(1. - abs_pred) - - return cost - return loss - - -def print_abs_stats( - task_name, - alpha, - num_true, - num_false, - num_abstain, - max_abs): - - # Compute interesting values - total = num_true + num_false - tot_pred = total - num_abstain - abs_frac = num_abstain / total - abs_acc = 1.0 - if tot_pred > 0: - abs_acc = num_true / tot_pred - - print(' task, alpha, true, false, abstain, total, tot_pred, abs_frac, max_abs, abs_acc') - print('{:>12s}, {:10.5e}, {:8d}, {:8d}, {:8d}, {:8d}, {:8d}, {:10.5f}, {:10.5f}, {:10.5f}' - .format(task_name, alpha, - num_true, num_false - num_abstain, num_abstain, total, - tot_pred, abs_frac, max_abs, abs_acc)) - - -def write_abs_stats(stats_file, alphas, accs, abst): - - # Open file for appending - abs_file = open(stats_file, 'a') - - # we write all the results - for k in range((alphas.shape[0])): - abs_file.write("%10.5e," % K.get_value(alphas[k])) - for k in range((alphas.shape[0])): - abs_file.write("%10.5e," % accs[k]) - for k in range((alphas.shape[0])): - abs_file.write("%10.5e," % abst[k]) - abs_file.write("\n") diff --git a/Pilot1/NT3/nt3_cf/abstention/nt3_abstention_keras2_cf.py b/Pilot1/NT3/nt3_cf/abstention/nt3_abstention_keras2_cf.py deleted file mode 100644 index 8ec8c5d8..00000000 --- a/Pilot1/NT3/nt3_cf/abstention/nt3_abstention_keras2_cf.py +++ /dev/null @@ -1,378 +0,0 @@ -from __future__ import print_function -import pandas as pd -import numpy as np -import os -import tensorflow -from tensorflow.keras import backend as K -os.environ["CUDA_VISIBLE_DEVICES"]="2" -from tensorflow.keras.layers import Dense, Dropout, Activation, Conv1D, MaxPooling1D, Flatten, LocallyConnected1D -from tensorflow.keras.models import Sequential, model_from_json, model_from_yaml -from tensorflow.keras.utils import to_categorical -from tensorflow.keras.callbacks import CSVLogger, ReduceLROnPlateau - -from sklearn.preprocessing import MaxAbsScaler -from abstain_functions import abs_definitions - -import nt3 as bmk -import candle -import pickle -additional_definitions = abs_definitions - -required = bmk.required - - -class BenchmarkNT3Abs(candle.Benchmark): - def set_locals(self): - """Functionality to set variables specific for the benchmark - - required: set of required parameters for the benchmark. - - additional_definitions: list of dictionaries describing the additional parameters for the - benchmark. - """ - - if required is not None: - self.required = set(bmk.required) - if additional_definitions is not None: - self.additional_definitions = abs_definitions + bmk.additional_definitions - - -def initialize_parameters(default_model='nt3_noise_model.txt'): - - # Build benchmark object - nt3Bmk = BenchmarkNT3Abs( - bmk.file_path, - default_model, - 'keras', - prog='nt3_abstention', - desc='1D CNN to classify RNA sequence data in normal or tumor classes') - - # Initialize parameters - gParameters = candle.finalize_parameters(nt3Bmk) - - return gParameters - - -def load_data(path, gParameters): - - # Rewrite this function to handle pickle files instead - print("Loading data...") - data = pickle.load(open(path, 'rb')) - X=data[0] - y=data[1] - polluted_inds = data[2] - cluster_inds = data[3] - size = X.shape[0] - X_train = X[0:(int)(0.8*size)] - X_test = X[(int)(0.8*size):] - Y_train = y[0:(int)(0.8*size)] - Y_test = y[(int)(0.8*size):] - #df_train = (pd.read_csv(train_path, header=None).values).astype('float32') - #df_test = (pd.read_csv(test_path, header=None).values).astype('float32') - #X_train,Y_train, X_test, Y_test = data - #polluted_inds = [] - #cluster_inds=[] - print('done') - - - print('df_train shape:', X_train.shape) - print('df_test shape:', X_test.shape) - - return X_train, Y_train, X_test, Y_test, polluted_inds, cluster_inds - - -def run(gParameters): - - print('Params:', gParameters) - - data_file = gParameters['cf_noise'] - # file_test = gParameters['test_data'] - url = gParameters['data_url'] - - #train_file = candle.get_file(file_train, url + file_train, cache_subdir='Pilot1') - #test_file = candle.get_file(file_test, url + file_test, cache_subdir='Pilot1') - - X_train, Y_train, X_test, Y_test, polluted_inds, cluster_inds = load_data(data_file, gParameters) - - # add extra class for abstention - # first reverse the to_categorical - Y_train = np.argmax(Y_train, axis=1) - Y_test = np.argmax(Y_test, axis=1) - Y_train, Y_test = candle.modify_labels(gParameters['classes'] + 1, Y_train, Y_test) - # print(Y_test) - - print('X_train shape:', X_train.shape) - print('X_test shape:', X_test.shape) - - print('Y_train shape:', Y_train.shape) - print('Y_test shape:', Y_test.shape) - - x_train_len = X_train.shape[1] - - # this reshaping is critical for the Conv1D to work - - #X_train = np.expand_dims(X_train, axis=2) - #X_test = np.expand_dims(X_test, axis=2) - - print('X_train shape:', X_train.shape) - print('X_test shape:', X_test.shape) - - model = Sequential() - - layer_list = list(range(0, len(gParameters['conv']), 3)) - for _, i in enumerate(layer_list): - filters = gParameters['conv'][i] - filter_len = gParameters['conv'][i + 1] - stride = gParameters['conv'][i + 2] - print(int(i / 3), filters, filter_len, stride) - if gParameters['pool']: - pool_list = gParameters['pool'] - if type(pool_list) != list: - pool_list = list(pool_list) - - if filters <= 0 or filter_len <= 0 or stride <= 0: - break - if 'locally_connected' in gParameters: - model.add(LocallyConnected1D(filters, filter_len, strides=stride, padding='valid', input_shape=(x_train_len, 1))) - else: - # input layer - if i == 0: - model.add(Conv1D(filters=filters, kernel_size=filter_len, strides=stride, padding='valid', input_shape=(x_train_len, 1))) - else: - model.add(Conv1D(filters=filters, kernel_size=filter_len, strides=stride, padding='valid')) - model.add(Activation(gParameters['activation'])) - if gParameters['pool']: - model.add(MaxPooling1D(pool_size=pool_list[int(i / 3)])) - - model.add(Flatten()) - - for layer in gParameters['dense']: - if layer: - model.add(Dense(layer)) - model.add(Activation(gParameters['activation'])) - if gParameters['dropout']: - model.add(Dropout(gParameters['dropout'])) - model.add(Dense(gParameters['classes'])) - model.add(Activation(gParameters['out_activation'])) - - # modify the model for abstention - model = candle.add_model_output(model, mode='abstain', num_add=1, activation=gParameters['out_activation']) - -# Reference case -# model.add(Conv1D(filters=128, kernel_size=20, strides=1, padding='valid', input_shape=(P, 1))) -# model.add(Activation('relu')) -# model.add(MaxPooling1D(pool_size=1)) -# model.add(Conv1D(filters=128, kernel_size=10, strides=1, padding='valid')) -# model.add(Activation('relu')) -# model.add(MaxPooling1D(pool_size=10)) -# model.add(Flatten()) -# model.add(Dense(200)) -# model.add(Activation('relu')) -# model.add(Dropout(0.1)) -# model.add(Dense(20)) -# model.add(Activation('relu')) -# model.add(Dropout(0.1)) -# model.add(Dense(CLASSES)) -# model.add(Activation('softmax')) - - kerasDefaults = candle.keras_default_config() - - # Define optimizer - optimizer = candle.build_optimizer(gParameters['optimizer'], - gParameters['learning_rate'], - kerasDefaults) - - model.summary() - - # Configure abstention model - nb_classes = gParameters['classes'] - mask = np.zeros(nb_classes + 1) - mask[nb_classes] = 1.0 - print("Mask is ", mask) - alpha0 = gParameters['alpha'] - if isinstance(gParameters['max_abs'], list): - max_abs = gParameters['max_abs'][0] - else: - max_abs = gParameters['max_abs'] - - print("Initializing abstention callback with: \n") - print("alpha0 ", alpha0) - print("alpha_scale_factor ", gParameters['alpha_scale_factor']) - print("min_abs_acc ", gParameters['min_acc']) - print("max_abs_frac ", max_abs) - print("acc_gain ", gParameters['acc_gain']) - print("abs_gain ", gParameters['abs_gain']) - - abstention_cbk = candle.AbstentionAdapt_Callback(acc_monitor='val_abstention_acc', - abs_monitor='val_abstention', - init_abs_epoch=gParameters['init_abs_epoch'], - alpha0=alpha0, - alpha_scale_factor=gParameters['alpha_scale_factor'], - min_abs_acc=gParameters['min_acc'], - max_abs_frac=max_abs, - acc_gain=gParameters['acc_gain'], - abs_gain=gParameters['abs_gain']) - - model.compile(loss=candle.abstention_loss(abstention_cbk.alpha, mask), - optimizer=optimizer, - metrics=[candle.abstention_acc_metric(nb_classes), - # candle.acc_class_i_metric(1), - # candle.abstention_acc_class_i_metric(nb_classes, 1), - candle.abstention_metric(nb_classes)]) - - # model.compile(loss=abs_loss, - # optimizer=optimizer, - # metrics=abs_acc) - - # model.compile(loss=gParameters['loss'], - # optimizer=optimizer, - # metrics=[gParameters['metrics']]) - - output_dir = gParameters['output_dir'] - - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - # calculate trainable and non-trainable params - gParameters.update(candle.compute_trainable_params(model)) - - # set up a bunch of callbacks to do work during model training.. - model_name = gParameters['model_name'] - # path = '{}/{}.autosave.model.h5'.format(output_dir, model_name) - # checkpointer = ModelCheckpoint(filepath=path, verbose=1, save_weights_only=False, save_best_only=True) - print(output_dir) - csv_logger = CSVLogger("{}/training.log".format(output_dir)) - reduce_lr = ReduceLROnPlateau(monitor='val_loss', - factor=0.1, patience=10, verbose=1, mode='auto', - epsilon=0.0001, cooldown=0, min_lr=0) - - candleRemoteMonitor = candle.CandleRemoteMonitor(params=gParameters) - timeoutMonitor = candle.TerminateOnTimeOut(gParameters['timeout']) - - # n_iters = 1 - - # val_labels = {"activation_5": Y_test} - # for epoch in range(gParameters['epochs']): - # print('Iteration = ', epoch) - history = model.fit(X_train, Y_train, - batch_size=gParameters['batch_size'], - epochs=gParameters['epochs'], - # initial_epoch=epoch, - # epochs=epoch + n_iters, - verbose=1, - validation_data=(X_test, Y_test), - # callbacks=[csv_logger, reduce_lr, candleRemoteMonitor, timeoutMonitor]) # , abstention_cbk]) - callbacks=[csv_logger, reduce_lr, candleRemoteMonitor, timeoutMonitor, abstention_cbk]) - - # ret, alpha = adjust_alpha(gParameters, X_test, Y_test, val_labels, model, alpha, [nb_classes+1]) - - score = model.evaluate(X_test, Y_test, verbose=0) - - if len(polluted_inds) > 0: - y_pred = model.predict(X_test) - abstain_inds = [] - for i in range(y_pred.shape[0]): - if np.argmax(y_pred[i]) == nb_classes: - abstain_inds.append(i) - - # Cluster indices and polluted indices are wrt to entire train + test dataset - # whereas y_pred only contains test dataset so add offset for correct indexing - offset_testset = Y_train.shape[0] - abstain_inds=[i+offset_testset for i in abstain_inds] - - polluted_percentage = c = np.sum([el in polluted_inds for el in abstain_inds])/np.max([len(abstain_inds),1]) - print("Percentage of abstained samples that were polluted {}".format(polluted_percentage)) - - cluster_inds_test = list(filter(lambda cluster_inds: cluster_inds >= offset_testset, cluster_inds)) - cluster_inds_test_abstain = [el in abstain_inds for el in cluster_inds_test] - cluster_percentage = c = np.sum(cluster_inds_test_abstain)/len(cluster_inds_test) - print("Percentage of cluster (in test set) that was abstained {}".format(cluster_percentage)) - - unabstain_inds = [] - for i in range(y_pred.shape[0]): - if np.argmax(y_pred[i]) != nb_classes and (i+offset_testset in cluster_inds_test): - unabstain_inds.append(i) - # Make sure number of unabstained indices in cluster test set plus number of abstainsed indices in cluster test set - # equals number of indices in cluster in the test set - assert(len(unabstain_inds)+np.sum(cluster_inds_test_abstain) == len(cluster_inds_test)) - score_cluster = 1 if len(unabstain_inds)==0 else model.evaluate(X_test[unabstain_inds], Y_test[unabstain_inds])[1] - print("Accuracy of unabastained cluster {}".format(score_cluster)) - - pickle.dump({'Abs polluted': polluted_percentage, 'Abs val cluster': cluster_percentage, 'Abs val acc': score_cluster}, open("{}/cluster_trace.pkl".format(output_dir), "wb")) - - alpha_trace = open(output_dir + "/alpha_trace", "w+") - for alpha in abstention_cbk.alphavalues: - alpha_trace.write(str(alpha) + '\n') - alpha_trace.close() - - if False: - print('Test score:', score[0]) - print('Test accuracy:', score[1]) - # serialize model to JSON - model_json = model.to_json() - with open("{}/{}.model.json".format(output_dir, model_name), "w") as json_file: - json_file.write(model_json) - - # serialize model to YAML - model_yaml = model.to_yaml() - with open("{}/{}.model.yaml".format(output_dir, model_name), "w") as yaml_file: - yaml_file.write(model_yaml) - - # serialize weights to HDF5 - model.save_weights("{}/{}.weights.h5".format(output_dir, model_name)) - print("Saved model to disk") - - # load json and create model - json_file = open('{}/{}.model.json'.format(output_dir, model_name), 'r') - loaded_model_json = json_file.read() - json_file.close() - loaded_model_json = model_from_json(loaded_model_json) - - # load yaml and create model - yaml_file = open('{}/{}.model.yaml'.format(output_dir, model_name), 'r') - loaded_model_yaml = yaml_file.read() - yaml_file.close() - loaded_model_yaml = model_from_yaml(loaded_model_yaml) - - # load weights into new model - loaded_model_json.load_weights('{}/{}.weights.h5'.format(output_dir, model_name)) - print("Loaded json model from disk") - - # evaluate json loaded model on test data - loaded_model_json.compile(loss=gParameters['loss'], - optimizer=gParameters['optimizer'], - metrics=[gParameters['metrics']]) - score_json = loaded_model_json.evaluate(X_test, Y_test, verbose=0) - - print('json Test score:', score_json[0]) - print('json Test accuracy:', score_json[1]) - - print("json %s: %.2f%%" % (loaded_model_json.metrics_names[1], score_json[1] * 100)) - - # load weights into new model - loaded_model_yaml.load_weights('{}/{}.weights.h5'.format(output_dir, model_name)) - print("Loaded yaml model from disk") - - # evaluate loaded model on test data - loaded_model_yaml.compile(loss=gParameters['loss'], - optimizer=gParameters['optimizer'], - metrics=[gParameters['metrics']]) - score_yaml = loaded_model_yaml.evaluate(X_test, Y_test, verbose=0) - - print('yaml Test score:', score_yaml[0]) - print('yaml Test accuracy:', score_yaml[1]) - - print("yaml %s: %.2f%%" % (loaded_model_yaml.metrics_names[1], score_yaml[1] * 100)) - - return history - - -def main(): - gParameters = initialize_parameters() - run(gParameters) - - -if __name__ == '__main__': - main() - try: - K.clear_session() - except AttributeError: # theano does not have this function - pass diff --git a/Pilot1/NT3/nt3_cf/abstention/nt3_baseline_keras2.py b/Pilot1/NT3/nt3_cf/abstention/nt3_baseline_keras2.py deleted file mode 100644 index 8d1227f5..00000000 --- a/Pilot1/NT3/nt3_cf/abstention/nt3_baseline_keras2.py +++ /dev/null @@ -1,318 +0,0 @@ -from __future__ import print_function - -import pandas as pd -import numpy as np -import os -import pickle - -from tensorflow.keras import backend as K - -from tensorflow.keras.layers import Dense, Dropout, Activation, Conv1D, MaxPooling1D, Flatten, LocallyConnected1D -from tensorflow.keras.models import Sequential, model_from_json, model_from_yaml -from tensorflow.keras.utils import to_categorical -from tensorflow.keras.callbacks import CSVLogger, ReduceLROnPlateau - -from sklearn.preprocessing import MaxAbsScaler - -import nt3 as bmk -import candle - - -def initialize_parameters(default_model='nt3_default_model.txt'): - - # Build benchmark object - nt3Bmk = bmk.BenchmarkNT3( - bmk.file_path, - default_model, - 'keras', - prog='nt3_baseline', - desc='1D CNN to classify RNA sequence data in normal or tumor classes') - - # Initialize parameters - gParameters = candle.finalize_parameters(nt3Bmk) - - return gParameters - -def load_data_pickle(path, gParameters): - # Rewrite this function to handle pickle files instead - print("Loading data...") - data = pickle.load(open(path, 'rb')) - X=data[0] - y=data[1] - polluted_inds = data[2] - cluster_inds = data[3] - size = X.shape[0] - X_train = X[0:(int)(0.8*size)] - X_test = X[(int)(0.8*size):] - Y_train = y[0:(int)(0.8*size)] - Y_test = y[(int)(0.8*size):] - #df_train = (pd.read_csv(train_path, header=None).values).astype('float32') - #df_test = (pd.read_csv(test_path, header=None).values).astype('float32') - #X_train,Y_train, X_test, Y_test = data - print('done') - return X_train, Y_train, X_test, Y_test - -def load_data(train_path, test_path, gParameters): - - print('Loading data...') - df_train = (pd.read_csv(train_path, header=None).values).astype('float32') - df_test = (pd.read_csv(test_path, header=None).values).astype('float32') - print('done') - - print('df_train shape:', df_train.shape) - print('df_test shape:', df_test.shape) - - seqlen = df_train.shape[1] - - df_y_train = df_train[:, 0].astype('int') - df_y_test = df_test[:, 0].astype('int') - - # only training set has noise - Y_train = to_categorical(df_y_train, gParameters['classes']) - Y_test = to_categorical(df_y_test, gParameters['classes']) - - df_x_train = df_train[:, 1:seqlen].astype(np.float32) - df_x_test = df_test[:, 1:seqlen].astype(np.float32) - - X_train = df_x_train - X_test = df_x_test - - scaler = MaxAbsScaler() - mat = np.concatenate((X_train, X_test), axis=0) - mat = scaler.fit_transform(mat) - - X_train = mat[:X_train.shape[0], :] - X_test = mat[X_train.shape[0]:, :] - - # TODO: Add better names for noise boolean, make a featue for both RNA seq and label noise together - # check if noise is on (this is for label) - if gParameters['add_noise']: - # check if we want noise correlated with a feature - if gParameters['noise_correlated']: - Y_train, y_train_noise_gen = candle.label_flip_correlated(Y_train, - gParameters['label_noise'], X_train, - gParameters['feature_col'], - gParameters['feature_threshold']) - # else add uncorrelated noise - else: - Y_train, y_train_noise_gen = candle.label_flip(Y_train, gParameters['label_noise']) - # check if noise is on for RNA-seq data - elif gParameters['noise_gaussian']: - X_train = candle.add_gaussian_noise(X_train, 0, gParameters['std_dev']) - - return X_train, Y_train, X_test, Y_test - - -def run(gParameters): - - file_train = gParameters['train_data'] - file_test = gParameters['test_data'] - url = gParameters['data_url'] - - #train_file = candle.get_file(file_train, url + file_train, cache_subdir='Pilot1') - #test_file = candle.get_file(file_test, url + file_test, cache_subdir='Pilot1') - - model = Sequential() - - initial_epoch = 0 - best_metric_last = None - - #X_train, Y_train, X_test, Y_test = load_data(train_file, test_file, gParameters) - X_train, Y_train, X_test, Y_test = load_data_pickle(file_train, gParameters) - - print('X_train shape:', X_train.shape) - print('X_test shape:', X_test.shape) - - print('Y_train shape:', Y_train.shape) - print('Y_test shape:', Y_test.shape) - - x_train_len = X_train.shape[1] - - # this reshaping is critical for the Conv1D to work - - X_train = np.expand_dims(X_train, axis=2) - X_test = np.expand_dims(X_test, axis=2) - - print('X_train shape:', X_train.shape) - print('X_test shape:', X_test.shape) - - layer_list = list(range(0, len(gParameters['conv']), 3)) - for _, i in enumerate(layer_list): - filters = gParameters['conv'][i] - filter_len = gParameters['conv'][i + 1] - stride = gParameters['conv'][i + 2] - print(int(i / 3), filters, filter_len, stride) - if gParameters['pool']: - pool_list = gParameters['pool'] - if type(pool_list) != list: - pool_list = list(pool_list) - - if filters <= 0 or filter_len <= 0 or stride <= 0: - break - if 'locally_connected' in gParameters: - model.add(LocallyConnected1D(filters, filter_len, strides=stride, padding='valid', input_shape=(x_train_len, 1))) - else: - # input layer - if i == 0: - model.add(Conv1D(filters=filters, kernel_size=filter_len, strides=stride, padding='valid', input_shape=(x_train_len, 1))) - else: - model.add(Conv1D(filters=filters, kernel_size=filter_len, strides=stride, padding='valid')) - model.add(Activation(gParameters['activation'])) - if gParameters['pool']: - model.add(MaxPooling1D(pool_size=pool_list[int(i / 3)])) - - model.add(Flatten()) - - for layer in gParameters['dense']: - if layer: - model.add(Dense(layer)) - model.add(Activation(gParameters['activation'])) - if gParameters['dropout']: - model.add(Dropout(gParameters['dropout'])) - model.add(Dense(gParameters['classes'])) - model.add(Activation(gParameters['out_activation'])) - - J = candle.restart(gParameters, model) - if J is not None: - initial_epoch = J['epoch'] - best_metric_last = J['best_metric_last'] - gParameters['ckpt_best_metric_last'] = best_metric_last - print('initial_epoch: %i' % initial_epoch) - - ckpt = candle.CandleCheckpointCallback(gParameters, - verbose=False) - -# Reference case -# model.add(Conv1D(filters=128, kernel_size=20, strides=1, padding='valid', input_shape=(P, 1))) -# model.add(Activation('relu')) -# model.add(MaxPooling1D(pool_size=1)) -# model.add(Conv1D(filters=128, kernel_size=10, strides=1, padding='valid')) -# model.add(Activation('relu')) -# model.add(MaxPooling1D(pool_size=10)) -# model.add(Flatten()) -# model.add(Dense(200)) -# model.add(Activation('relu')) -# model.add(Dropout(0.1)) -# model.add(Dense(20)) -# model.add(Activation('relu')) -# model.add(Dropout(0.1)) -# model.add(Dense(CLASSES)) -# model.add(Activation('softmax')) - - kerasDefaults = candle.keras_default_config() - - # Define optimizer - optimizer = candle.build_optimizer(gParameters['optimizer'], - gParameters['learning_rate'], - kerasDefaults) - - model.summary() - model.compile(loss=gParameters['loss'], - optimizer=optimizer, - metrics=[gParameters['metrics']]) - - output_dir = gParameters['output_dir'] - - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - # calculate trainable and non-trainable params - gParameters.update(candle.compute_trainable_params(model)) - - # set up a bunch of callbacks to do work during model training.. - model_name = gParameters['model_name'] - # path = '{}/{}.autosave.model.h5'.format(output_dir, model_name) - # checkpointer = ModelCheckpoint(filepath=path, verbose=1, save_weights_only=False, save_best_only=True) - csv_logger = CSVLogger('{}/training.log'.format(output_dir)) - reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=10, verbose=1, mode='auto', epsilon=0.0001, cooldown=0, min_lr=0) - candleRemoteMonitor = candle.CandleRemoteMonitor(params=gParameters) - timeoutMonitor = candle.TerminateOnTimeOut(gParameters['timeout']) - - history = model.fit(X_train, Y_train, - batch_size=gParameters['batch_size'], - epochs=gParameters['epochs'], - initial_epoch=initial_epoch, - verbose=1, - validation_data=(X_test, Y_test), - callbacks=[csv_logger, reduce_lr, candleRemoteMonitor, timeoutMonitor, - ckpt]) - - score = model.evaluate(X_test, Y_test, verbose=0) - - if False: - print('Test score:', score[0]) - print('Test accuracy:', score[1]) - # serialize model to JSON - model_json = model.to_json() - with open("{}/{}.model.json".format(output_dir, model_name), "w") as json_file: - json_file.write(model_json) - - # serialize model to YAML - model_yaml = model.to_yaml() - with open("{}/{}.model.yaml".format(output_dir, model_name), "w") as yaml_file: - yaml_file.write(model_yaml) - - # serialize weights to HDF5 - model.save_weights("{}/{}.weights.h5".format(output_dir, model_name)) - print("Saved model to disk") - - # load json and create model - json_file = open('{}/{}.model.json'.format(output_dir, model_name), 'r') - loaded_model_json = json_file.read() - json_file.close() - loaded_model_json = model_from_json(loaded_model_json) - - # load yaml and create model - yaml_file = open('{}/{}.model.yaml'.format(output_dir, model_name), 'r') - loaded_model_yaml = yaml_file.read() - yaml_file.close() - loaded_model_yaml = model_from_yaml(loaded_model_yaml) - - # load weights into new model - loaded_model_json.load_weights('{}/{}.weights.h5'.format(output_dir, model_name)) - print("Loaded json model from disk") - - # evaluate json loaded model on test data - loaded_model_json.compile(loss=gParameters['loss'], - optimizer=gParameters['optimizer'], - metrics=[gParameters['metrics']]) - score_json = loaded_model_json.evaluate(X_test, Y_test, verbose=0) - - print('json Test score:', score_json[0]) - print('json Test accuracy:', score_json[1]) - - print("json %s: %.2f%%" % (loaded_model_json.metrics_names[1], score_json[1] * 100)) - - # load weights into new model - loaded_model_yaml.load_weights('{}/{}.weights.h5'.format(output_dir, model_name)) - print("Loaded yaml model from disk") - - # evaluate loaded model on test data - loaded_model_yaml.compile(loss=gParameters['loss'], - optimizer=gParameters['optimizer'], - metrics=[gParameters['metrics']]) - score_yaml = loaded_model_yaml.evaluate(X_test, Y_test, verbose=0) - - print('yaml Test score:', score_yaml[0]) - print('yaml Test accuracy:', score_yaml[1]) - - print("yaml %s: %.2f%%" % (loaded_model_yaml.metrics_names[1], score_yaml[1] * 100)) - - model.save(path) - path = '{}/{}.autosave.data.h5'.format(output_dir, model_name) - pickle.dump( [X_train, Y_train, X_test, Y_test], open( path, "wb" ) ) - print(path) - return history - - -def main(): - gParameters = initialize_parameters() - run(gParameters) - - -if __name__ == '__main__': - main() - try: - K.clear_session() - except AttributeError: # theano does not have this function - pass diff --git a/Pilot1/NT3/nt3_cf/analyze.ipynb b/Pilot1/NT3/nt3_cf/analyze.ipynb deleted file mode 100644 index 99d570c1..00000000 --- a/Pilot1/NT3/nt3_cf/analyze.ipynb +++ /dev/null @@ -1,655 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [], - "source": [ - "import numpy as np\n", - "import pickle\n", - "import os\n", - "import matplotlib.pyplot as plt\n", - "from sklearn.cluster import KMeans\n", - "from sklearn.decomposition import PCA\n", - "from sklearn.manifold import TSNE" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "parent = '/Users/shah38/Desktop/xai-geom/nt3/'\n", - "# directory = parent + 'pickle_summit_35'\n", - "# counterfactuals = []\n", - "# count = 0\n", - "# for filename in os.listdir(directory):\n", - "# if filename.startswith(\"save\"):\n", - "# count+=1\n", - "# d = pickle.load(open(os.path.join(directory, filename), 'rb'))\n", - "# counterfactuals.append(d)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "test = [item for sublist in counterfactuals for item in sublist]" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1160\n" - ] - } - ], - "source": [ - "print(len(test))" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "29\n" - ] - } - ], - "source": [ - "print(count)" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "ename": "FileNotFoundError", - "evalue": "[Errno 2] No such file or directory: 'nt3.autosave.data.pkl'", - "output_type": "error", - "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mFileNotFoundError\u001B[0m Traceback (most recent call last)", - "\u001B[0;32m\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mdata\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mpickle\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mload\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mopen\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m\"nt3.autosave.data.pkl\"\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;34m'rb'\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 2\u001B[0m \u001B[0mprint\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mdata\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;36m0\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mshape\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 3\u001B[0m \u001B[0;34m\u001B[0m\u001B[0m\n", - "\u001B[0;31mFileNotFoundError\u001B[0m: [Errno 2] No such file or directory: 'nt3.autosave.data.pkl'" - ] - } - ], - "source": [ - "data = pickle.load(open(\"nt3.autosave.data.pkl\",'rb'))\n", - "print(data[0].shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "pickle.dump(test, open(\"complete_save.pkl\", 'wb'))" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [], - "source": [ - "# Plot # of positive/negative indices per threshold value\n", - "# Pick a threshold\n", - "# Find genes that overlap the most\n", - "num_pos = []\n", - "num_neg = []\n", - "#threshold_values = [0.0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1.0]\n", - "threshold_values = [0.9]#, 0.925, 0.95, 0.975, 1.0]\n", - "for t in threshold_values:\n", - " thresholds = pickle.load(open('{}threshold_{}.pkl'.format(parent,t), 'rb'))\n", - " pos = thresholds['positive threshold indices']\n", - " num_pos.append([pos[i][0].shape[0] for i in range(len(pos))])\n", - " neg = thresholds['negative threshold indices']\n", - " num_neg.append([neg[i][0].shape[0] for i in range(len(neg))])" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": "
", - "image/png": "\n" - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "fig, ax = plt.subplots()\n", - "pos = np.arange(len(num_pos)) + 1\n", - "total = [np.array(num_pos[i]) + np.array(num_neg[i]) for i in range(len(num_pos)) ]\n", - "bp = ax.boxplot(total, sym='k+', positions=pos)\n", - "\n", - "ax.set_xlabel('threshold value')\n", - "ax.set_ylabel('# indices')\n", - "#ax.set_xticks(np.arange(0,1.1,0.1))\n", - "plt.setp(bp['whiskers'], color='k', linestyle='-')\n", - "plt.setp(bp['fliers'], markersize=3.0)\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "fig, ax = plt.subplots()\n", - "pos = np.arange(len(num_pos)) + 1\n", - "bp = ax.boxplot(num_pos, sym='k+', positions=pos)\n", - "\n", - "ax.set_xlabel('threshold value')\n", - "ax.set_xticklabels([0.9,0.925,0.95,0.975,1.0])\n", - "ax.set_ylabel('# indices')\n", - "plt.setp(bp['whiskers'], color='k', linestyle='-')\n", - "plt.setp(bp['fliers'], markersize=3.0)\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": "
", - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAOnElEQVR4nO3dXYyc1X3H8e+vONAqiWIDW8uynS5pLFXcBKwVdZUoakEh4FQ1lRJEVRWLWvINkYjSqnWai6ZSL6BSQ4tUIbkF1URpCMqLsAJt4hqiqBcQlgTMWykLBWHLYCcQkihKWpJ/L+a4mZhd7+w7e/b7kVZznnPOzJz/PuOfn33mLVWFJKkvv7TSC5AkLT7DXZI6ZLhLUocMd0nqkOEuSR1at9ILADj//PNrfHx8pZchSavKww8//J2qGptu7E0R7uPj40xOTq70MiRpVUnywkxjnpaRpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOrfpwH993D+P77lnpZUjSm8qqD3dJ0hsZ7pLUoZHCPcnzSR5L8kiSydZ3bpJDSZ5plxtaf5LckmQqyZEk25eyAEnSG83lyP13quqiqppo2/uAw1W1DTjctgGuBLa1n73ArYu1WEnSaBZyWmYXcKC1DwBXDfXfUQMPAOuTbFrA/UiS5mjUcC/ga0keTrK39W2squOt/RKwsbU3Ay8OXfdo65MkLZNRv6zjfVV1LMmvAoeS/OfwYFVVkprLHbf/JPYCvPOd75zLVSVJsxjpyL2qjrXLE8CXgUuAl0+dbmmXJ9r0Y8DWoatvaX2n3+b+qpqoqomxsWm/JUqSNE+zhnuStyZ5+6k2cDnwOHAQ2N2m7Qbubu2DwLXtVTM7gNeGTt9IkpbBKKdlNgJfTnJq/r9U1b8leQi4K8ke4AXg6jb/XmAnMAX8CLhu0VctSTqjWcO9qp4D3jNN/3eBy6bpL+D6RVmdJGlefIeqJHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdGjnck5yV5NtJvtK2L0jyYJKpJJ9PcnbrP6dtT7Xx8aVZuiRpJnM5cr8BeGpo+ybg5qp6N/AqsKf17wFebf03t3mSpGU0Urgn2QJ8CPinth3gUuALbcoB4KrW3tW2aeOXtfmSpGUy6pH73wF/BvysbZ8HfK+qXm/bR4HNrb0ZeBGgjb/W5kuSlsms4Z7kd4ETVfXwYt5xkr1JJpNMnjx5cjFvWpLWvFGO3N8L/F6S54E7GZyO+XtgfZJ1bc4W4FhrHwO2ArTxdwDfPf1Gq2p/VU1U1cTY2NiCipAk/aJZw72qPlFVW6pqHLgGuK+q/hC4H/hwm7YbuLu1D7Zt2vh9VVWLumpJ0hkt5HXufw58PMkUg3Pqt7X+24DzWv/HgX0LW6Ikaa7WzT7l56rq68DXW/s54JJp5vwY+MgirE2SNE++Q1WSOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA7NGu5JfjnJN5M8muSJJH/V+i9I8mCSqSSfT3J26z+nbU+18fGlLUGSdLpRjtx/AlxaVe8BLgKuSLIDuAm4uareDbwK7Gnz9wCvtv6b2zxJ0jKaNdxr4Idt8y3tp4BLgS+0/gPAVa29q23Txi9LkkVbsSRpViOdc09yVpJHgBPAIeBZ4HtV9XqbchTY3NqbgRcB2vhrwHmLuWhJ0pmNFO5V9dOqugjYAlwC/MZC7zjJ3iSTSSZPnjy50JuTJA2Z06tlqup7wP3AbwHrk6xrQ1uAY619DNgK0MbfAXx3mtvaX1UTVTUxNjY2z+VLkqYzyqtlxpKsb+1fAT4APMUg5D/cpu0G7m7tg22bNn5fVdViLlqSdGbrZp/CJuBAkrMY/GdwV1V9JcmTwJ1J/hr4NnBbm38b8JkkU8ArwDVLsG5J0hnMGu5VdQS4eJr+5xicfz+9/8fARxZldZKkefEdqpLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHeom3Mf33bPSS5CkN41uwl2S9HOGuyR1aNZwT7I1yf1JnkzyRJIbWv+5SQ4leaZdbmj9SXJLkqkkR5JsX+oiJEm/aJQj99eBP6mqC4EdwPVJLgT2AYerahtwuG0DXAlsaz97gVsXfdWSpDOaNdyr6nhVfau1fwA8BWwGdgEH2rQDwFWtvQu4owYeANYn2bToK5ckzWhO59yTjAMXAw8CG6vqeBt6CdjY2puBF4eudrT1nX5be5NMJpk8efLkHJctSTqTkcM9yduALwIfq6rvD49VVQE1lzuuqv1VNVFVE2NjY3O5qiRpFiOFe5K3MAj2z1bVl1r3y6dOt7TLE63/GLB16OpbWp8kaZmM8mqZALcBT1XVp4eGDgK7W3s3cPdQ/7XtVTM7gNeGTt9IkpbBuhHmvBf4I+CxJI+0vr8AbgTuSrIHeAG4uo3dC+wEpoAfAdct6oolSbOaNdyr6j+AzDB82TTzC7h+geuSJC2A71CVpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkd6ircx/fds9JLkKQ3ha7CXZI0sG6lF7AUznQE//yNH1rGlUjSyvDIXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHVo1nBPcnuSE0keH+o7N8mhJM+0yw2tP0luSTKV5EiS7Uu5eEnS9EY5cv9n4IrT+vYBh6tqG3C4bQNcCWxrP3uBWxdnmZKkuZg13KvqG8Arp3XvAg609gHgqqH+O2rgAWB9kk2LtVhJ0mjme859Y1Udb+2XgI2tvRl4cWje0db3Bkn2JplMMnny5Ml5LkOSNJ0FP6FaVQXUPK63v6omqmpibGxsocuQJA2Zb7i/fOp0S7s80fqPAVuH5m1pfZKkZTTfcD8I7G7t3cDdQ/3XtlfN7ABeGzp9I0laJrN+KmSSzwG/DZyf5Cjwl8CNwF1J9gAvAFe36fcCO4Ep4EfAdUuwZknSLGYN96r6gxmGLptmbgHXL3RRkqSF8R2qktQhw12SOmS4S1KHDHdJ6pDhLkkd6i7cz/Tl2JK0VnQX7pIkw12SumS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA7N+pG/venhTU7P3/ihlV6CpDc5j9wlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR1aknBPckWSp5NMJdm3FPchSZrZon8TU5KzgH8APgAcBR5KcrCqnlzs+1qrevg2qblai98+5X7WQizF1+xdAkxV1XMASe4EdgGGu+ZtLQbdWrQW9/NS/Ye2FOG+GXhxaPso8JunT0qyF9jbNn+Y5Ol53t/5wHfmed3VyprXBmteA3LTgmr+tZkGVuwLsqtqP7B/obeTZLKqJhZhSauGNa8N1rw2LFXNS/GE6jFg69D2ltYnSVomSxHuDwHbklyQ5GzgGuDgEtyPJGkGi35apqpeT/JR4KvAWcDtVfXEYt/PkAWf2lmFrHltsOa1YUlqTlUtxe1KklaQ71CVpA4Z7pLUoVUd7j1/zEGS55M8luSRJJOt79wkh5I80y43tP4kuaX9Ho4k2b6yqx9NktuTnEjy+FDfnGtMsrvNfybJ7pWoZRQz1PupJMfafn4kyc6hsU+0ep9O8sGh/lXzuE+yNcn9SZ5M8kSSG1p/z/t5ppqXd19X1ar8YfBk7bPAu4CzgUeBC1d6XYtY3/PA+af1/Q2wr7X3ATe19k7gX4EAO4AHV3r9I9b4fmA78Ph8awTOBZ5rlxtae8NK1zaHej8F/Ok0cy9sj+lzgAvaY/2s1fa4BzYB21v77cB/tdp63s8z1bys+3o1H7n//8ccVNX/AKc+5qBnu4ADrX0AuGqo/44aeABYn2TTSixwLqrqG8Arp3XPtcYPAoeq6pWqehU4BFyx9Kufuxnqncku4M6q+klV/TcwxeAxv6oe91V1vKq+1do/AJ5i8C72nvfzTDXPZEn29WoO9+k+5uBMv8DVpoCvJXm4fVQDwMaqOt7aLwEbW7un38Vca+yh9o+2UxC3nzo9QYf1JhkHLgYeZI3s59NqhmXc16s53Hv3vqraDlwJXJ/k/cODNfh7ruvXsa6FGoFbgV8HLgKOA3+7sstZGkneBnwR+FhVfX94rNf9PE3Ny7qvV3O4d/0xB1V1rF2eAL7M4E+0l0+dbmmXJ9r0nn4Xc61xVddeVS9X1U+r6mfAPzLYz9BRvUnewiDkPltVX2rdXe/n6Wpe7n29msO92485SPLWJG8/1QYuBx5nUN+pVwnsBu5u7YPAte2VBjuA14b+5F1t5lrjV4HLk2xof+Ze3vpWhdOeG/l9BvsZBvVek+ScJBcA24Bvssoe90kC3AY8VVWfHhrqdj/PVPOy7+uVfmZ5gc9K72TwTPSzwCdXej2LWNe7GDwz/ijwxKnagPOAw8AzwL8D57b+MPiClGeBx4CJla5hxDo/x+DP0/9lcD5xz3xqBP6YwZNQU8B1K13XHOv9TKvnSPuHu2lo/idbvU8DVw71r5rHPfA+BqdcjgCPtJ+dne/nmWpe1n3txw9IUodW82kZSdIMDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUof8Dhz0GWCovLmwAAAAASUVORK5CYII=\n" - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "plt.hist(num_pos[0], bins=[0,10,20,30,40,100,500,1000,1500,2000,2500])\n", - "num_pos_9 = np.array(num_pos[0])\n", - "indices_pos_9 = np.argwhere((num_pos_9 <= 20) & (num_pos_9 > 10))" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": "
", - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXcAAAD4CAYAAAAXUaZHAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAR50lEQVR4nO3dbYxcV33H8e+vTghVoU1CtpZrW90Arqq0Uk20TVOBKpoISExVBwlQUFUsGsmtFCRQn3Doi1KpkUJVSIvURjJNiqkoIeJBsUgomBCEeEHChhoTJ02zgFFsmXhLQgChpk3498Ucl8Hsw+zOPrDH3480mnPPOXfmHN/1b2fP3JmbqkKS1JefWu8BSJJWnuEuSR0y3CWpQ4a7JHXIcJekDp2z3gMAuOiii2pycnK9hyFJG8oDDzzwX1U1MVfbT0S4T05OMj09vd7DkKQNJck35mtzWUaSOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjq04cN9ct9dTO67a72HIUk/UTZ8uEuSfpzhLkkdMtwlqUMjh3uSTUn+PcnH2/bFSe5LMpPkQ0me0+rPa9szrX1ydYYuSZrPUl65vwV4eGj7ncDNVfVi4EngulZ/HfBkq7+59ZMkraGRwj3JNuDVwD+17QBXAB9uXQ4A17Ty7rZNa7+y9ZckrZFRX7n/HfDnwA/a9guAb1fVM237OLC1lbcCjwG09qda/x+RZG+S6STTs7Ozyxy+JGkui4Z7kt8BTlXVAyv5xFW1v6qmqmpqYmLOq0RJkpZplMvsvRT43SS7gOcCPwv8PXB+knPaq/NtwInW/wSwHTie5Bzg54BvrfjIJUnzWvSVe1XdUFXbqmoSuBb4TFX9HnAv8NrWbQ9wZysfbNu09s9UVa3oqCVJCxrnPPe3AX+cZIbBmvqtrf5W4AWt/o+BfeMNUZK0VKMsy/y/qvos8NlW/hpw2Rx9/ht43QqMTZK0TH5CVZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUoVEukP3cJPcn+XKSo0n+qtW/L8nXkxxut52tPknek2QmyZEkl672JCRJP2qUKzE9DVxRVd9Lci7w+SSfaG1/VlUfPqP/1cCOdvsN4JZ2L0laI6NcILuq6ntt89x2W+iC17uB97f9vgCcn2TL+EOVJI1qpDX3JJuSHAZOAYeq6r7WdGNberk5yXmtbivw2NDux1vdmY+5N8l0kunZ2dkxpiBJOtNI4V5Vz1bVTmAbcFmSXwVuAH4Z+HXgQuBtS3niqtpfVVNVNTUxMbHEYUuSFrKks2Wq6tvAvcBVVXWyLb08DfwzcFnrdgLYPrTbtlYnSVojo5wtM5Hk/Fb+aeAVwH+cXkdPEuAa4MG2y0Hgje2smcuBp6rq5KqMXpI0p1HOltkCHEiyicEvgzuq6uNJPpNkAghwGPij1v9uYBcwA3wfeNPKD1uStJBFw72qjgAvmaP+inn6F3D9+EOTJC2Xn1CVpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHVolMvsPTfJ/Um+nORokr9q9RcnuS/JTJIPJXlOqz+vbc+09snVnYIk6UyjvHJ/Griiqn4N2Alc1a6N+k7g5qp6MfAkcF3rfx3wZKu/ufWTJK2hRcO9Br7XNs9ttwKuAD7c6g8wuEg2wO62TWu/sl1EW5K0RkZac0+yKclh4BRwCPgq8O2qeqZ1OQ5sbeWtwGMArf0p4AVzPObeJNNJpmdnZ8ebhSTpR4wU7lX1bFXtBLYBlwG/PO4TV9X+qpqqqqmJiYlxH06SNGRJZ8tU1beBe4HfBM5Pck5r2gacaOUTwHaA1v5zwLdWZLSSpJGMcrbMRJLzW/mngVcADzMI+de2bnuAO1v5YNumtX+mqmolBy1JWtg5i3dhC3AgySYGvwzuqKqPJ3kIuD3JXwP/Dtza+t8K/EuSGeAJ4NpVGLckaQGLhntVHQFeMkf91xisv59Z/9/A61ZkdJKkZfETqpLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDo1ymb3tSe5N8lCSo0ne0urfkeREksPttmtonxuSzCR5JMmrVnMCkqQfN8pl9p4B/qSqvpTk+cADSQ61tpur6m+HOye5hMGl9X4F+AXg00l+qaqeXcmBS5Lmt+gr96o6WVVfauXvMrg49tYFdtkN3F5VT1fV14EZ5rgcnyRp9SxpzT3JJIPrqd7Xqt6c5EiS25Jc0Oq2Ao8N7XacOX4ZJNmbZDrJ9Ozs7JIHfqbJfXeN/RiS1IuRwz3J84CPAG+tqu8AtwAvAnYCJ4F3LeWJq2p/VU1V1dTExMRSdpUkLWKkcE9yLoNg/0BVfRSgqh6vqmer6gfAe/nh0ssJYPvQ7ttanSRpjYxytkyAW4GHq+rdQ/Vbhrq9BniwlQ8C1yY5L8nFwA7g/pUbsiRpMaOcLfNS4PeBryQ53OreDrwhyU6ggGPAHwJU1dEkdwAPMTjT5nrPlJGktbVouFfV54HM0XT3AvvcCNw4xrgkSWPwE6qS1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA6Ncpm97UnuTfJQkqNJ3tLqL0xyKMmj7f6CVp8k70kyk+RIkktXexKSpB81yiv3Z4A/qapLgMuB65NcAuwD7qmqHcA9bRvgagbXTd0B7AVuWfFRS5IWtGi4V9XJqvpSK38XeBjYCuwGDrRuB4BrWnk38P4a+AJw/hkX05YkrbIlrbknmQReAtwHbK6qk63pm8DmVt4KPDa02/FWd+Zj7U0ynWR6dnZ2icOWJC1k5HBP8jzgI8Bbq+o7w21VVUAt5Ymran9VTVXV1MTExFJ2lSQtYqRwT3Iug2D/QFV9tFU/fnq5pd2favUngO1Du29rdZKkNTLK2TIBbgUerqp3DzUdBPa08h7gzqH6N7azZi4HnhpavpEkrYFzRujzUuD3ga8kOdzq3g7cBNyR5DrgG8DrW9vdwC5gBvg+8KYVHbEkaVGLhntVfR7IPM1XztG/gOvHHJckaQx+QlWSOmS4S1KHDHdJ6pDhLkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1KFRLrN3W5JTSR4cqntHkhNJDrfbrqG2G5LMJHkkyatWa+CSpPmN8sr9fcBVc9TfXFU72+1ugCSXANcCv9L2+cckm1ZqsJKk0Swa7lX1OeCJER9vN3B7VT1dVV9ncB3Vy8YYnyRpGcZZc39zkiNt2eaCVrcVeGyoz/FW92OS7E0ynWR6dnZ2jGH80OS+u1bkcSRpo1tuuN8CvAjYCZwE3rXUB6iq/VU1VVVTExMTyxyGJGkuywr3qnq8qp6tqh8A7+WHSy8ngO1DXbe1OknSGlpWuCfZMrT5GuD0mTQHgWuTnJfkYmAHcP94Q5QkLdU5i3VI8kHg5cBFSY4Dfwm8PMlOoIBjwB8CVNXRJHcADwHPANdX1bOrM/T5Lbb2fuymV6/RSCRpfSwa7lX1hjmqb12g/43AjeMMSpI0Hj+hKkkdMtwlqUOGuyR1yHCXpA4Z7pLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nq0KLhnuS2JKeSPDhUd2GSQ0kebfcXtPokeU+SmSRHkly6moOXJM1tlFfu7wOuOqNuH3BPVe0A7mnbAFczuG7qDmAvcMvKDFOStBSLhntVfQ544ozq3cCBVj4AXDNU//4a+AJw/hkX05YkrYHlrrlvrqqTrfxNYHMrbwUeG+p3vNX9mCR7k0wnmZ6dnV3mMCRJcxn7DdWqKqCWsd/+qpqqqqmJiYlxhyFJGrLccH/89HJLuz/V6k8A24f6bWt1a2Zy311r+XSS9BNpueF+ENjTynuAO4fq39jOmrkceGpo+UaStEbOWaxDkg8CLwcuSnIc+EvgJuCOJNcB3wBe37rfDewCZoDvA29ahTFLkhaxaLhX1Rvmabpyjr4FXD/uoCRJ4/ETqpLUIcNdkjpkuEtShwx3SeqQ4S5JHTLcJalDhrskdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHVo0e9z79FGvxTfsZtevd5DkPQTbqxwT3IM+C7wLPBMVU0luRD4EDAJHANeX1VPjjdMSdJSrMSyzG9X1c6qmmrb+4B7qmoHcE/bliStodVYc98NHGjlA8A1q/AckqQFjBvuBXwqyQNJ9ra6zVV1spW/CWyea8cke5NMJ5menZ0dcxiSpGHjvqH6sqo6keTngUNJ/mO4saoqSc21Y1XtB/YDTE1NzdlHkrQ8Y71yr6oT7f4U8DHgMuDxJFsA2v2pcQcpSVqaZYd7kp9J8vzTZeCVwIPAQWBP67YHuHPcQUqSlmacZZnNwMeSnH6cf62qf0vyReCOJNcB3wBeP/4wJUlLsexwr6qvAb82R/23gCvHGZQkaTx+/YAkdchwl6QOGe6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ4a7JHXIcJekDhnuktQhw12SOmS4S1KHDHdJ6pDhLkkdMtwlqUPjXiB7XkmuAv4e2AT8U1XdtFrPdbaZ3HfXeg9hzR276dXrPQRpQ1mVcE+yCfgH4BXAceCLSQ5W1UOr8Xzq39n4C+1s5C/xlbNar9wvA2bapfhIcjuwGzDcJc3rbPwlvlq/0FYr3LcCjw1tHwd+Y7hDkr3A3rb5vSSPLPO5LgL+a5n7blTO+ezgnM8CeedYc/7F+RpWbc19MVW1H9g/7uMkma6qqRUY0obhnM8OzvnssFpzXq2zZU4A24e2t7U6SdIaWK1w/yKwI8nFSZ4DXAscXKXnkiSdYVWWZarqmSRvBj7J4FTI26rq6Go8FyuwtLMBOeezg3M+O6zKnFNVq/G4kqR15CdUJalDhrskdWhDh3uSq5I8kmQmyb71Hs9KSnIsyVeSHE4y3eouTHIoyaPt/oJWnyTvaf8OR5Jcur6jH02S25KcSvLgUN2S55hkT+v/aJI96zGXUc0z53ckOdGO9eEku4babmhzfiTJq4bqN8TPfpLtSe5N8lCSo0ne0uq7Pc4LzHltj3NVbcgbgzdqvwq8EHgO8GXgkvUe1wrO7xhw0Rl1fwPsa+V9wDtbeRfwCSDA5cB96z3+Eef4W8ClwIPLnSNwIfC1dn9BK1+w3nNb4pzfAfzpHH0vaT/X5wEXt5/3TRvpZx/YAlzays8H/rPNq9vjvMCc1/Q4b+RX7v//FQdV9T/A6a846Nlu4EArHwCuGap/fw18ATg/yZb1GOBSVNXngCfOqF7qHF8FHKqqJ6rqSeAQcNXqj3555pnzfHYDt1fV01X1dWCGwc/9hvnZr6qTVfWlVv4u8DCDT7B3e5wXmPN8VuU4b+Rwn+srDhb6B9xoCvhUkgfaVzUAbK6qk638TWBzK/f0b7HUOfYy9ze3ZYjbTi9R0Nmck0wCLwHu4yw5zmfMGdbwOG/kcO/dy6rqUuBq4PokvzXcWIO/57o+j/VsmGNzC/AiYCdwEnjX+g5n5SV5HvAR4K1V9Z3htl6P8xxzXtPjvJHDveuvOKiqE+3+FPAxBn+iPX56uaXdn2rde/q3WOocN/zcq+rxqnq2qn4AvJfBsYZO5pzkXAYh94Gq+mir7vo4zzXntT7OGzncu/2KgyQ/k+T5p8vAK4EHGczv9FkCe4A7W/kg8MZ2psHlwFNDf/JuNEud4yeBVya5oP2Z+8pWt2Gc8f7IaxgcaxjM+dok5yW5GNgB3M8G+tlPEuBW4OGqevdQU7fHeb45r/lxXu93lsd8V3oXg3eivwr8xXqPZwXn9UIG74x/GTh6em7AC4B7gEeBTwMXtvowuDjKV4GvAFPrPYcR5/lBBn+e/i+D9cTrljNH4A8YvAk1A7xpvee1jDn/S5vTkfafd8tQ/79oc34EuHqofkP87AMvY7DkcgQ43G67ej7OC8x5TY+zXz8gSR3ayMsykqR5GO6S1CHDXZI6ZLhLUocMd0nqkOEuSR0y3CWpQ/8H2vJR9Ye8H/sAAAAASUVORK5CYII=\n" - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "plt.hist(num_neg[0], bins=[0,10,20,30,40,100,500,1000,1500,2000,2500])\n", - "num_neg_9 = np.array(num_neg[0])\n", - "indices_neg_9 = np.argwhere((num_neg_9 <= 20) & (num_neg_9 > 10))" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[642, 1027, 258, 134, 394, 1046, 279, 536, 791, 923, 798, 676, 682, 811, 556, 812, 432, 694, 55, 314, 957, 1093, 197, 591, 93, 735, 357, 1017, 634, 638, 767]\n" - ] - } - ], - "source": [ - "# These are the counterfactuals that have between 10 and 20 pos/neg perturbations > 0.9*max\n", - "overlap = list(set(np.squeeze(indices_pos_9)).intersection(set( np.squeeze(indices_neg_9))))\n", - "print(overlap)" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "ename": "NameError", - "evalue": "name 'pickle' is not defined", - "output_type": "error", - "traceback": [ - "\u001B[0;31m---------------------------------------------------------------------------\u001B[0m", - "\u001B[0;31mNameError\u001B[0m Traceback (most recent call last)", - "\u001B[0;32m\u001B[0m in \u001B[0;36m\u001B[0;34m\u001B[0m\n\u001B[0;32m----> 1\u001B[0;31m \u001B[0mthresholds_9\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0mpickle\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mload\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mopen\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0;34m'{}threshold_{}.pkl'\u001B[0m\u001B[0;34m.\u001B[0m\u001B[0mformat\u001B[0m\u001B[0;34m(\u001B[0m\u001B[0mparent\u001B[0m\u001B[0;34m,\u001B[0m\u001B[0;36m0.9\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m,\u001B[0m \u001B[0;34m'rb'\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m)\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[0m\u001B[1;32m 2\u001B[0m \u001B[0mgenes_pos\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 3\u001B[0m \u001B[0mgenes_neg\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 4\u001B[0m \u001B[0mperturb_vector\u001B[0m\u001B[0;34m=\u001B[0m\u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n\u001B[1;32m 5\u001B[0m \u001B[0mcf_class\u001B[0m \u001B[0;34m=\u001B[0m \u001B[0;34m[\u001B[0m\u001B[0;34m]\u001B[0m\u001B[0;34m\u001B[0m\u001B[0;34m\u001B[0m\u001B[0m\n", - "\u001B[0;31mNameError\u001B[0m: name 'pickle' is not defined" - ] - } - ], - "source": [ - "thresholds_9 = pickle.load(open('{}threshold_{}.pkl'.format(parent,0.9), 'rb'))\n", - "genes_pos = []\n", - "genes_neg = []\n", - "perturb_vector=[]\n", - "cf_class = []\n", - "for i in overlap:\n", - " genes_pos.append(thresholds_9['positive threshold indices'][i])\n", - " genes_neg.append(thresholds_9['negative threshold indices'][i])\n", - " perturb_vector.append(thresholds_9['perturbation vector'][i])\n", - " cf_class.append(thresholds_9['counterfactual class'][i])" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "18 13\n" - ] - } - ], - "source": [ - "# These are the set of genes (indices) that have been perturbed more than 0.9*max perturbation\n", - "# in the counterfactual example \n", - "\n", - "# split by class\n", - "genes_pos_0 = []\n", - "genes_neg_0 = []\n", - "perturb_vector_0=[]\n", - "genes_pos_1 = []\n", - "genes_neg_1 = []\n", - "perturb_vector_1=[]\n", - "for i,j,k,l in zip(genes_pos, genes_neg, perturb_vector, cf_class):\n", - " if l==0:\n", - " genes_pos_0.append(i)\n", - " genes_neg_0.append(j)\n", - " perturb_vector_0.append(k)\n", - " else:\n", - " genes_pos_1.append(i)\n", - " genes_neg_1.append(j)\n", - " perturb_vector_1.append(k)\n", - "print(len(genes_neg_0), len(genes_neg_1))" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "metadata": {}, - "outputs": [], - "source": [ - "# cluster the counterfactual examples\n", - "num_clusters = 1\n", - "kmeans_0 = KMeans(n_clusters=num_clusters, random_state=0).fit(perturb_vector_0)\n", - "kmeans_1 = KMeans(n_clusters=num_clusters, random_state=0).fit(perturb_vector_1)" - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "outputs": [], - "source": [ - "for i in range(len(kmeans_0.cluster_centers_)):\n", - " diff_1=kmeans_1.cluster_centers_[i]\n", - " max_value = np.max(np.abs(diff_1))\n", - " ind_pos = np.where(diff_1 > 0.9*max_value)\n", - " ind_neg = np.where(diff_1 < -0.9*max_value)\n", - " pickle.dump([ind_pos, ind_neg], open(\"{}cf_class_1.pkl\".format(parent), \"wb\"))\n", - " diff_0=kmeans_1.cluster_centers_[i]\n", - " max_value = np.max(np.abs(diff_0))\n", - " ind_pos = np.where(diff_1 > 0.9*max_value)\n", - " ind_neg = np.where(diff_1 < -0.9*max_value)\n", - " pickle.dump([ind_pos, ind_neg],open(\"{}cf_class_0.pkl\".format(parent), \"wb\"))" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 39, - "outputs": [ - { - "data": { - "text/plain": "
", - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAPSklEQVR4nO3dX4hc533G8edZuwlsmoItLapqe3cUowSUQpUwiBbcEFM3kQ1FcUpAZii+CGwKNvTfjcNe1DeCUuqGXKRux0TYNFObQmssGhMnNiWmUOqMUkVexaiWHa0soUir+KIpW5za++vFOasdrWalnZ05c8555/uB4cx5z+68P46PHr96z6szjggBANI0VXYBAIDiEPIAkDBCHgASRsgDQMIIeQBI2K1lF9Br586d0Wg0yi4DAGrl+PHjVyJipt+xSoV8o9FQt9stuwwAqBXbS5sdY7oGABJGyANAwgh5AEgYIQ8ACSPkASBhhDxQV52O1GhIU1PZttMpuyJUUKWWUALYok5Hmp+XVlay/aWlbF+SWq3y6kLlMJIH6mhhYT3g16ysZO1AD0IeqKNz5wZrx8Qi5IE6mp0drB0Ti5AH6ujIEWl6+tq26emsHehByAN11GpJ7bY0NyfZ2bbd5qYrrsPqGqCuWi1CHTfFSB4AEkbIA0DCCHkASNhIQt72UduXbS/2tD1u+4LtE/nrgVH0BQDYulGN5J+WdLBP+9ciYn/+enFEfQEAtmgkIR8Rr0p6dxSfBQAYnaLn5B+1fTKfzrmt3w/Ynrfdtd1dXl4uuBwAmCxFhvyTku6WtF/SRUlP9PuhiGhHRDMimjMzfb9sHACwTYWFfERciogPImJV0lOSDhTVFwCgv8JC3vbunt0HJS1u9rMAgGKM5LEGtp+V9FlJO22fl/Tnkj5re7+kkHRW0ldG0RcAYOtGEvIR8VCf5m+O4rMBANvHv3gFgIQR8gCQMEIeABJGyANAwgh5AEgYIQ8ACSPkASBhhDwAJIyQB4CEEfIAkDBCHgASRsgDQMIIeQBIGCEPAAkj5AEgYYQ8ACSMkAeAhBHyAJAwQh4AEkbIA0DCCHkASBghDwAJI+QBIGGEPAAkbCQhb/uo7cu2F3vabrf9Pdtv5tvbRtEXAGDrRjWSf1rSwQ1tj0l6JSL2Snol3wcAjNFIQj4iXpX07obmQ5Keyd8/I+kLo+gLALB1Rc7J74qIi/n7n0ra1e+HbM/b7truLi8vF1gOAEyesdx4jYiQFJsca0dEMyKaMzMz4ygHACZGkSF/yfZuScq3lwvsCwDQR5Ehf0zSw/n7hyW9UGBfAIA+RrWE8llJ/y7pE7bP2/6ypL+Q9Lu235R0X74PABijW0fxIRHx0CaHfmcUnw8A2B7+xSsAJIyQB4CEEfIAkDBCHgASRsgDQMIIeQBIGCEPAAkj5AEgYYQ8ACSMkAeAhBHyAJAwQh6om05HajSkqals2+mUXREqbCQPKAMwJp2OND8vraxk+0tL2b4ktVrl1YXKYiQP1MnCwnrAr1lZydqBPgh5oE7OnRusHROPkAfqZHZ2sHZMPEIeqJMjR6Tp6WvbpqezdqAPQh6ok1ZLareluTnJzrbtNjddsSlW1wB102oR6tgyRvKoF9aIAwNhJI/6YI04MDBG8qgP1ogDAyPkUR+sEQcGRsijPlgjDgyMkEd9sEYcGFjhN15tn5X0c0kfSHo/IppF94lErd1cXVjIpmhmZ7OA56YrsKlxra65NyKujKkvpIw14sBAmK4BgISNI+RD0ndtH7c9v/Gg7XnbXdvd5eXlMZQDAJNjHCF/T0R8WtL9kh6x/ZnegxHRjohmRDRnZmbGUA4ATI7CQz4iLuTby5Kel3Sg6D4BAJlCQ972R2x/dO29pM9JWiyyTwDAuqJH8rsk/ZvtH0l6TdK3I+I7BfeZFh7IBWAIhS6hjIi3Jf1GkX0kjQdyARgSSyirjAdyARgSIV9lPJALwJDqH/Ipz1nzQC4AQ6p3yK/NWS8tSRHrc9apBD0P5AIwpHqHfOpz1nxpM4AhOSLKruGqZrMZ3W53678wNZWN4DeypdXV0RUGABVm+/hmT/it90ieOWsAuKF6hzxz1kDaiw8wtHqHPHPWmHSpLz7A0Oo9Jw9MukYjC/aN5uaks2fHXQ1Kku6cPDDp+AdzuAlCHqgzFh/gJgh5oM5YfICbIOSBOmPxAW6i0EcNAxiDVotQx6YYyQNAwgh5AEgYIQ8ACSPkASBhhDwAJIyQB4CEEfIAkDBCHgASRsgDQMIKD3nbB22ftn3G9mNF9wcAWFdoyNu+RdI3JN0vaZ+kh2zvK7JPAMC6okfyBySdiYi3I+IXkp6TdKjgPgEAuaJD/g5J7/Tsn8/brrI9b7tru7u8vFxwOQAwWUq/8RoR7YhoRkRzZmam7HIAIClFh/wFSXf17N+ZtwEAxqDokP+BpL2299j+kKTDko4V3CcAIFfol4ZExPu2H5X0kqRbJB2NiFNF9gkAWFf4N0NFxIuSXiy6HwDA9Uq/8QoAKA4hDwAJI+QBIGGEPAAkjJAHgIQR8gCQMEIeABJGyANAwgh5AEgYIQ8ACSPkASBhhDwAJIyQR/E6HanRkKamsm2nU3ZFwMQo/CmUmHCdjjQ/L62sZPtLS9m+JLVa5dUFTAhG8ijWwsJ6wK9ZWcnaARSOkEexzp0brB3ASBHyKNbs7GDtAEaKkEexjhyRpqevbZueztoBFI6QR7FaLandlubmJDvbttvp33RlRREqwhFRdg1XNZvN6Ha7ZZcBDGfjiqI1O3ZIX/96+v+Dw9jZPh4RzX7HGMkDo9ZvRZEk/exnWfgzqscYEfLAqN1o5RDLRzFmhDwwajdbOcTyUYwRIY96q+INzn4rinqxfBRjRMijvtZucC4tSRHrj0woO+jXVhTt2HH9MZaPYqOCByqFhbztx21fsH0ifz1QVF+YUFV+ZEKrJV25In3rW5O3fBRbN4aBSmFLKG0/Lul/IuKvtvo7LKHEQKamsj8YG9nS6ur46wEG1Whkwb7R3Jx09uyWP4YllEgTj0xA3Y3h2U5Fh/yjtk/aPmr7tn4/YHvedtd2d3l5ueBykBQemYC6G8NAZaiQt/2y7cU+r0OSnpR0t6T9ki5KeqLfZ0REOyKaEdGcmZkZphxMmkl9ZALSMYaBylAhHxH3RcSv93m9EBGXIuKDiFiV9JSkA6MpGZKquXSwDK1WNne5upptCXjUyRgGKoV9M5Tt3RFxMd99UNJiUX1NHL5tCUhHq1Xon9si5+T/0vbrtk9KulfSnxTY12Sp8tJBAJVS2Eg+Iv6gqM+eeHzbEoAtYgllHbF0EMAWEfJ1xNJBAFtEyNcRSwcBbFFhc/IoWMF35AGkgZE8ACSMkAeAhBHyAJAwQh4AEkbIA0DCCHkASBghDwAJI+QBIGGEPAAkjJAHgIQR8gCQMEIeABJGyANAwgh5AEgYIQ8ACSPkASBhhDwAJIyQB4CEEfIAkDBCHgASNlTI2/6S7VO2V203Nxz7qu0ztk/b/vxwZQIAtuPWIX9/UdIXJf1db6PtfZIOS/qkpF+T9LLtj0fEB0P2BwAYwFAj+Yh4IyJO9zl0SNJzEfFeRPxE0hlJB4bpCwAwuKLm5O+Q9E7P/vm87Tq25213bXeXl5cLKgcAJtNNp2tsvyzpV/scWoiIF4YtICLaktqS1Gw2Y9jPAwCsu2nIR8R92/jcC5Lu6tm/M28DAIxRUdM1xyQdtv1h23sk7ZX0WkF9AQA2MewSygdtn5f0W5K+bfslSYqIU5L+UdKPJX1H0iOsrAGA8RtqCWVEPC/p+U2OHZF0ZJjPBwAMh3/xCgAJI+QBIGGEPAAkjJAHgIQR8gCQMEIeABJGyANAwgh5AEgYIQ8ACSPkASBhhDwAJIyQx2TqdKRGQ5qayradTtkVAYUY9jtegfrpdKT5eWllJdtfWsr2JanVKq8uoACM5DF5FhbWA37NykrWDiSGkMfkOXdusHagxgh5TJ7Z2cHagRoj5DF5jhyRpqevbZueztqBxBDymDytltRuS3Nzkp1t221uuiJJrK7BZGq1CHVMBEbyAJAwQh4AEkbIA0DCCHkASBghDwAJc0SUXcNVtpclLZVdRx87JV0pu4hNVLk2qdr1Udv2UNv2FFnbXETM9DtQqZCvKtvdiGiWXUc/Va5NqnZ91LY91LY9ZdXGdA0AJIyQB4CEEfJb0y67gBuocm1Steujtu2htu0ppTbm5AEgYYzkASBhhDwAJIyQvwHbX7J9yvaq7WZPe8P2/9o+kb/+tiq15ce+avuM7dO2Pz/u2jbU8rjtCz3n6oEy68lrOpifmzO2Hyu7nl62z9p+PT9X3QrUc9T2ZduLPW232/6e7Tfz7W0Vqq0S15vtu2z/q+0f539O/yhvH/u5I+RvbFHSFyW92ufYWxGxP3/94ZjrkjapzfY+SYclfVLSQUl/Y/uW8Zd3ja/1nKsXyywkPxffkHS/pH2SHsrPWZXcm5+rKqz3flrZddTrMUmvRMReSa/k+2V4WtfXJlXjentf0p9FxD5Jvynpkfw6G/u5I+RvICLeiIjTZdfRzw1qOyTpuYh4LyJ+IumMpAPjra7SDkg6ExFvR8QvJD2n7Jyhj4h4VdK7G5oPSXomf/+MpC+MtajcJrVVQkRcjIgf5u9/LukNSXeohHNHyG/fHtv/afv7tn+77GJ63CHpnZ7983lbmR61fTL/63Upf7XvUcXz0yskfdf2cdvzZReziV0RcTF//1NJu8ospo8qXW+y3ZD0KUn/oRLO3cSHvO2XbS/2ed1odHdR0mxEfErSn0r6B9u/UpHaxu4mdT4p6W5J+5WdtydKLbb67omITyubTnrE9mfKLuhGIluDXaV12JW63mz/sqR/kvTHEfHfvcfGde4m/uv/IuK+bfzOe5Ley98ft/2WpI9LGumNsu3UJumCpLt69u/M2wqz1TptPyXpX4qsZQvGfn4GEREX8u1l288rm17qd0+oTJds746Ii7Z3S7pcdkFrIuLS2vuyrzfbv6Qs4DsR8c9589jP3cSP5LfD9szazUzbH5O0V9Lb5VZ11TFJh21/2PYeZbW9VlYx+YW85kFlN4zL9ANJe23vsf0hZTepj5VckyTJ9kdsf3TtvaTPqfzz1c8xSQ/n7x+W9EKJtVyjKtebbUv6pqQ3IuKvew6N/9xFBK9NXsoukvPKRu2XJL2Ut/++pFOSTkj6oaTfq0pt+bEFSW9JOi3p/pLP4d9Lel3SSWUX+O4K/Hd9QNJ/5edooex6eur6mKQf5a9TVahN0rPKpj3+L7/evixph7KVIW9KelnS7RWqrRLXm6R7lE3FnMxz4kR+3Y393PFYAwBIGNM1AJAwQh4AEkbIA0DCCHkASBghDwAJI+QBIGGEPAAk7P8B5yRZLr0XHV8AAAAASUVORK5CYII=\n" - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - }, - { - "data": { - "text/plain": "
", - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAOC0lEQVR4nO3dT4gc553G8eeRTAITcnDQrCJsz4wJ2gXnsCY0Oi1Lwjqx1hfFgSwKzWLYwORg791hDgkEQQgJWQjJQgeEfZiN8cVYxCEb2xexsIvdAm9WchAWzows4VhjcgkMJNj57aF6rNGkezR/6q3q/tX3A0N1vTXq922X+6Hmfd96yxEhAEBOR9puAACgHEIeABIj5AEgMUIeABIj5AEgsXvabsB2x44di6WlpbabAQAz5dKlS+9HxPy4Y1MV8ktLSxoOh203AwBmiu31ScforgGAxAh5AEiMkAeAxAh5AEiMkAeAxAh5AGjT6qq0tCQdOVJtV1drffupmkIJAJ2yuiotL0ubm9X++nq1L0n9fi1VcCUPAG1ZWbkd8Fs2N6vymhDyANCW69f3V34AhDwAtGVhYX/lB0DIA0Bbzp2T5ubuLJubq8prQsgDQFv6fWkwkBYXJbvaDga1DbpKzK4BgHb1+7WG+k5cyQNAYoQ8ACRGyANAYoQ8ACRGyANAYoQ8ACRGyANAYoQ8ACRGyANAYoQ8ACRGyAO4U+EnFaFZrF0D4LYGnlSEZnElD+C2Bp5UhGYR8gBua+BJRWhWLSFv+7ztW7Yvbyv7lO2Xbb812t5bR10ACmrgSUVoVl1X8s9IOr2j7GlJr0bESUmvjvYBTLMGnlSEZtUS8hFxUdLvdxSfkfTs6PWzkr5cR10ACmrgSUVoVsnZNccj4t3R699JOl6wLgB1KfykIjSrkYHXiAhJMe6Y7WXbQ9vDjY2NJpoDAJ1RMuTfs31CkkbbW+N+KSIGEdGLiN78/HzB5gBA95QM+QuSnhi9fkLSiwXrAgCMUdcUyp9J+m9Jf2P7hu2vS/qupC/afkvSI6N9AECDahl4jYivTTj0D3W8PwDgYLjjFQASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBIDFCHgASI+QBILF7Sldge03SHyR9KOmDiOiVrhMAUCke8iNfiIj3G6oLADDS7e6a1VVpaUk6cqTarq623SIAqFUTIR+SfmX7ku3lnQdtL9se2h5ubGw00JyR1VVpeVlaX5ciqu3yMkEPIBVHRNkK7Psi4qbtv5L0sqR/jYiL43631+vFcDgs2p6PLC1Vwb7T4qK0ttZMGwCgBrYvTRrvLH4lHxE3R9tbkl6QdKp0nXty/fr+yqcZ3U4AJiga8rY/YfuTW68lfUnS5ZJ17tnCwv7KpxXdTgB2UfpK/rik/7L9v5Jek/RSRPyycJ17c+6cNDd3Z9ncXFU+S1ZWpM3NO8s2N6tyAJ1XdAplRLwt6W9L1nFg/X61XVmpumgWFqqA3yqfFZm6nQDUrql58tOp35+9UN9pYWH8APKsdTsBKKLb8+QzyNLtBKAIQn7W9fvSYFBN/bSr7WCwt79QmJUDpNft7posDtLttDUrZ2vQdmtWztb7AUiBK/muYlYO0AmEfFcxKwfoBEK+q7LcDAZgV4R8VzErB+gEQr6rDjMrB8DMYHZNl2W4GQzArnJcyTPfGwDGmv0reeZ7A8BEs38lz3xvAJho9kOe+d4AMNHshzzzvQFgotkPeeZ7A8BEsx/yzPcGgIlmf3aNxHxvAJhg9q/k68A8+27gPKODclzJHwbz7LuB84yOckS03YaP9Hq9GA6HzVa6tDT+GamLi9LaWrNtQTmcZyRm+1JE9MYdo7uGefbdwHlGRxHyzLPvBs4zOoqQZ559N3Ce0VGEPPPsu4HzjI5i4BUAZhwDrwDQUYQ8ACRGyANAYoQ8ACRWPORtn7Z91fY120+Xrg/ALli/p3OKrl1j+6ikH0v6oqQbkl63fSEi3ixZL4AxWL+nk0pfyZ+SdC0i3o6IP0l6TtKZwnUCGIfnIXdS6ZC/T9I72/ZvjMo+YnvZ9tD2cGNjo3BzgA5j/Z5Oan3gNSIGEdGLiN78/HzbzQHyYv2eTiod8jclPbBt//5RGWZJE4N1DAiWx/o9nVQ65F+XdNL2g7Y/JumspAuF60Sdtgbr1teliNuDdXWGcBN1gPV7Oqr42jW2H5P0b5KOSjofERMvG1i7Zgo18bANHugBHEqra9dExC8i4q8j4jO7BTxqUKLLo4nBOgYEgWJaH3hFTUp1eTQxWMeAIFAMIZ9FqTnQTQzWMSAIFEPIZ1Gqy6OJwToGBIFieGhIFgxeAp3FQ0O6YJq6PHYbAGY+PNCooguUoUFbXRsrK1UXzcJCFfBNd3nstgiWxAJZQMPorkG9dus2kuhSAgrYrbuGK3nU6yADwMyHB4qhTx712m3OO/PhgcYR8qjXbgPA0zQ4DHQE3TWo114GgNseHAY6hIFXAJhxzJMHgI4i5AEgMUIeABIj5AEgMUIeABIj5AEgMUIeABIj5DEZywIDM487XjHebksGc4cqMDO4ksd4pZ4ZC6BRhDzGK/XMWACNIuQxHssCAykQ8hiPZYGBFAh5jNfvS4NB9Wg+u9oOBgy6AjOG2TWYrN8n1IEZx5U8ACRGyANAYoQ8ACRWLORtf9v2TdtvjH4eK1UXAGC80gOvP4yI7xeuAwAwAd01AJBY6ZB/yvavbZ+3fe+4X7C9bHtoe7ixsVG4OQDQLY6Ig/9j+xVJnx5zaEXS/0h6X1JI+o6kExHxL7u9X6/Xi+FweOD2AEAX2b4UEb1xxw7VJx8Rj+yxAT+V9PPD1AUA2L+Ss2tObNt9XNLlUnUBAMYrObvme7YfVtVdsybpGwXrAgCMUSzkI+KfS703AGBvmEIJAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8gCQGCEPAIkR8sCsWV2VlpakI0eq7epq2y3CFLun7QYA2IfVVWl5WdrcrPbX16t9Ser322sXphZX8sAsWVm5HfBbNjercmCMQ4W87a/avmL7z7Z7O4590/Y121dtP3q4ZgKQJF2/vr9ydN5hr+QvS/qKpIvbC20/JOmspM9KOi3pJ7aPHrIuAAsL+ytH5x0q5CPiNxFxdcyhM5Kei4g/RsRvJV2TdOowdQGQdO6cNDd3Z9ncXFUOjFGqT/4+Se9s278xKvsLtpdtD20PNzY2CjUHSKLflwYDaXFRsqvtYMCgKya66+wa269I+vSYQysR8eJhGxARA0kDSer1enHY9wPS6/cJdezZXUM+Ih45wPvelPTAtv37R2UAgAaV6q65IOms7Y/bflDSSUmvFaoL04KbdICpc6iboWw/LulHkuYlvWT7jYh4NCKu2H5e0puSPpD0ZER8ePjmYmpxkw4wlRwxPd3gvV4vhsNh283AQSwtVcG+0+KitLbWdGuATrF9KSJ6445xxyvqwU06wFQi5FEPbtIBphIhj3pwkw4wlQh51IObdICpxFLDqA836QBThyt5AEiMkAeAxAh5AEiMkAeAxAh5AEiMkAeAxAh5AHmxMirz5AEkxcqokriSB5DVysrtgN+yuVmVdwghDyAnVkaVRMgDyIqVUSUR8gCyYmVUSYQ8gKxYGVUSs2sAZMbKqFzJA0BmhDwAJEbIA0BihDwAJEbIA0Bijoi22/AR2xuS1vf468ckvV+wOdOKz90tfO5uOejnXoyI+XEHpirk98P2MCJ6bbejaXzubuFzd0uJz013DQAkRsgDQGKzHPKDthvQEj53t/C5u6X2zz2zffIAgLub5St5AMBdEPIAkNhMhbztr9q+YvvPtns7jn3T9jXbV20/2lYbm2D727Zv2n5j9PNY220qxfbp0Tm9ZvvpttvTFNtrtv9vdH6HbbenJNvnbd+yfXlb2adsv2z7rdH23jbbWMKEz137d3umQl7SZUlfkXRxe6HthySdlfRZSacl/cT20eab16gfRsTDo59ftN2YEkbn8MeS/lHSQ5K+NjrXXfGF0fnNPl/8GVXf2+2elvRqRJyU9OpoP5tn9JefW6r5uz1TIR8Rv4mIq2MOnZH0XET8MSJ+K+mapFPNtg4FnJJ0LSLejog/SXpO1blGIhFxUdLvdxSfkfTs6PWzkr7caKMaMOFz126mQn4X90l6Z9v+jVFZZk/Z/vXoT750f8qOdPG8bglJv7J9yfZy241pwfGIeHf0+neSjrfZmIbV+t2eupC3/Yrty2N+OnUFd5f/Dv8u6TOSHpb0rqQftNpYlPB3EfE5VV1VT9r++7Yb1Jao5nl3Za537d/tqXv8X0Q8coB/dlPSA9v27x+Vzay9/new/VNJPy/cnLakO697FRE3R9tbtl9Q1XV1cfd/lcp7tk9ExLu2T0i61XaDmhAR7229ruu7PXVX8gd0QdJZ2x+3/aCkk5Jea7lNxYz+p9/yuKoB6Yxel3TS9oO2P6ZqcP1Cy20qzvYnbH9y67WkLynvOZ7kgqQnRq+fkPRii21pTInv9tRdye/G9uOSfiRpXtJLtt+IiEcj4ort5yW9KekDSU9GxIdttrWw79l+WNWfsGuSvtFuc8qIiA9sPyXpPyUdlXQ+Iq603KwmHJf0gm2p+o7+R0T8st0mlWP7Z5I+L+mY7RuSviXpu5Ket/11VcuP/1N7LSxjwuf+fN3fbZY1AIDEsnTXAADGIOQBIDFCHgASI+QBIDFCHgASI+QBIDFCHgAS+3+scocs/QSQcwAAAABJRU5ErkJggg==\n" - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - }, - { - "data": { - "text/plain": "
", - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAX8AAAD8CAYAAACfF6SlAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAASp0lEQVR4nO3df4hl533f8fdn15bSISGyrUEW+7tkSbsJbSwGRaGhGMutV66xnBKHNQNWEsFQkMEhBXeV+aMUulA3ELfGTtqhMlXKYEXkB7skThVZtnHyh2SP8kOxpCgeyx7tLrK1jm2lZYjStb79456VrlazuzNz79x75z7vF1zuOd/n7L3Pc7X6zNnnnHluqgpJUlv2jLsDkqTRM/wlqUGGvyQ1yPCXpAYZ/pLUIMNfkho0tPBPsjfJnyX5/W7/SJLHkqwm+a0k13X167v91a798LD6IEnanGGe+X8YeLpv/6PAx6rqR4DvAnd39buB73b1j3XHSZJGaCjhn2Q/8K+A/9HtB3gH8NvdIfcD7+u27+z26dpv746XJI3IG4b0Ov8F+AjwQ93+W4DvVdXFbv8csK/b3gecBaiqi0le7I7/9pVe/MYbb6zDhw8PqauS1IbHH3/821U1u1HbwOGf5D3AC1X1eJK3D/p6fa+7ACwAHDx4kJWVlWG9tCQ1IcnaldqGMe3zz4D3JvkG8AC96Z7/CtyQ5NIPl/3A+W77PHCg69gbgB8G/ubyF62qpaqaq6q52dkNf3BJkrZp4PCvqnuran9VHQZOAJ+rqnng88DPdofdBZzuts90+3TtnytXl5OkkdrJ+/z/HfDLSVbpzenf19XvA97S1X8ZOLmDfZAkbWBYF3wBqKovAF/otp8Fbt3gmL8D3j/M95UkbY2/4StJDTL8tX3Ly3D4MOzZ03teXh53jyRt0lCnfdSQ5WVYWID19d7+2lpvH2B+fnz9krQpnvlrexYXXw3+S9bXe3VJE8/w1/Y899zW6pImiuGv7Tl4cGt1SRPF8Nf2nDoFMzOvrc3M9OqSJp7hr+2Zn4elJTh0CJLe89KSF3ulXcK7fbR98/OGvbRLeeYvSQ0y/CWpQYa/JDXI8JekBhn+ktQgw1/S5rmY39TwVk9Jm+NiflPFM39Jm+NiflPF8Je0OS7mN1UMf0mb42J+U8Xwl7Q5LuY3VQx/SZvjYn5Txbt9JG2ei/lNjYHP/JP8QJIvJfmLJE8m+Q9d/UiSx5KsJvmtJNd19eu7/dWu/fCgfZAkbc0wpn1eAt5RVf8U+AngeJLbgI8CH6uqHwG+C9zdHX838N2u/rHuOEnSCA0c/tXzf7vdN3aPAt4B/HZXvx94X7d9Z7dP1357kgzaD0nS5g3lgm+SvUn+HHgBeBj4GvC9qrrYHXIO2Ndt7wPOAnTtLwJvGUY/JEmbM5Twr6rvV9VPAPuBW4F/NOhrJllIspJk5cKFCwP3UZL0qqHe6llV3wM+D/wUcEOSS3cT7QfOd9vngQMAXfsPA3+zwWstVdVcVc3Nzs4Os5uS1Lxh3O0zm+SGbvsfAP8CeJreD4Gf7Q67CzjdbZ/p9unaP1dVNWg/JEmbN4z7/G8G7k+yl94Pkwer6veTPAU8kOQ/An8G3Ncdfx/wv5KsAt8BTgyhD5KkLRg4/KvqCeBtG9SfpTf/f3n974D3D/q+kqTtc3kHSWqQ4S9JDTL8JalBhr8kNcjwl6QGGf6S1CDDX5IaZPhLUoMMf0lqkOEvSQ0y/CWpQYa/JDXI8JekBhn+ktQgw1+SGmT4S1KDDH9JapDhr+mzvAyHD8OePb3n5eVx90iaOMP4Dl9pciwvw8ICrK/39tfWevsA8/Pj65c0YTzz13RZXHw1+C9ZX+/VJb3C8Nd0ee65rdWlSbXD05eGv6bLwYNbq0uT6NL05doaVL06fTnEHwADh3+SA0k+n+SpJE8m+XBXf3OSh5N8tXt+U1dPko8nWU3yRJJbBu3D0HnBcPc6dQpmZl5bm5np1aXdYgTTl8M4878I/NuqOgbcBtyT5BhwEnikqo4Cj3T7AHcAR7vHAvAbQ+jD8IzgJ6520Pw8LC3BoUOQ9J6XlrzYq91lBNOXqaqhvRhAktPAJ7rH26vq+SQ3A1+oqh9N8t+77U93xz9z6bgrvebc3FytrKwMtZ9XdPhwL/Avd+gQfOMbo+mDpLYNKYeSPF5Vcxu1DXXOP8lh4G3AY8BNfYH+TeCmbnsfcLbvj53rapPBC4aSxm0E05dDC/8kPwj8DvBLVfW3/W3V++fFlv6JkWQhyUqSlQsXLgyrm9fmBUNJ4zaC6cuhhH+SN9IL/uWq+t2u/K1uuofu+YWufh440PfH93e116iqpaqaq6q52dnZYXRzc7xgKGkSzM/3pnhefrn3POTrVsO42yfAfcDTVfVrfU1ngLu67buA0331D3Z3/dwGvHi1+f6R84KhpAYMfME3yU8Dfwz8JfByV/4VevP+DwIHgTXg56rqO90Pi08Ax4F14Beq6qpXc0d6wVeSpsTVLvgOvLZPVf0JkCs0377B8QXcM+j7SpK2z9/wlaQGGf6S1CDDX5IaZPhLUoMMf0lqkOEvSQ0y/CWpQYa/JDXI8JekBhn+ktQgw1+SGmT4S1KDDH9JG1te7n2d4J49vWe/x3qqDLyqp6QptLwMCwuwvt7bX1vr7YPfbTElPPOX9HqLi68G/yXr6726poLhL+n1nntua3XtOoa/pNc7eHBrde06hr+k1zt1CmZmXlubmenVNRUMf0mvNz8PS0tw6BAkveelJS/2ThHv9pG0sfl5w36KeeYvSQ0y/CWpQUMJ/ySfSvJCkq/01d6c5OEkX+2e39TVk+TjSVaTPJHklmH0QZK0ecM68/+fwPHLaieBR6rqKPBItw9wB3C0eywAvzGkPkiSNmko4V9VXwS+c1n5TuD+bvt+4H199d+snkeBG5LcPIx+SJI2Zyfn/G+qque77W8CN3Xb+4Czfced62qSpBEZyQXfqiqgtvJnkiwkWUmycuHChR3qmSS1aSfD/1uXpnO65xe6+nngQN9x+7vaa1TVUlXNVdXc7OzsDnZTktqzk+F/Brir274LON1X/2B3189twIt900OSpBEYym/4Jvk08HbgxiTngH8P/CfgwSR3A2vAz3WHfwZ4N7AKrAO/MIw+SJI2byjhX1UfuELT7RscW8A9w3hfSdL2+Bu+ktQgw1+SGmT4a/T8YnBp7FzSWaPlF4NLE8Ezf42WXwwuTQTDX6PlF4NLE8Hw12j5xeDSRDD8NVp+Mbg0EQx/jZZfDC5NBO/20ej5xeDS2HnmL0kNMvwlqUGGvyQ1yPCXpAYZ/pLUIMNfkhpk+EtSgwx/SRqHMS9tPt3h77rxkibRpaXN19ag6tWlzUeYUdMb/hPw4UrShiZgafPpDf8J+HAlaUMTsLT59Ib/BHy4krShCVjafGzhn+R4kmeSrCY5OfQ3mIAPV5I2NAFLm48l/JPsBT4J3AEcAz6Q5NhQ32QCPlxJ2tAELG0+riWdbwVWq+pZgCQPAHcCTw3tHS59iIuLvamegwd7we9SwpImwZiXNh9X+O8DzvbtnwN+cujv4rrxkrShib3gm2QhyUqSlQsXLoy7O5I0VcYV/ueBA337+7vaK6pqqarmqmpudnZ2pJ2TpGk3rvD/MnA0yZEk1wEngDNj6oskNWcsc/5VdTHJh4CHgL3Ap6rqyXH0RZJaNLYvcK+qzwCfGdf7S1LLJvaCryRp5xj+0rRwFVttwdimfSQN0aVVbC8tZnhpFVvwd120Ic/8pWngKrbaIsNfmgauYqstMvylaeAqttoiw1+aBq5iqy0y/KVpMAFLBGt38W4faVq4iq22wDN/SWqQ4S9JDTL8JalBhr8kNcjwl6QGGf6S1CDDX5IaZPhLUoMMf0lqkOEvSQ0y/CWpQYa/JDXI8JekBg0U/knen+TJJC8nmbus7d4kq0meSfKuvvrxrraa5OQg7y9J2p5Bz/y/Avxr4Iv9xSTHgBPAjwHHgV9PsjfJXuCTwB3AMeAD3bGSpBEaaD3/qnoaIMnlTXcCD1TVS8DXk6wCt3Ztq1X1bPfnHuiOfWqQfkiStman5vz3AWf79s91tSvVJUkjdM0z/ySfBd66QdNiVZ0efpdeed8FYAHgoF9CLUlDdc0z/6p6Z1X9+AaPqwX/eeBA3/7+rnal+kbvu1RVc1U1Nzs7e+2RaPSWl+HwYdizp/e8vDzuHknapJ2a9jkDnEhyfZIjwFHgS8CXgaNJjiS5jt5F4TM71AftpOVlWFiAtTWo6j0vLPgDQNolBr3V82eSnAN+CviDJA8BVNWTwIP0LuT+b+Ceqvp+VV0EPgQ8BDwNPNgdq91mcRHW119bW1/v1SVNvFTVuPtwTXNzc7WysjLubqjfnj29M/7LJfDyy6Pvj6TXSfJ4Vc1t1OZv+Gp7rnQR3ovz0q5g+Gt7Tp2CmZnX1mZmenVJE8/w1/bMz8PSEhw61JvqOXSotz8/P+6eSdqEgX7DV42bnzfspV3KM39JapDhL0kNMvwlqUGGvyQ1yPCXpAYZ/pLUIMNfkhpk+EtSgwx/SWqQ4S9JDTL8JalBhr8kNcjwl6QGGf6S1CDDX5IaZPhLUoMMf0lqkOEvSQ0aKPyT/GqSv0ryRJLfS3JDX9u9SVaTPJPkXX31411tNcnJQd5fkrQ9g575Pwz8eFX9E+CvgXsBkhwDTgA/BhwHfj3J3iR7gU8CdwDHgA90x0qSRmig8K+qP6qqi93uo8D+bvtO4IGqeqmqvg6sArd2j9Wqeraq/h54oDtWkjRCw5zz/0XgD7vtfcDZvrZzXe1KdUnSCL3hWgck+Szw1g2aFqvqdHfMInARWB5Wx5IsAAsABw8eHNbLSpLYRPhX1Tuv1p7k54H3ALdXVXXl88CBvsP2dzWuUr/8fZeAJYC5ubna6BhJ0vYMerfPceAjwHurar2v6QxwIsn1SY4AR4EvAV8GjiY5kuQ6eheFzwzSB0nS1l3zzP8aPgFcDzycBODRqvo3VfVkkgeBp+hNB91TVd8HSPIh4CFgL/CpqnpywD5IkrYor87UTK65ublaWVkZdzckaVdJ8nhVzW3U5m/4SlKDDH9JapDhL0kNMvwlqUGGvyQ1yPCX1K7lZTh8GPbs6T0vD22Rgok36H3+krQ7LS/DwgKsd7+furbW2weYnx9fv0bEM39JbVpcfDX4L1lf79UbYPhLatNzz22tPmUMf0ltutJqwY2sImz4S2rTqVMwM/Pa2sxMr94Aw19Sm+bnYWkJDh2CpPe8tNTExV7wbh9JLZufbybsL+eZv6Sd0/B99JPOM39JO6Px++gnnWf+knZG4/fRTzrDX9LOaPw++kln+EvaGY3fRz/pDH9JO6Px++gnneEvaWc0fh/9pPNuH0k7p+H76CedZ/6S1CDDX5IaZPhLUoMMf0lqkOEvSQ1KVY27D9eU5AKwNu5+bMGNwLfH3YkxcextcuyT6VBVzW7UsCvCf7dJslJVc+Puxzg4dsfemt06dqd9JKlBhr8kNcjw3xlL4+7AGDn2Njn2XcY5f0lqkGf+ktQgw38ASX41yV8leSLJ7yW5oa/t3iSrSZ5J8q6++vGutprk5Hh6Prgk70/yZJKXk8xd1jbVY7/ctI7rkiSfSvJCkq/01d6c5OEkX+2e39TVk+Tj3WfxRJJbxtfzwSU5kOTzSZ7q/r5/uKvv/vFXlY9tPoB/Cbyh2/4o8NFu+xjwF8D1wBHga8De7vE14B8C13XHHBv3OLY59n8M/CjwBWCurz71Y7/sc5jKcV02xn8O3AJ8pa/2n4GT3fbJvr/77wb+EAhwG/DYuPs/4NhvBm7ptn8I+Ovu7/iuH79n/gOoqj+qqovd7qPA/m77TuCBqnqpqr4OrAK3do/Vqnq2qv4eeKA7dtepqqer6pkNmqZ+7JeZ1nG9oqq+CHznsvKdwP3d9v3A+/rqv1k9jwI3JLl5ND0dvqp6vqr+tNv+P8DTwD6mYPyG//D8Ir2f+ND7y3G2r+1cV7tSfZq0NvZpHde13FRVz3fb3wRu6ran9vNIchh4G/AYUzB+v8zlGpJ8FnjrBk2LVXW6O2YRuAgsj7JvO20zY5eqqpJM9W2DSX4Q+B3gl6rqb5O80rZbx2/4X0NVvfNq7Ul+HngPcHt1k37AeeBA32H7uxpXqU+ca439CqZi7FtwtfFOs28lubmqnu+mNV7o6lP3eSR5I73gX66q3+3Ku378TvsMIMlx4CPAe6tqva/pDHAiyfVJjgBHgS8BXwaOJjmS5DrgRHfsNGlt7NM6rms5A9zVbd8FnO6rf7C76+U24MW+6ZFdJ71T/PuAp6vq1/qadv/4x33FeTc/6F3MPAv8eff4b31ti/TuAnkGuKOv/m56dwx8jd70ydjHsc2x/wy9+cyXgG8BD7Uy9g0+i6kcV9/4Pg08D/y/7r/53cBbgEeArwKfBd7cHRvgk91n8Zf03Qm2Gx/ATwMFPNH3//m7p2H8/oavJDXIaR9JapDhL0kNMvwlqUGGvyQ1yPCXpAYZ/pLUIMNfkhpk+EtSg/4/XIFY5kNG/sYAAAAASUVORK5CYII=\n" - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - }, - { - "data": { - "text/plain": "
", - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAXkAAAD4CAYAAAAJmJb0AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjEsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy8QZhcZAAAUPklEQVR4nO3db4xdd33n8fcnSZPKbcGkGdLIjjNGG7oK2xbSIUpF/0BCS0gR5kGFIs0uKWVrKUoRULRsgqWV9kEk/qk0aFvYEWkVtLObZgMlFqKUJIVKfRCHcUhC/hDiQhw7OGSQFlqttYnSfPfBOa5vkrHNzP137pn3Sxrdc37nzpyvz9gfn3vu73xvqgpJUj+dNu0CJEnjY8hLUo8Z8pLUY4a8JPWYIS9JPXbGtAsYdM4559T8/Py0y5CkmbJ///4fVtXcWts6FfLz8/OsrKxMuwxJmilJDp5om5drJKnHDHlJ6jFDXpJ6zJCXpB4z5CWpxwx5bS7LyzA/D6ed1jwuL0+7ImmsOjWFUhqr5WXYvRuOHm3WDx5s1gEWF6dXlzRGnslr89iz53jAH3P0aDMu9dRIQj7J1iS3Jfl2kkeS/FqSs5PckeSx9vEVo9iXtGFPPLG+cakHRnUmfyPwlar6t8CvAI8A1wF3VdWFwF3tujQ9O3asb1zqgaFDPsnLgd8EbgKoqmer6kfALuDm9mk3A+8Ydl/SUG64AbZseeHYli3NuNRToziT3wmsAn+Z5JtJPpvkZ4Bzq+pI+5yngHPX+uYku5OsJFlZXV0dQTnSCSwuwtISXHABJM3j0pJvuqrXRhHyZwAXA5+uqtcB/5cXXZqp5oNk1/ww2apaqqqFqlqYm1uziZo0OouL8Pjj8PzzzaMBPx5OVe2MUYT8YeBwVe1r12+jCf0fJDkPoH18egT7ktR1x6aqHjwIVcenqhr0UzF0yFfVU8ChJL/YDl0OPAzsBa5ux64Gbh92X5JmgFNVO2VUN0O9F1hOcibwXeDdNP+B3JrkPcBB4J0j2pekLnOqaqeMJOSr6j5gYY1Nl4/i50uaITt2NJdo1hrXxHnHq6TRcqpqpxjykkbLqaqdYoMySaO3uGiod4Rn8pLUY4b8LPDGEkkb5OWarrMHuqQheCbfdd5YImkIhnzXeWOJpCEY8l1nD3RJQzDku84bSyQNwZDvOm8skTQEZ9fMAm8skbRBnslLUo8Z8pLUY4a8JPWYIS9JPWbIS1KPjSzkk5ye5JtJvtSu70yyL8mBJH/VfjSgJGmCRnkm/z7gkYH1jwKfrKp/A/wf4D0j3Jck6ScwkpBPsh34XeCz7XqAy4Db2qfcDLxjFPuSJP3kRnUm/6fAh4Dn2/WfB35UVc+164eBbWt9Y5LdSVaSrKyuro6oHEkSjCDkk7wNeLqq9m/k+6tqqaoWqmphbm5u2HIkSQNG0dbgDcDbk1wJ/DTwMuBGYGuSM9qz+e3AkyPYlyRpHYY+k6+q66tqe1XNA1cBf1dVi8DXgN9rn3Y1cPuw+5Ikrc8458n/Z+CPkxyguUZ/0xj3JUlaw0i7UFbV14Gvt8vfBS4Z5c+XJK2Pd7xKUo8Z8pLUY4a8JPWYIS9JPWbIS1KPGfKS1GOGvCT1mCEvST1myEtSjxnyUhcsL8P8PJx2WvO4vDztirQRHfw9jrStgaQNWF6G3bvh6NFm/eDBZh1gcXF6dWl9Ovp7TFVNbecvtrCwUCsrK9MuQ5qs+fkmEF7sggvg8ccnXY02aoq/xyT7q2phrW1erpGm7Ykn1jeuburo79GQl6Ztx471jaubOvp7NOSlabvhBtiy5YVjW7Y045odHf09GvLStC0uwtJSc+02aR6XlnzTddZ09PfoG6+SNOPG+sZrkvOTfC3Jw0keSvK+dvzsJHckeax9fMWw+5Ikrc8oLtc8B3ywqi4CLgWuTXIRcB1wV1VdCNzVrkuSJmjokK+qI1V1b7v8z8AjwDZgF3Bz+7SbgXcMuy9J0vqM9I3XJPPA64B9wLlVdaTd9BRw7gm+Z3eSlSQrq6uroyxHkja9kYV8kp8FPg+8v6r+aXBbNe/urvkOb1UtVdVCVS3Mzc2Nqhx1RQd7eUibyUh61yT5KZqAX66qL7TDP0hyXlUdSXIe8PQo9qUZ0tFeHtJmMorZNQFuAh6pqj8Z2LQXuLpdvhq4fdh9acbs2XM84I85erQZlzQRoziTfwPwH4BvJbmvHfsw8BHg1iTvAQ4C7xzBvjRLOtrLQ9pMhg75qvoHICfYfPmwP18zbMeOtbvy2ZNFmhjbGmh8OtrLQ9pMDHmNT0d7eUibiZ8MpfFaXDTUpSnyTP7FnNctqUc8kx/kvG5JPeOZ/CDndUvqGUN+kPO6JfWMIT+oo5/RKEkbZcgPcl63pJ4x5Ac5r1tSzzi75sWc1y2pRzyTl6QeM+QlqccMeUnqMUNeknrMkJekHjPkJanHxh7ySa5I8miSA0muG/f+JHWEHV07Yazz5JOcDvwZ8NvAYeAbSfZW1cPj3K+kKbOja2eM+0z+EuBAVX23qp4FbgF2jXmfkqbNjq6dMe6Q3wYcGlg/3I79qyS7k6wkWVldXR1zOZImwo6unTH1N16raqmqFqpqYW5ubtrlSBoFO7p2xrhD/kng/IH17e2YpD6zo2tnjDvkvwFcmGRnkjOBq4C9Y96npGmzo2tnjHV2TVU9l+SPgL8FTgf+oqoeGuc+JXWEHV07Yeythqvqy8CXx70fSdJLTf2NV0nS+BjyktRjhrwk9ZghL0k9ZshLUo8Z8pLUY4a8tF620NUMGfs8ealXbKGrGeOZvLQettDVjDHkpfWwha5mjCEvrYctdDVjDHlpPWyhqxljyEvrYQtdzRhn10jrZQtdzRDP5CWpxwx5SZqmMd9c5+UaSZqWCdxcN9SZfJKPJ/l2kgeS/HWSrQPbrk9yIMmjSd4yfKmS1DMTuLlu2Ms1dwD/rqp+GfgOcD1AkotoPrT7NcAVwJ8nOX3IfWnc7MkiTdYEbq4bKuSr6qtV9Vy7ejewvV3eBdxSVc9U1feAA8Alw+xLY3bsZePBg1B1/GWjQS+NzwRurhvlG69/APxNu7wNODSw7XA7pq6yJ4s0eRO4ue6UIZ/kziQPrvG1a+A5e4DngHWf9iXZnWQlycrq6up6v12jYk8WafImcHPdKWfXVNWbT7Y9ye8DbwMur6pqh58Ezh942vZ2bK2fvwQsASwsLNRaz9EE7NjRXKJZa1zS+Iz55rphZ9dcAXwIeHtVDb7W3wtcleSsJDuBC4F7htmXxsyeLFIvDXtN/r8BPwfckeS+JJ8BqKqHgFuBh4GvANdW1b8MuS+Nkz1ZpF7K8Sss07ewsFArKyvTLkOSZkqS/VW1sNY22xpIUo8Z8pLUY4a8JPWYIS9JPWbIS1KPzX7I21RLkk5otvvJT6AXsyTNstk+k7epliSd1GyHvE21JOmkZjvkJ9CLWZJm2WyHvE21JOmkZjvkbaolSSc127NrYOy9mCVpls32mbyk8fEelF6Y/TN5SaPnPSi94Zm8pJfyHpTeMOQlvZT3oPSGIS/ppbwHpTdGEvJJPpikkpzTrifJp5IcSPJAkotHsR9JE+I9KL0xdMgnOR/4HWDwddxbgQvbr93Ap4fdj6QJ8h6U3hjF7JpPAh8Cbh8Y2wV8rppPCb87ydYk51XVkRHsT9IkeA9KLwx1Jp9kF/BkVd3/ok3bgEMD64fbsbV+xu4kK0lWVldXhylHkvQipzyTT3In8AtrbNoDfJjmUs2GVdUSsASwsLBQw/wsSdILnTLkq+rNa40n+SVgJ3B/EoDtwL1JLgGeBM4fePr2dkySNEEbvlxTVd+qqldW1XxVzdNckrm4qp4C9gLvamfZXAr82OvxkjR542pr8GXgSuAAcBR495j2I0k6iZGFfHs2f2y5gGtH9bMlSRvjHa+S1GOGvDYPW+dqE7LVsDYHW+dqk/JMXpuDrXO1SRny2hxsnatNypDX5mDrXG1Shrw2B1vnapMy5LU52DpXm5Sza7R52DpXm5Bn8pLUY4a8JPWYIS9JPWbIS1KPGfKS1GOGvCT1mCEvST1myEvSqcxwm+qhQz7Je5N8O8lDST42MH59kgNJHk3ylmH3I0lTcaxN9cGDUHW8TfWMBP1QIZ/kTcAu4Feq6jXAJ9rxi4CrgNcAVwB/nuT0IWuVpMmb8TbVw57JXwN8pKqeAaiqp9vxXcAtVfVMVX2P5gO9LxlyX5vbDL9clGbajLepHjbkXw38RpJ9Sf4+yevb8W3AoYHnHW7HXiLJ7iQrSVZWV1eHLKenZvzlojTTZrxN9SlDPsmdSR5c42sXTYOzs4FLgf8E3Jok6ymgqpaqaqGqFubm5jb0h+i9GX+5KM20GW9TfcoulFX15hNtS3IN8IWqKuCeJM8D5wBPAucPPHV7O6aNmPGXi9JMO9a5dM+e5t/cjh1NwM9IR9NhL9d8EXgTQJJXA2cCPwT2AlclOSvJTuBC4J4h97V5zfjLRWnmLS7C44/D8883jzMS8DB8yP8F8KokDwK3AFdX4yHgVuBh4CvAtVX1L0Pua/Oa8ZeLkqZnqA8NqapngX9/gm03AKbQKMz4y0VJ0+MnQ80KP9VI0gbY1kCSesyQl6QeM+QlqccMeUnqMUNeknrMkJekHjPkJanHDHlJ6jFDXpJ6zJCXpB4z5CWpxwx5SeoxQ16SesyQl6QeM+SlzWR5Gebn4bTTmkc/DL737CcvbRbLy7B79/EPhT94sFkHP6ugx4Y6k0/y2iR3J7kvyUqSS9rxJPlUkgNJHkhy8WjKlbRhe/YcD/hjjh5txtVbw16u+RjwX6vqtcB/adcB3krz4d0XAruBTw+5H0nDeuKJ9Y2rF4YN+QJe1i6/HPh+u7wL+Fz7od53A1uTnDfkviQNY8eO9Y2rF4YN+fcDH09yCPgEcH07vg04NPC8w+3YSyTZ3V7qWVldXR2yHEkndMMNsGXLC8e2bGnG1VunDPkkdyZ5cI2vXcA1wAeq6nzgA8BN6y2gqpaqaqGqFubm5tb/J5D0k1lchKUluOACSJrHpSXfdO25VNXGvzn5MbC1qipJgB9X1cuS/Hfg61X1v9rnPQq8saqOnOznLSws1MrKyobrkaTNKMn+qlpYa9uwl2u+D/xWu3wZ8Fi7vBd4VzvL5lKa8D9pwEuSRm/YefJ/CNyY5Azg/9HMpAH4MnAlcAA4Crx7yP1IkjZgqJCvqn8AfnWN8QKuHeZnS5KGZ1sDSeoxQ16SesyQl6QeM+Q1G+yeKG2IXSjVfXZPlDbMM3l1n90TpQ0z5NV9dk+UNsyQV/fZPVHaMENe3Wf3RGnDDHl1n90TpQ1zdo1mw+KioS5tgGfyktRjhrwk9ZghL0k9ZshLUo8Z8pLUY0N9xuuoJVkFDk6xhHOAH05x/6fS9fqg+zVa3/C6XuNmrO+Cqppba0OnQn7akqyc6MNwu6Dr9UH3a7S+4XW9Rut7IS/XSFKPGfKS1GOG/AstTbuAU+h6fdD9Gq1veF2v0foGeE1eknrMM3lJ6jFDXpJ6zJAHkrw2yd1J7kuykuSSdjxJPpXkQJIHklw8xRrfm+TbSR5K8rGB8evb+h5N8pZp1dfW8sEkleScdr1Lx+/j7fF7IMlfJ9k6sK0TxzDJFW0NB5JcN606Buo5P8nXkjzc/r17Xzt+dpI7kjzWPr5iynWenuSbSb7Uru9Msq89jn+V5Mwp17c1yW3t379HkvzaRI9hVW36L+CrwFvb5SuBrw8s/w0Q4FJg35TqexNwJ3BWu/7K9vEi4H7gLGAn8I/A6VOq8Xzgb2luZjunS8evreV3gDPa5Y8CH+3SMQROb/f9KuDMtqaLpnW82prOAy5ul38O+E57vD4GXNeOX3fsWE6xzj8G/ifwpXb9VuCqdvkzwDVTru9m4D+2y2cCWyd5DD2TbxTwsnb55cD32+VdwOeqcTewNcl5U6jvGuAjVfUMQFU9PVDfLVX1TFV9DzgAXDKF+gA+CXyI5lge05XjR1V9taqea1fvBrYP1NiFY3gJcKCqvltVzwK3tLVNTVUdqap72+V/Bh4BtrV13dw+7WbgHdOpEJJsB34X+Gy7HuAy4Lb2KdOu7+XAbwI3AVTVs1X1IyZ4DA35xvuBjyc5BHwCuL4d3wYcGnje4XZs0l4N/Eb7EvTvk7y+He9EfUl2AU9W1f0v2tSJ+tbwBzSvMKA7NXaljjUlmQdeB+wDzq2qI+2mp4Bzp1QWwJ/SnFw8367/PPCjgf/Qp30cdwKrwF+2l5Q+m+RnmOAx3DSfDJXkTuAX1ti0B7gc+EBVfT7JO2n+131zh+o7Azib5pLH64Fbk7xqguWdqr4P01wOmaqT1VhVt7fP2QM8ByxPsrZZluRngc8D76+qf2pOlhtVVUmmMg87yduAp6tqf5I3TqOGn8AZwMXAe6tqX5IbaS7P/KtxH8NNE/JVdcLQTvI54H3t6v+mfekHPElzrfmY7e3YpOu7BvhCNRfw7knyPE2To6nXl+SXaM5W7m//8W8H7m3fvJ5YfSer8Zgkvw+8Dbi8PZYw4RpPoit1vECSn6IJ+OWq+kI7/IMk51XVkfby29Mn/glj9Qbg7UmuBH6a5pLrjTSXBc9oz+anfRwPA4eral+7fhtNyE/sGHq5pvF94Lfa5cuAx9rlvcC72lkilwI/HniJNUlfpHnzlSSvpnnz5odtfVclOSvJTuBC4J5JFlZV36qqV1bVfFXN0/ylvriqnqI7x48kV9C8rH97VR0d2DT1Y9j6BnBhOzPkTOCqtrapaa9v3wQ8UlV/MrBpL3B1u3w1cPukawOoquuranv79+4q4O+qahH4GvB7064PoP13cCjJL7ZDlwMPM8ljOM13nbvyBfw6sJ9mRsM+4Ffb8QB/RjPr4VvAwpTqOxP4H8CDwL3AZQPb9rT1PUo7Q2jKx/Jxjs+u6cTxa2s5QHPN+7726zNdO4Y0s5G+09aypwO/y1+neSP9gYHjdiXNde+7aE6G7gTO7kCtb+T47JpX0fxHfYDmlflZU67ttcBKexy/CLxiksfQtgaS1GNerpGkHjPkJanHDHlJ6jFDXpJ6zJCXpB4z5CWpxwx5Seqx/w/1zAjqr4SiiAAAAABJRU5ErkJggg==\n" - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "# PCA/tsne for high dimensional data (60k dims) does not really work well...\n", - "pca = PCA(2)\n", - "tsne = TSNE(n_components=2)\n", - "tsne_fit_1 = tsne.fit_transform(perturb_vector_1)\n", - "tsne_fit_0 = tsne.fit_transform(perturb_vector_0)\n", - "pca_fit_1 = pca.fit_transform(perturb_vector_1)\n", - "pca_fit_0 = pca.fit_transform(perturb_vector_0)\n", - "label_1 = kmeans_1.predict(perturb_vector_1)\n", - "label_0 = kmeans_0.predict(perturb_vector_0)\n", - "colors = ['r','g','b','c','y']\n", - "for i in np.arange(num_clusters):\n", - " plt.scatter(pca_fit_1[label_1==i][:,0], pca_fit_1[label_1==i][:,1], color=colors[i])\n", - "plt.show()\n", - "for i in np.arange(num_clusters):\n", - " plt.scatter(pca_fit_0[label_0==i][:,0], pca_fit_0[label_0==i][:,1], color=colors[i])\n", - "plt.show()\n", - "for i in np.arange(num_clusters):\n", - " plt.scatter(tsne_fit_1[label_1==i][:,0], tsne_fit_1[label_1==i][:,1], color=colors[i])\n", - "plt.show()\n", - "for i in np.arange(num_clusters):\n", - " plt.scatter(tsne_fit_0[label_0==i][:,0], tsne_fit_0[label_0==i][:,1], color=colors[i])\n", - "plt.show()" - ], - "metadata": { - "collapsed": false, - "pycharm": { - "name": "#%%\n" - } - } - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[ 1981 3535 5008 8364 11864 13916 15069 15253 21367 24170 28987 33326\n", - " 34237 37660 39210 40684 43098 45503]\n" - ] - } - ], - "source": [ - "print(genes_pos_0[0][0])" - ] - }, - { - "cell_type": "code", - "execution_count": 129, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "CF=0, Number of overlapping positive genes for cluster 0: 0\n", - "CF=0, Number of overlapping negative genes for cluster 0: 0\n", - "CF=0, Number of overlapping positive genes for cluster 1: 0\n", - "CF=0, Number of overlapping negative genes for cluster 1: 0\n", - "CF=0, Number of overlapping positive genes for cluster 2: 0\n", - "CF=0, Number of overlapping negative genes for cluster 2: 0\n", - "CF=0, Number of overlapping positive genes for cluster 3: 0\n", - "CF=0, Number of overlapping negative genes for cluster 3: 0\n", - "CF=0, Number of overlapping positive genes for cluster 4: 0\n", - "CF=0, Number of overlapping negative genes for cluster 4: 0\n", - "CF=1, Number of overlapping positive genes for cluster 0: 2\n", - "CF=1, Number of overlapping negative genes for cluster 0: 0\n", - "CF=1, Number of overlapping positive genes for cluster 1: 0\n", - "CF=1, Number of overlapping negative genes for cluster 1: 0\n", - "CF=1, Number of overlapping positive genes for cluster 2: 0\n", - "CF=1, Number of overlapping negative genes for cluster 2: 0\n", - "CF=1, Number of overlapping positive genes for cluster 3: 0\n", - "CF=1, Number of overlapping negative genes for cluster 3: 0\n", - "CF=1, Number of overlapping positive genes for cluster 4: 0\n", - "CF=1, Number of overlapping negative genes for cluster 4: 0\n" - ] - } - ], - "source": [ - "# within each cluster determine the overlapping gene indices that have been perturbed\n", - "for i in np.arange(num_clusters):\n", - " in_cluster_0 = np.argwhere(kmeans_0.labels_==i).squeeze()\n", - " genes_pos_0_sets = []\n", - " genes_neg_0_sets = []\n", - " for j in in_cluster_0:\n", - " genes_pos_0_sets.append(set(genes_pos_0[j][0]))\n", - " genes_neg_0_sets.append(set(genes_neg_0[j][0]))\n", - " overlap_pos_0 = list(genes_pos_0_sets[0].intersection(*genes_pos_0_sets))\n", - " overlap_neg_0 = list(genes_neg_0_sets[0].intersection(*genes_neg_0_sets))\n", - " print(\"CF=0, Number of overlapping positive genes for cluster {}: {}\".format(i,len(overlap_pos_0)))\n", - " print(\"CF=0, Number of overlapping negative genes for cluster {}: {}\".format(i,len(overlap_neg_0)))\n", - "\n", - "for i in np.arange(num_clusters):\n", - " in_cluster_1 = np.argwhere(kmeans_1.labels_==i).squeeze()\n", - " genes_pos_1_sets = []\n", - " genes_neg_1_sets = []\n", - " for j in in_cluster_1:\n", - " genes_pos_1_sets.append(set(genes_pos_1[j][0]))\n", - " genes_neg_1_sets.append(set(genes_neg_1[j][0]))\n", - " overlap_pos_1 = list(genes_pos_1_sets[0].intersection(*genes_pos_1_sets))\n", - " overlap_neg_1 = list(genes_neg_1_sets[0].intersection(*genes_neg_1_sets))\n", - " print(\"CF=1, Number of overlapping positive genes for cluster {}: {}\".format(i,len(overlap_pos_1)))\n", - " print(\"CF=1, Number of overlapping negative genes for cluster {}: {}\".format(i,len(overlap_neg_1)))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Not many overlapping genes which means that clustering doesnt give <> correlation to \n", - "# highly perturbed genes, but there may be some statistical methods to find which gene perturbations\n", - "# are <> in each cluster" - ] - }, - { - "cell_type": "code", - "execution_count": 126, - "metadata": {}, - "outputs": [ - { - "data": { - "image/png": "\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "fig, ax = plt.subplots()\n", - "pos = np.arange(len(num_pos)) + 1\n", - "bp = ax.boxplot(num_neg, sym='k+', positions=pos)\n", - "\n", - "ax.set_xlabel('threshold value')\n", - "ax.set_ylabel('# indices')\n", - "plt.setp(bp['whiskers'], color='k', linestyle='-')\n", - "plt.setp(bp['fliers'], markersize=3.0)\n", - "plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 60, - "metadata": {}, - "outputs": [], - "source": [ - "thresholds = pickle.load(open(\"threshold.complete.pkl\",'rb'))" - ] - }, - { - "cell_type": "code", - "execution_count": 61, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "683\n" - ] - } - ], - "source": [ - "print(thresholds['positive threshold indices'][0][0].shape[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "1120\n" - ] - } - ], - "source": [ - "diff = thresholds['perturbation vector']\n", - "print(len(diff))" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": 16, - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "array([0, 1, 1, ..., 0, 1, 0], dtype=int32)" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "kmeans.labels_ " - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.7.1" - } - }, - "nbformat": 4, - "nbformat_minor": 4 -} \ No newline at end of file diff --git a/Pilot1/NT3/nt3_cf/nt3.ipynb b/Pilot1/NT3/nt3_cf/nt3.ipynb deleted file mode 100644 index 6e5cd05c..00000000 --- a/Pilot1/NT3/nt3_cf/nt3.ipynb +++ /dev/null @@ -1,426 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "id": "hazardous-tokyo", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "TF version: 2.2.0\n", - "Eager execution enabled: False\n" - ] - } - ], - "source": [ - "import tensorflow as tf\n", - "tf.get_logger().setLevel(40) # suppress deprecation messages\n", - "tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs \n", - "from tensorflow.keras.models import Model, load_model\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import os\n", - "# os.environ[\"CUDA_VISIBLE_DEVICES\"]=\"1\"\n", - "from time import time\n", - "from alibi.explainers import CounterFactual, CounterFactualProto\n", - "print('TF version: ', tf.__version__)\n", - "print('Eager execution enabled: ', tf.executing_eagerly()) # False\n", - "import pickle" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "polar-netherlands", - "metadata": {}, - "outputs": [], - "source": [ - "model_nt3 = tf.keras.models.load_model('/vol/ml/shahashka/xai-geom/nt3/nt3.autosave.model')\n", - "with open('/vol/ml/shahashka/xai-geom/nt3/nt3.autosave.data.pkl', 'rb') as pickle_file:\n", - " X_train,Y_train,X_test,Y_test = pickle.load(pickle_file)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "satellite-passage", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1, 60483, 1)\n" - ] - }, - { - "ename": "UnknownError", - "evalue": "2 root error(s) found.\n (0) Unknown: Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.\n\t [[{{node conv1d/conv1d}}]]\n\t [[activation_4/Softmax/_83]]\n (1) Unknown: Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.\n\t [[{{node conv1d/conv1d}}]]\n0 successful operations.\n0 derived errors ignored.", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mUnknownError\u001b[0m Traceback (most recent call last)", - "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[0mfeature_range\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 11\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 12\u001b[0;31m cf = CounterFactual(model_nt3, shape=shape_cf, target_proba=target_proba, tol=tol,\n\u001b[0m\u001b[1;32m 13\u001b[0m \u001b[0mtarget_class\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtarget_class\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmax_iter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_iter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlam_init\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlam_init\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0mmax_lam_steps\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmax_lam_steps\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlearning_rate_init\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mlearning_rate_init\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/vol/ml/shahashka/anaconda3/envs/xai-geom-tf/lib/python3.8/site-packages/alibi/explainers/counterfactual.py\u001b[0m in \u001b[0;36m__init__\u001b[0;34m(self, predict_fn, shape, distance_fn, target_proba, target_class, max_iter, early_stop, lam_init, max_lam_steps, tol, learning_rate_init, feature_range, eps, init, decay, write_dir, debug, sess)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 178\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_classes\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpredict_fn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzeros\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 180\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[0;31m# flag to keep track if explainer is fit or not\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/vol/ml/shahashka/anaconda3/envs/xai-geom-tf/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_v1.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)\u001b[0m\n\u001b[1;32m 955\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 956\u001b[0m \u001b[0mfunc\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_select_training_loop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 957\u001b[0;31m return func.predict(\n\u001b[0m\u001b[1;32m 958\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 959\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/vol/ml/shahashka/anaconda3/envs/xai-geom-tf/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_arrays.py\u001b[0m in \u001b[0;36mpredict\u001b[0;34m(self, model, x, batch_size, verbose, steps, callbacks, **kwargs)\u001b[0m\n\u001b[1;32m 706\u001b[0m x, _, _ = model._standardize_user_data(\n\u001b[1;32m 707\u001b[0m x, check_steps=True, steps_name='steps', steps=steps)\n\u001b[0;32m--> 708\u001b[0;31m return predict_loop(\n\u001b[0m\u001b[1;32m 709\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 710\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/vol/ml/shahashka/anaconda3/envs/xai-geom-tf/lib/python3.8/site-packages/tensorflow/python/keras/engine/training_arrays.py\u001b[0m in \u001b[0;36mmodel_iteration\u001b[0;34m(model, inputs, targets, sample_weights, batch_size, epochs, verbose, callbacks, val_inputs, val_targets, val_sample_weights, shuffle, initial_epoch, steps_per_epoch, validation_steps, validation_freq, mode, validation_in_fit, prepared_feed_values_from_dataset, steps_name, **kwargs)\u001b[0m\n\u001b[1;32m 384\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 385\u001b[0m \u001b[0;31m# Get outputs.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 386\u001b[0;31m \u001b[0mbatch_outs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mf\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mins_batch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 387\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_outs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 388\u001b[0m \u001b[0mbatch_outs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mbatch_outs\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/vol/ml/shahashka/anaconda3/envs/xai-geom-tf/lib/python3.8/site-packages/tensorflow/python/keras/backend.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, inputs)\u001b[0m\n\u001b[1;32m 3629\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_make_callable\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfeed_arrays\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfeed_symbols\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msymbol_vals\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msession\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 3630\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 3631\u001b[0;31m fetched = self._callable_fn(*array_vals,\n\u001b[0m\u001b[1;32m 3632\u001b[0m run_metadata=self.run_metadata)\n\u001b[1;32m 3633\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_call_fetch_callbacks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mfetched\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fetches\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", - "\u001b[0;32m/vol/ml/shahashka/anaconda3/envs/xai-geom-tf/lib/python3.8/site-packages/tensorflow/python/client/session.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1468\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1469\u001b[0m \u001b[0mrun_metadata_ptr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf_session\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mTF_NewBuffer\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mrun_metadata\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1470\u001b[0;31m ret = tf_session.TF_SessionRunCallable(self._session._session,\n\u001b[0m\u001b[1;32m 1471\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_handle\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1472\u001b[0m run_metadata_ptr)\n", - "\u001b[0;31mUnknownError\u001b[0m: 2 root error(s) found.\n (0) Unknown: Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.\n\t [[{{node conv1d/conv1d}}]]\n\t [[activation_4/Softmax/_83]]\n (1) Unknown: Failed to get convolution algorithm. This is probably because cuDNN failed to initialize, so try looking to see if a warning log message was printed above.\n\t [[{{node conv1d/conv1d}}]]\n0 successful operations.\n0 derived errors ignored." - ] - } - ], - "source": [ - "shape_cf = (1,) + X_train.shape[1:] \n", - "print(shape_cf)\n", - "target_proba = 0.9\n", - "tol = 0.1 # want counterfactuals with p(class)>0.90\n", - "target_class = 'other' # any class other than will do\n", - "max_iter = 1000\n", - "lam_init = 1e-1\n", - "max_lam_steps = 20\n", - "learning_rate_init = 0.1\n", - "feature_range = (0,1)\n", - "\n", - "cf = CounterFactual(model_nt3, shape=shape_cf, target_proba=target_proba, tol=tol,\n", - " target_class=target_class, max_iter=max_iter, lam_init=lam_init,\n", - " max_lam_steps=max_lam_steps, learning_rate_init=learning_rate_init,\n", - " feature_range=feature_range)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "closing-quarterly", - "metadata": {}, - "outputs": [], - "source": [ - "shape = X_train[0].shape[0]\n", - "results=[]\n", - "for i in np.arange(0,5):\n", - " x_sample=X_train[i:i+1]\n", - " print(x_sample.shape)\n", - " start = time()\n", - " try:\n", - " explanation = cf.explain(x_sample)\n", - " print('Counterfactual prediction: {}, {}'.format(explanation.cf['class'], explanation.cf['proba']))\n", - " print(\"Actual prediction: {}\".format(model_nt3.predict(x_sample)))\n", - " results.append([explanation.cf['X'],explanation.cf['class'], explanation.cf['proba']])\n", - " # if counterfactual not found make a dummy array \n", - " except IndexError:\n", - " dummy = np.empty(x_sample.shape)\n", - " dummy[:] = np.nan\n", - " results.append(dummy)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "dirty-hebrew", - "metadata": {}, - "outputs": [], - "source": [ - "print(X_train.shape, X_test.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "reliable-video", - "metadata": {}, - "outputs": [], - "source": [ - "pickle.dump(results, open(\"small_cf_test.pkl\", \"wb\"))" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "rapid-beauty", - "metadata": {}, - "outputs": [], - "source": [ - "for i in range(len(results)):\n", - " plt.figure(figsize=(20, 4))\n", - " sample = X_train[i].flatten()\n", - " y = results[i][0].flatten()\n", - " x = np.arange(y.shape[0])\n", - " plt.plot(x,y,alpha=0.5, label='counterfactual')\n", - " plt.plot(x,sample,alpha=0.5, label='input')\n", - " plt.plot(x,sample-y, label='diff')\n", - " props = dict(boxstyle='round', facecolor='wheat', alpha=1)\n", - " prediction = model_nt3.predict(X_test[i:i+1])\n", - " plt.text(0.05, 0.95, \"original input: {} {} \\n counterfactual: {} {}\".format(np.argmax(prediction), \n", - " prediction,results[i][1] ,results[i][2]), \n", - " fontsize=16,\n", - " verticalalignment='top', bbox=props)\n", - " plt.legend()\n", - " plt.show()" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "defensive-seeker", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.4542566392434886\n", - "0.4086305624547959\n", - "\n", - "\n", - "0.5043837824693964\n", - "0.46664387832620924\n", - "\n", - "\n", - "0.6335310691943951\n", - "0.6069955054804324\n", - "\n", - "\n", - "0.6563886318117149\n", - "0.5876853023454891\n", - "\n", - "\n", - "0.16572694427639978\n", - "0.1858917752190248\n", - "\n", - "\n" - ] - } - ], - "source": [ - "from scipy.stats import pearsonr\n", - "Y_flag = np.argmax(Y_test,axis=1)\n", - "\n", - "for r in range(len(results)):\n", - " pearson_0 = []\n", - " pearson_1 = []\n", - " for i in range(len(Y_flag)):\n", - " if Y_flag[i]==0:\n", - " pearson_0.append(pearsonr(results[r][0].flatten(), X_test[i].flatten())[0])\n", - " else:\n", - " pearson_1.append(pearsonr(results[r][0].flatten(), X_test[i].flatten())[0])\n", - "\n", - " print(np.average(pearson_0))\n", - " print(np.average(pearson_1))\n", - " print(\"\\n\")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "corporate-future", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[0.36569104 0.634309 ]\n" - ] - } - ], - "source": [ - "Y_predict = model_nt3.predict(X_test)\n", - "noise = np.random.uniform(-0.3, 0.3, X_test.shape)\n", - "Y_predict_noise = model_nt3.predict(X_test+noise)\n", - "print(Y_predict_noise[0])" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "subtle-blood", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "Text(0.5, 1.0, 'Class 0 predictions with uniform random noise [-0.2,0.2]')" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - }, - { - "data": { - "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlcAAAG5CAYAAACjnRHrAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8QVMy6AAAACXBIWXMAAAsTAAALEwEAmpwYAAAvdElEQVR4nO3deZhcZZn38e9tEogDYQ0uEKQRCWFLArZISAjIMsoiigODCIhcDmEZBgd9fYm+zoAOKgIDCCIMCiaARAXRAUQFZEtYxA6ELYRAIEBkX5OAQYj3+8c5HStNL5XkVFd3+H6uq67uOvXUOfepU8uvnvPUOZGZSJIkqRrvanYBkiRJKxPDlSRJUoUMV5IkSRUyXEmSJFXIcCVJklQhw5UkSVKFDFdquIg4MSIuaXYdzVC77hHxgYhYGBEDlmM+X4+IH1df4YqpZ50iIiPiQ71Qy44R8VDN9c0i4u6IWBARxzZ6+b0tIiZFxEnNrqM7jXjeluv914iYW+V8myEivhkRr5WvkYHNrkfVMVypEhHxuYhoKz9on46I30bEuCbV0hIRN0bE6xExKyJ2a0YdHWXmE5m5emYu7q5dROwcEfM63Pc7mfkvja1w2XVcp4i4KSKaUmdmTs3MzWom/V/gpswckplnNaOmd7oGPm9PycyW7hqU70mPl+Hl1xGxThft3hMRUyLiqYh4NSJujYiPdjPfiIjvRcSL5eWUiIgu2m4fEddFxEsR8XxEXBYR72+/PTNPALasb5XVnxiutMIi4svAmcB3gPcCHwB+CHyqSSVNAe4G1gX+H3B5RKy3ojP1m2W/sxHwwPLcsept7XOnd0XElsD/AIdQvCe9TvGe1JnVgT8BHwbWASYDv4mI1btoPwH4NDAKGAnsDRzRRdu1gfOBForn4wLgJ8u0MuqfMtOLl+W+AGsCC4H9u2lzInBJzfXLgGeAV4FbgC1rbtsTmEnxJvRn4P+U04cCVwOvAC8BU4F3dbKs4cAbwJCaaVOBI7up7XLg5+Uy7wJG1dw+FzgeuLec70Bge+C2spZ7gJ1r2m8M3FzO6zrgB+3rTvEGm8DA8vo6FG+0TwEvA78GVgP+AvytfFwXAut38hjuQxEcXgFuAjbvUPP/KWt+tVy3wcv4OH4TOLv8fxDwGkVvAcC7gUUUHxxL1gn4NrC4vG0h8IOyfQJHAg+X63kOEF1sj0nASTXXdwbm1bluS9oCN3SoZTjFc/Ui4HngceAb7esOfAG4FTijfFxOKmv5IfDbch63Au+j+CLxMjAL2Kab530C/1qu92PltO8DTwLzgenAjh2ei78oa1xQbt/Wmtu3oXh+LijX+2cdHqvDgUfK+q8E1u9Qy9FlLQuA/wI2AW4va/kFsEoX6/EFYBpwWrnejwF71Ny+frm8l8rlH97Zax8YDFwCvEjx/PsT8N6a95ELgKcpXvcnAQPqeY500eY7wKU11zcB/krN+0IP958PfLiL224DJtRc/yJwR53z3RZY0GFaCzXvC15Wjos9V1pRYyjeNH+1DPf5LbAp8B6KD4uf1tx2AXBEZg4BtqL4kAT4CjAPWI/im+jXKd6QOtoSeDQzF9RMu4fuu94/RRH41gEuBX4dEYNqbj8Q2AtYq1z2byje/Neh+KD/ZU3P2KUUH5pDKT7ADu1muRcD/1DW9h7gjMx8DdgDeCqL3W2rZ+ZTtXeKiOEUvXP/TvF4XANcFRGr1DT7Z+ATFGFvJMUHJNT/ON5MEVYAPkIRhncqr48BHsrMl2vvkJn/jyKsHVPWfUzNzXuX8xlV1vbxbh6XnnS1brW17NKhltnA2RQf4h8s1+XzwGE1d/so8CjFtvh2zbK+QbE936AII3eV1y8HTu+h1k+X892ivP4nYDR/f65dFhGDa9rvQxGa1qIILD8AKLftrymeM+tQPF//qf1OEbEL8N2y3vdThMefdajlExS9M9tT7DI9HzgI2JDitXZgN+vxUeChcr1PAS6o2RU2heI5tT6wH/CdiNi1k3kcSvH4b0jRq3wkxRcJKHqL3gI+RBEi/xFYkd2JW1K87gHIzDkU4Wp4T3eMiNHAKhRBscd50/P7S63xLGdvqvoXw5VW1LrAC5n5Vr13yMwLM3NBZr5B8c12VESsWd78JrBFRKyRmS9n5l01098PbJSZb2YxvqazULA6RY9GrVeBId2UND0zL8/MNyk+LAdTfAC1Oyszn8zMvwAHA9dk5jWZ+bfMvA5oA/aMiA9QBIj/yMw3MvMW4KrOFliOu9iDokft5XKdbu6mxloHAL/JzOvKmk+j6E3aoUPNT2XmS2UNo8vp9T6OtwObRsS6FB8IFwAblLtKdqIIX8vi5Mx8JTOfAG6sqWd5dLVuXSoH3B8AfK187s0F/ptit1G7pzLz7Mx8q9zWAL/KzOmZuYjiC8SizLwoizFmP6cIAt35bma+1D6/zLwkM18sl/HfwKpA7TixaeVzazFFkBpVTt+eogfxzHK7XU4R1NodBFyYmXeVr6uvAWMioqWmzfcyc35mPgDcD1ybmY9m5qsUX3i6W5fHM/NHZV2TKZ5D742IDYFxwPGZuSgzZwA/ZunHtd2bFO8XH8rMxeXjOj8i3kvxWvj3zHwtM5+j6EH8bDf19GR53geIiDUoHvdvlo9LPfN+FVi9q3FXNfMeCfwn8NXu2mnlYLjSinoRGFrvmJKIGBARJ0fEnIiYT7GbB4pvxFB8G98TeDwibo6IMeX0Uym+SV4bEY9GxMQuFrEQWKPDtDUodoV05cn2fzLzb/z9W/jbbqcYN7F/RLzSfqH4cHl/eZ+Xy96ndo93scwNgZc69v7Uaf3a+ZY1PwlsUNPmmZr/X6f4QIA6H8cyDLRRBKnxFGHqNmAsyxeuuqpneSzPvIZS9EbUbo/HWfoxe5K3e7bm/790cr2nZS81z4j4SkQ8WA6cfoWiJ2doTZOO6za4fG2tD/y5QxCuXZeOz4mFFK/N2vVbkXVZUldmvl7+u3q53Jc69BR3fFzbXQz8HvhZOXj8lLKHeCOK4Ph0zWvqfyh6EHtU/kp0YXlp7xVa5veBiHg3RVi/IzO/280iO857DWBhF19S2uf9IYoA+6XMnNrNvLWSMFxpRd1OMa7l03W2/xzFbrjdKD5YWsrpAZCZf8rMT1G8sf6aYiwIZW/DVzLzg8AngS93sevhAeCDEVH7DXUU3XfFb9j+T0S8CxhGMQ6qXe2b5pPAxZm5Vs1ltcw8mWK8yNoRsVpN+w90scwngXUiYq1ObuvyTbr0FMUHUnvNUa7Dn3u437I8jlAEqF0oejT+VF7/OLAdxVi5ThfRUw09eI1iV2m7963g/Nq9QNFzslHNtA+w9GO2orV3Zsk8I2JHivF7/wysnZlrUfR6dNvjUXqaouewtm3tc6vjc2I1il6iHp8TK+gpiudx7eut4+MKQNnj9s3M3IKil3Vvil2zT1Lsch1a85paIzPr2tVW9r6270Jvv88D/L3Xj4j4IEUv4ezO5hERq1K83/yZrgent1tq3vTw/hIRGwHXA/+VmRf3MG+tJAxXWiFl1/l/AudExKcj4h8iYlBE7BERp3RylyEUb6QvUnyIfqf9hohYJSIOiog1y91d8ykGJRMRe0fEh8oPl/bpbzukQTm2ZgZwQkQMjoh9Kcbl/LKb1fhwRHym7CH497K+O7poewnwyYj4eNkLNziKQycMy8zHKXp7vlmuyziKANPZ4/Y0xTfZH0bE2uVjNr68+Vlg3ZpdpR39AtgrInYtv/l/paz5tm7WEaj/cSzdTPHhNzMz/0oxcP5fKAZnP9/FfZ6lGNO0vGZQ7GJdJyLeR7E9Vli5O+sXwLcjYkj5gfdliu3ZW4ZQjCt6HhgYEf/J23tXunJ7ed9jI2JgRHyGIuS2uxQ4LCJGl0HhO8Afy92fDZOZT1I8775bvhZGUgzw/mnHthHxsYjYutxFO58i7C4uXwvXAv8dEWtExLsiYpOI2KnjPJbBTylepzuWQfNbwBUdetja6xpEMX7uL8Dny57g2ttbojgOVUs56SKKLyUbRMT6FK+/SZ0VEREbUIwbPSczz1uB9VE/Y7jSCsvM0yk+qL5B8cHxJHAMxTfBji6i2G3wZ4pfBXYMMYcAc8tdhkdSjHGCYgD89RRd8rcDP8zMm7oo6bNAK8Uvm04G9usmDAD8L8V4nJfL5X+mDHedreuTFD1vX69Z16/y99fS5ygG/74EnFCub1cOofiAmQU8RxkkMnMWxSDhR8vdJLW7KMnMhygel7MpemQ+CXyyDEA9WZbH8TaKsVztvVQzKXopu+q1guLXcPtFxMsRsTzHlrqYYoDwXIoP3J8vxzy68m8UPWOPUvz67VLgwgrn35PfUwTq2RSvgUV0vivybcpt+xmKwfsvUzxfr6i5/Q/Af1B8iXia4tdxKzJmaVkcSNED/RTFuLQTyrGIHb2PIsTMBx6kCO/t4fbzFLttZ1Ks3+UUu9qXSzmu7EiKkPUcRbA9uv32iDgvItrDTnsv2j8Cr9TsYtyxvH1D/v6eBcUuy6uA+yjGrv2mnNY+7wci4qDy6r9QfNk4oWa+C5d3vdR/RDe7iaWVXkScSDHA9uCe2kpqvoj4EUWgezYzN+mF5X0DeD4z/6fHxss+7xMovpiuCqyWPRxgWP2H4UrvaIYrSVLV3C0oSZJUIXuuJEmSKmTPlSRJUoX61MlEhw4dmi0tLc0uQ5IkqUfTp09/ITPX6zi9T4WrlpYW2traml2GJElSjyKi07NwuFtQkiSpQoYrSZKkChmuJEmSKmS4kiRJqpDhSpIkqUKGK0mSpAoZriRJkipkuJIkSaqQ4UqSJKlChitJkqQKGa4kSZIqZLiSJEmqkOFKkiSpQoYrSZKkChmuJEmSKmS4kiRJqtDAZhdQ69n5izjjutnNLkOSJPVTx+0+vNkl2HMlSZJUJcOVJElShQxXkiRJFTJcSZIkVchwJUmSVKGGhquI+EREPBQRj0TExEYuS5IkqS9oWLiKiAHAOcAewBbAgRGxRaOWJ0mS1Bc0sudqO+CRzHw0M/8K/Az4VAOXJ0mS1HSNDFcbAE/WXJ9XTltKREyIiLaIaHvt1ZcbWI4kSVLjNTJcRSfT8m0TMs/PzNbMbF1tzbUbWI4kSVLjNTJczQM2rLk+DHiqgcuTJElqukaGqz8Bm0bExhGxCvBZ4MoGLk+SJKnpGnbi5sx8KyKOAX4PDAAuzMwHGrU8SZKkvqBh4QogM68BrmnkMiRJkvoSj9AuSZJUIcOVJElShQxXkiRJFTJcSZIkVaihA9qX1XsXPshxt36k2WVIzXXiq82uQJK0Auy5kiRJqpDhSpIkqUKGK0mSpAoZriRJkipkuJIkSaqQ4UqSJKlChitJkqQKGa4kSZIqZLiSJEmqkOFKkiSpQoYrSZKkChmuJEmSKmS4kiRJqpDhSpIkqUKGK0mSpAoZriRJkipkuJIkSaqQ4UqSJKlCA5tdwFLW3wZObGt2FZIkScvNnitJkqQKGa4kSZIqZLiSJEmqkOFKkiSpQoYrSZKkChmuJEmSKmS4kiRJqpDhSpIkqUKGK0mSpAoZriRJkipkuJIkSaqQ4UqSJKlChitJkqQKGa4kSZIqZLiSJEmqkOFKkiSpQoYrSZKkChmuJEmSKmS4kiRJqpDhSpIkqUKGK0mSpAoZriRJkipkuJIkSaqQ4UqSJKlChitJkqQKGa4kSZIqZLiSJEmqkOFKkiSpQoYrSZKkChmuJEmSKmS4kiRJqpDhSpIkqUKGK0mSpAoZriRJkipkuJIkSaqQ4UqSJKlChitJkqQKDWx2AbWenb+IM66b3ewyVirH7T682SVIkvSOYs+VJElShQxXkiRJFTJcSZIkVchwJUmSVCHDlSRJUoUMV5IkSRVqWLiKiAsj4rmIuL9Ry5AkSeprGtlzNQn4RAPnL0mS1Oc0LFxl5i3AS42avyRJUl/U9DFXETEhItoiou21V19udjmSJEkrpOnhKjPPz8zWzGxdbc21m12OJEnSCml6uJIkSVqZGK4kSZIq1MhDMUwBbgc2i4h5EfHFRi1LkiSprxjYqBln5oGNmrckSVJf5W5BSZKkChmuJEmSKmS4kiRJqpDhSpIkqUKRmc2uYYnW9Qdk21OLm12GJElSjyJiema2dpxuz5UkSVKFDFeSJEkVMlxJkiRVyHAlSZJUIcOVJElShQxXkiRJFTJcSZIkVchwJUmSVCHDlSRJUoUMV5IkSRUyXEmSJFXIcCVJklQhw5UkSVKFDFeSJEkVMlxJkiRVyHAlSZJUIcOVJElShQxXkiRJFepb4Wr9bZpdgSRJ0grpW+FKkiSpnzNcSZIkVchwJUmSVCHDlSRJUoUMV5IkSRUyXEmSJFXIcCVJklQhw5UkSVKFDFeSJEkVMlxJkiRVyHAlSZJUIcOVJElShQxXkiRJFTJcSZIkVchwJUmSVCHDlSRJUoUMV5IkSRUyXEmSJFXIcCVJklQhw5UkSVKFDFeSJEkVMlxJkiRVyHAlSZJUIcOVJElShQxXkiRJFTJcSZIkVchwJUmSVCHDlSRJUoUMV5IkSRUyXEmSJFXIcCVJklShHsNVRFwcEWvWXN8oIv7Q2LIkSZL6p3p6rqYBf4yIPSPicOA64MyGViVJktRPDeypQWb+T0Q8ANwIvABsk5nPNLwySZKkfqie3YKHABcCnwcmAddExKgG1yVJktQv9dhzBfwTMC4znwOmRMSvgMnA6EYWJkmS1B/Vs1vw0wARsVpmvpaZd0bEdg2vTJIkqR/qMVxFxBjgAmB14APlLsEjgKOrLubZ+Ys447rZVc9WkiStZI7bfXizS+hSPb8WPBP4OPAiQGbeA4xvYE2SJEn9Vl0HEc3MJztMWtyAWiRJkvq9ega0PxkROwAZEasAxwIPNrYsSZKk/qmenqsjgX8FNgDmUfxK8F8bWJMkSVK/Vc+vBV8ADuqFWiRJkvq9LsNVRJwNZFe3Z+axDalIkiSpH+tut2AbMB0YDGwLPFxeRlPHgPaI2DAiboyIByPigYj4UgX1SpIk9Wld9lxl5mSAiPgC8LHMfLO8fh5wbR3zfgv4SmbeFRFDgOkRcV1mzlzxsiVJkvqmega0rw8Mqbm+ejmtW5n5dGbeVf6/gOIXhhssT5GSJEn9RT2HYjgZuDsibiyv7wScuCwLiYgWYBvgj53cNgGYALD2e3rMbJIkSX1aPb8W/ElE/Bb4aDlpYmY+U+8CImJ14JfAv2fm/E7mfz5wPsCGw7fqcgC9JElSf1DXEdqBAcDzwMvA8Iio6/Q3ETGIIlj9NDOvWL4SJUmS+o96Ttz8PeAA4AHgb+XkBG7p4X5BccLnBzPz9BWsU5IkqV+oZ8zVp4HNMvONZZz3WOAQ4L6ImFFO+3pmXrOM85EkSeo36glXjwKDgGUKV5k5DYjlKUqSJKm/qidcvQ7MiIg/UBOwPEK7JEnS29UTrq4sL5IkSepBPYdimNwbhUiSJK0Mujtx8y8y858j4j46OYFzZo5saGWSJEn9UGR2ftzOiHh/Zj4dERt1dntmPl51Ma3rD8i2CasvPfHEV6tejCRJ0gqLiOmZ2dpxencnbn66/Ft5iJIkSVpZ1XuEdkmSJNXBcCVJklShHsNVRHypnmmSJEmqr+fq0E6mfaHiOiRJklYK3R2K4UDgc8DGEVF7ENEhwIuNLkySJKk/6u4gorcBTwNDgf+umb4AuLeRRUmSJPVX3R2K4XHgcWBM75UjSZLUv/V4+puIWMDfj9C+CjAIeC0z12hkYZIkSf1RPecWHFJ7PSI+DWzXqIIkSZL6s2U+zlVm/hrYpfpSJEmS+r96dgt+pubqu4BWOjmRsyRJkuoIV8Ana/5/C5gLfKoh1UiSJPVz9Yy5Oqw3CpEkSVoZ1HP6mw9GxFUR8XxEPBcR/xsRH+yN4iRJkvqbega0Xwr8Ang/sD5wGTClkUVJkiT1V/WEq8jMizPzrfJyCQ5olyRJ6lQ9A9pvjIiJwM8oQtUBwG8iYh2AzHypgfVJkiT1K5HZfSdURDzWzc2ZmZWNv2ptbc22traqZidJktQwETE9M1s7Tq/n14IbN6YkSZKklU89uwWJiB2Altr2mXlRg2qSJEnqt+o5QvvFwCbADGBxOTkBw5UkSVIH9fRctQJbZE+DsyRJklTXoRjuB97X6EIkSZJWBvX0XA0FZkbEncAb7RMzc5+GVSVJktRP1ROuTmx0EZIkSSuLeg7FcHNvFCJJkrQy6DJcRcQCOj/NTVAcPHSNhlUlSZLUT3UZrjJzSG8WIkmStDKo59eCkiRJqpPhSpIkqUKGK0mSpAoZriRJkipkuJIkSaqQ4UqSJKlChitJkqQKGa4kSZIqZLiSJEmqkOFKkiSpQoYrSZKkChmuJEmSKmS4kiRJqpDhSpIkqUKGK0mSpAoZriRJkipkuJIkSaqQ4UqSJKlChitJkqQKGa4kSZIqZLiSJEmqkOFKkiSpQoYrSZKkChmuJEmSKmS4kiRJqpDhSpIkqUKGK0mSpAoNbHYBtZ6dv4gzrpvd7DIkSVI/ddzuw5tdgj1XkiRJVTJcSZIkVchwJUmSVCHDlSRJUoUMV5IkSRUyXEmSJFWoYeEqIgZHxJ0RcU9EPBAR32zUsiRJkvqKRh7n6g1gl8xcGBGDgGkR8dvMvKOBy5QkSWqqhoWrzExgYXl1UHnJRi1PkiSpL2jomKuIGBARM4DngOsy84+dtJkQEW0R0fbaqy83shxJkqSGa2i4yszFmTkaGAZsFxFbddLm/MxszczW1dZcu5HlSJIkNVyv/FowM18BbgI+0RvLkyRJapZG/lpwvYhYq/z/3cBuwKxGLU+SJKkvaOSvBd8PTI6IARQh7heZeXUDlydJktR0jfy14L3ANo2avyRJUl/kEdolSZIqZLiSJEmqkOFKkiSpQo0c0L7M3rvwQY679SPNLkMnvtrsCiRJ6rfsuZIkSaqQ4UqSJKlChitJkqQKGa4kSZIqZLiSJEmqkOFKkiSpQoYrSZKkChmuJEmSKmS4kiRJqpDhSpIkqUKGK0mSpAoZriRJkipkuJIkSaqQ4UqSJKlChitJkqQKGa4kSZIqZLiSJEmqkOFKkiSpQoYrSZKkCg1sdgFLWX8bOLGt2VVIkiQtN3uuJEmSKmS4kiRJqpDhSpIkqUKGK0mSpAoZriRJkipkuJIkSaqQ4UqSJKlChitJkqQKGa4kSZIqZLiSJEmqkOFKkiSpQoYrSZKkChmuJEmSKmS4kiRJqpDhSpIkqUKGK0mSpAoZriRJkipkuJIkSaqQ4UqSJKlChitJkqQKGa4kSZIqZLiSJEmqkOFKkiSpQoYrSZKkChmuJEmSKmS4kiRJqpDhSpIkqUKGK0mSpAoZriRJkipkuJIkSaqQ4UqSJKlChitJkqQKGa4kSZIqZLiSJEmqkOFKkiSpQoYrSZKkChmuJEmSKmS4kiRJqpDhSpIkqUKGK0mSpAoZriRJkipkuJIkSaqQ4UqSJKlChitJkqQKNTxcRcSAiLg7Iq5u9LIkSZKarTd6rr4EPNgLy5EkSWq6hoariBgG7AX8uJHLkSRJ6isa3XN1JvB/gb911SAiJkREW0S0Pf/88w0uR5IkqbEaFq4iYm/gucyc3l27zDw/M1szs3W99dZrVDmSJEm9opE9V2OBfSJiLvAzYJeIuKSBy5MkSWq6hoWrzPxaZg7LzBbgs8ANmXlwo5YnSZLUF3icK0mSpAoN7I2FZOZNwE29sSxJkqRmsudKkiSpQoYrSZKkCvXKbsEV8eabbzJv3jwWLVrU7FLUCwYPHsywYcMYNGhQs0uRJGm59PlwNW/ePIYMGUJLSwsR0exy1ECZyYsvvsi8efPYeOONm12OJEnLpc/vFly0aBHrrruuweodICJYd9117aWUJPVrfT5cAQardxC3tSSpv+sX4UqSJKm/6PNjrjo647rZlc7vuN2HVzq/7tx0002cdtppXH311Vx55ZXMnDmTiRMndtr2lVde4dJLL+Xoo4/utfokSdKKs+eqAosXL17m++yzzz5dBisowtUPf/jDFSlLkiQ1geGqB3PnzmXEiBEceuihjBw5kv3224/XX3+dlpYWvvWtbzFu3Dguu+wyrr32WsaMGcO2227L/vvvz8KFCwH43e9+x4gRIxg3bhxXXHHFkvlOmjSJY445BoBnn32Wfffdl1GjRjFq1Chuu+02Jk6cyJw5cxg9ejRf/epXm7LukiRp2Rmu6vDQQw8xYcIE7r33XtZYY40lPUqDBw9m2rRp7Lbbbpx00klcf/313HXXXbS2tnL66aezaNEiDj/8cK666iqmTp3KM8880+n8jz32WHbaaSfuuece7rrrLrbccktOPvlkNtlkE2bMmMGpp57am6srSZJWgOGqDhtuuCFjx44F4OCDD2batGkAHHDAAQDccccdzJw5k7FjxzJ69GgmT57M448/zqxZs9h4443ZdNNNiQgOPvjgTud/ww03cNRRRwEwYMAA1lxzzV5YK0mS1Aj9bkB7M3Q8PED79dVWWw0oDn65++67M2XKlKXazZgxw0MLSJL0DmPPVR2eeOIJbr/9dgCmTJnCuHHjlrp9++2359Zbb+WRRx4B4PXXX2f27NmMGDGCxx57jDlz5iy5b2d23XVXzj33XKAYHD9//nyGDBnCggULGrVKkiSpQfpdz1VvHjqh3eabb87kyZM54ogj2HTTTTnqqKM4++yzl9y+3nrrMWnSJA488EDeeOMNAE466SSGDx/O+eefz1577cXQoUMZN24c999//9vm//3vf58JEyZwwQUXMGDAAM4991zGjBnD2LFj2Wqrrdhjjz0cdyVJUj8RmdnsGpZobW3Ntra2paY9+OCDbL755k2qqPi14N57791pKFJjNHubS5JUj4iYnpmtHae7W1CSJKlChqsetLS02GslSZLqZriSJEmqkOFKkiSpQoYrSZKkChmuJEmSKtTvjnPFiRWfGubEV6udXw9uuukmTjvtNK6++mquvPJKZs6cycSJEztt+8orr3DppZdy9NFHL9ey9txzTy699FKApeZTW0OV2trauOiiizjrrLMqna8kSf2JPVcVWbx48TLfZ5999ukyWEERrtpPEr08rrnmGtZaa60Vnk+9WltbDVaSpHc8w1UP5s6dy4gRIzj00EMZOXIk++23H6+//jpQHKbhW9/6FuPGjeOyyy7j2muvZcyYMWy77bbsv//+LFy4EIDf/e53jBgxgnHjxnHFFVcsmfekSZM45phjAHj22WfZd999GTVqFKNGjeK2225j4sSJzJkzh9GjR/PVr351qbpOOeWUJUHmuOOOY5dddgHgD3/4w5ITRLe0tPDCCy90Op+FCxey3377MWLECA466CA6O5jszjvvzPHHH892223H8OHDmTp1KgCLFi3isMMOY+utt2abbbbhxhtvBIoesb333huAm2++mdGjRzN69Gi22WabJafyOfXUU/nIRz7CyJEjOeGEE1Z080iS1OcYrurw0EMPMWHCBO69917WWGONpXqBBg8ezLRp09htt9046aSTuP7667nrrrtobW3l9NNPZ9GiRRx++OFcddVVTJ06lWeeeabTZRx77LHstNNO3HPPPdx1111sueWWnHzyyWyyySbMmDHjbae/GT9+/JKw09bWxsKFC3nzzTeZNm0aO+6441JtO5vP3XffzZlnnsnMmTN59NFHufXWWzut66233uLOO+/kzDPP5Jvf/CYA55xzDgD33XcfU6ZM4dBDD2XRokVL3e+0007jnHPOYcaMGUydOpV3v/vdXHvttTz88MPceeedzJgxg+nTp3PLLbfUuxkkSeoXDFd12HDDDRk7diwABx98MNOmTVty2wEHHADAHXfcwcyZMxk7diyjR49m8uTJPP7448yaNYuNN96YTTfdlIhY0qvU0Q033MBRRx0FwIABA1hzze7Hln34wx9m+vTpLFiwgFVXXZUxY8bQ1tbG1KlT3xauOrPddtsxbNgw3vWudzF69Gjmzp3babvPfOYzS5bX3mbatGkccsghAIwYMYKNNtqI2bNnL3W/sWPH8uUvf5mzzjqLV155hYEDB3Lttddy7bXXss0227Dtttsya9YsHn744R5rlSSpP+l/A9qbICK6vL7aaqsBkJnsvvvuTJkyZam2M2bMeNv9qzBo0CBaWlr4yU9+wg477MDIkSO58cYbmTNnTl3n5Vt11VWX/D9gwADeeuutbtvVtqnnfJQTJ05kr7324pprrmH77bfn+uuvJzP52te+xhFHHFHPKkqS1C/Zc1WHJ554gttvvx2AKVOmMG7cuLe12X777bn11lt55JFHAHj99deZPXs2I0aM4LHHHmPOnDlL7t+ZXXfdlXPPPRcoBsfPnz+fIUOGLBmr1Jnx48dz2mmnMX78eHbccUfOO+88Ro8e/bYw19N8ltX48eP56U9/CsDs2bN54okn2GyzzZZqM2fOHLbeemuOP/54WltbmTVrFh//+Me58MILl4xF+/Of/8xzzz1XWV2SJPUF/a/nqpcPnQCw+eabM3nyZI444gg23XTTJbvvaq233npMmjSJAw88kDfeeAOAk046ieHDh3P++eez1157MXToUMaNG9fpuQq///3vM2HCBC644AIGDBjAueeey5gxYxg7dixbbbUVe+yxx9vGXe244458+9vfZsyYMay22moMHjy4012C66677lLz2WuvvVbo8Tj66KM58sgj2XrrrRk4cCCTJk1aqicM4Mwzz+TGG29kwIABbLHFFuyxxx6suuqqPPjgg4wZMwaA1VdfnUsuuYT3vOc9K1SPJEl9SdSzi6e3tLa2Zltb21LTHnzwwbp2czXK3Llz2XvvvT15cy9q9jaXJKkeETE9M1s7Tne3oCRJUoUMVz1oaWmx10qSJNWtX4SrvrTrUo3ltpYk9Xd9PlwNHjyYF1980Q/dd4DM5MUXX2Tw4MHNLkWSpOXW538tOGzYMObNm8fzzz/f7FLUCwYPHsywYcOaXYYkScutz4erQYMGsfHGGze7DEmSpLr0+d2CkiRJ/YnhSpIkqUKGK0mSpAr1qSO0R8QC4KFm16EeDQVeaHYR6pHbqf9wW/UPbqf+oTe300aZuV7HiX1tQPtDnR1GXn1LRLS5nfo+t1P/4bbqH9xO/UNf2E7uFpQkSaqQ4UqSJKlCfS1cnd/sAlQXt1P/4HbqP9xW/YPbqX9o+nbqUwPaJUmS+ru+1nMlSZLUrxmuJEmSKtTr4SoiPhERD0XEIxExsZPbIyLOKm+/NyK27e0aVahjWx1UbqN7I+K2iBjVjDrf6XraTjXtPhIRiyNiv96sT4V6tlNE7BwRMyLigYi4ubdrVKGO9741I+KqiLin3FaHNaPOd7KIuDAinouI+7u4vblZIjN77QIMAOYAHwRWAe4BtujQZk/gt0AA2wN/7M0avSzTttoBWLv8fw+3Vd/cTjXtbgCuAfZrdt3vtEudr6e1gJnAB8rr72l23e/ES53b6uvA98r/1wNeAlZpdu3vpAswHtgWuL+L25uaJXq752o74JHMfDQz/wr8DPhUhzafAi7Kwh3AWhHx/l6uU3Vsq8y8LTNfLq/eAQzr5RpV32sK4N+AXwLP9WZxWqKe7fQ54IrMfAIgM91WzVHPtkpgSEQEsDpFuHqrd8t8Z8vMWyge9640NUv0drjaAHiy5vq8ctqytlHjLet2+CLFtwT1rh63U0RsAOwLnNeLdWlp9byehgNrR8RNETE9Ij7fa9WpVj3b6gfA5sBTwH3AlzLzb71TnurU1CzR26e/iU6mdTwWRD1t1Hh1b4eI+BhFuBrX0IrUmXq205nA8Zm5uPiirSaoZzsNBD4M7Aq8G7g9Iu7IzNmNLk5LqWdbfRyYAewCbAJcFxFTM3N+g2tT/ZqaJXo7XM0DNqy5Powi+S9rGzVeXdshIkYCPwb2yMwXe6k2/V0926kV+FkZrIYCe0bEW5n5616pUFD/e98Lmfka8FpE3AKMAgxXvauebXUYcHIWg3seiYjHgBHAnb1TourQ1CzR27sF/wRsGhEbR8QqwGeBKzu0uRL4fDnSf3vg1cx8upfrVB3bKiI+AFwBHOK366bpcTtl5saZ2ZKZLcDlwNEGq15Xz3vf/wI7RsTAiPgH4KPAg71cp+rbVk9Q9DASEe8FNgMe7dUq1ZOmZole7bnKzLci4hjg9xS/yLgwMx+IiCPL28+j+DXTnsAjwOsU3xDUy+rcVv8JrAv8sOwVeSs9Y3yvqnM7qcnq2U6Z+WBE/A64F/gb8OPM7PRn5mqcOl9T/wVMioj7KHY/HZ+ZLzSt6HegiJgC7AwMjYh5wAnAIOgbWcLT30iSJFXII7RLkiRVyHAlSZJUIcOVJElShQxXkiRJFTJcSZIkVchwJa0EIuJ9EfGziJgTETMj4pqIGB4RLV2dNb6CZa4aET8vzzr/x4hoacRyOixz54i4uvx/n4iY2E3btSLi6Jrr60fE5Y2usR7lKW7qPmxJRHwhIn7QxW23lX+XbOuIaI2Is8r/d46IHaqoW1J9DFdSP1eePPZXwE2ZuUlmbgF8HXhvgxf9ReDlzPwQcAbwveWdUUQMWNb7ZOaVmXlyN03WApaEq8x8KjP3W47ylktE9MpxBDPzbcEpM9sy89jy6s6A4UrqRYYrqf/7GPBm7QFDM3NGZk6tbVT2bEyNiLvKyw7l9PdHxC0RMSMi7o+IHSNiQERMKq/fFxHHdbLcTwGTy/8vB3aNDicvLJc5KyImR8S9EXF5efRxImJuRPxnREwD9o+If4yI28vaLouI1ct2nyjnMQ34TM28l/TmRMR7I+JXEXFPedkBOBnYpFyvUzv07AyOiJ+U63Z3FOfHbJ/nFRHxu4h4OCJO6ewBL2v/XkTcWV4+VE6fFBGnR8SNwPciYnRE3FGu+68iYu2a2RwcEbeVj/F25f23K6fdXf7drKb9hmVdD0XECTW1LOykvp0j4uqyN/FI4LjycdgxIh6LiEFluzXKdRnU2XpKWj6GK6n/2wqYXke754DdM3Nb4ADgrHL654DfZ+ZoinPZzQBGAxtk5laZuTXwk07mt+Ss85n5FvAqxRH7O9oMOD8zRwLzqelNAhZl5jjgeuAbwG5lfW3AlyNiMPAj4JPAjsD7uli3s4CbM3MUsC3wADARmJOZozPzqx3a/2tZ99bAgcDkclmU634AsDVwQERsSOfmZ+Z2wA8oTo7dbni5Hl8BLqI4evdI4D6Ko0i3W63sdToauLCcNgsYn5nbUJwB4Ts17bcDDirr27+e3YqZORc4DzijfBymAjcBe5VNPgv8MjPf7GlekupnuJLeOQYBP4rilB2XAVuU0/8EHBYRJwJbZ+YCivOkfTAizo6IT1CEoo7qPev8k5l5a/n/JcC4mtt+Xv7dvqzn1oiYARwKbERxMtzHMvPh8iS5l3SxbrsA5wJk5uLMfLWLdu3GAReX7WcBj1OEIoA/ZOarmbkImFnW0ZkpNX/H1Ey/LDMXR8SawFqZeXM5fTIwvuP9M/MWYI2IWAtYE7is7GE7A9iypv11mfliZv6F4pyetY/jsvgxfz8VyGF0HpwlrQDDldT/PQB8uI52xwHPUvROtQKrwJIP9/HAn4GLI+Lzmfly2e4mil6eH3cyvyVnnS/HF60JvNRJu46Bq/b6a+XfoAgPo8vLFpn5xS7uX4XOgmG7N2r+X0zX52DNLv5/rWPDOu7ffv2/gBszcyuK3rrBPbRfZmXQbYmInYABnr9Qqp7hSur/bgBWjYjD2ydExEfKD89aawJPZ+bfgEMoTkpLRGwEPJeZPwIuALaNiKHAuzLzl8B/UOxq6+hKih4mgP2AG7Lzk5V+ICLae3YOBKZ10uYOYGzN2KV/iIjhFLvJNo6ITWru35k/AEeV9x0QEWsAC4AhXbS/hWIXG+VyPgA81EXbrhxQ8/f2jjeWvWcvR8SO5aRDgJtrmhxQLn8c8GrZfk2KkAvwhQ6z3D0i1omIdwOfBm6lPp09DhdR9JzZayU1gOFK6ufKQLMvxYfvnIh4ADgReKpD0x8Ch0bEHRS7wNp7WHYGZkTE3cA/Ad+nGE91U7mLbhLwtU4WfQGwbkQ8AnyZYoxTZx4sl3svsA7l7rsO6/A8RZiYUra7AxhR7pqbAPymHND+eBfL+BLwsXKX53Rgy8x8kWI34/0RcWonj8WAsv3PgS9k5hssm1Uj4o/lsjsb8A9F+Dy1XKfRwLdqbns5isMonEfxy0uAU4DvRsStlOG3xjSKXZkzKMZJtdVZ51XAvu0D2stpPwXW5u+7NiVVKDr/oilJK678tdrV5W6ulUZEzAVaM/OFZteyPCJiP+BTmXlIs2uRVka9chwWSVLfEBFnA3sAeza7FmllZc+VJElShRxzJUmSVCHDlSRJUoUMV5IkSRUyXEmSJFXIcCVJklSh/w8p365hWQSF0wAAAABJRU5ErkJggg==\n", - "text/plain": [ - "
" - ] - }, - "metadata": { - "needs_background": "light" - }, - "output_type": "display_data" - } - ], - "source": [ - "Y_flag_predict = np.argmax(Y_predict,axis=1)\n", - "Y_flag_predict_noise = np.argmax(Y_predict_noise,axis=1)\n", - "fig, ax = plt.subplots(figsize=(10,7))\n", - "# Example data\n", - "num_samples=5\n", - "y_pos = np.arange(num_samples)\n", - "y_pos_off = y_pos+0.25\n", - "ax.barh(y_pos, Y_predict[0:num_samples][:,0], height=0.25,align='center', label=\"predict\", alpha=0.5)\n", - "ax.barh(y_pos_off, Y_predict_noise[0:num_samples][:,0],height=0.25, align='center', label=\"predict with noise\")\n", - "\n", - "ax.set_yticks(y_pos)\n", - "ax.invert_yaxis() # labels read top-to-bottom\n", - "plt.legend()\n", - "plt.xlabel(\"Class 0 prediction probability\")\n", - "plt.ylabel(\"Input index\")\n", - "plt.title(\"Class 0 predictions with uniform random noise [-0.2,0.2]\")" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "id": "oriented-factor", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "0.0858112 0.2\n", - "2986\n", - "3631\n", - "1\n", - "[0.01942994 0.01748677 0.01767595 ... 0.05898305 0.01972752 0.01901901]\n" - ] - } - ], - "source": [ - "test_y = X_test[1].flatten()\n", - "test_cf = results[1][0].flatten()\n", - "diff = test_y-test_cf\n", - "threshold=0.2\n", - "max_value = np.max(np.abs(diff))\n", - "print(max_value ,threshold)\n", - "ind_pos = np.where(diff > threshold*max_value)\n", - "ind_neg = np.where(diff < -threshold*max_value)\n", - "print(len(ind_pos[0]))\n", - "print(len(ind_neg[0]))\n", - "cf_class = np.abs(1-np.argmax(Y_test[0]))\n", - "print(cf_class)\n", - "print(diff[ind_pos[0]])" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "behind-associate", - "metadata": {}, - "outputs": [], - "source": [ - "test = np.concatenate([X_train,X_test])" - ] - }, - { - "cell_type": "code", - "execution_count": 25, - "id": "flexible-sussex", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1400, 60483, 1)\n" - ] - } - ], - "source": [ - "print(test.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 27, - "id": "specific-director", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1120, 60483, 1) (280, 60483, 1)\n" - ] - } - ], - "source": [ - "print(X_train.shape, X_test.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 28, - "id": "muslim-mortgage", - "metadata": {}, - "outputs": [], - "source": [ - "test = np.concatenate([Y_train,Y_test])" - ] - }, - { - "cell_type": "code", - "execution_count": 29, - "id": "alternate-wrestling", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "(1400, 2)\n" - ] - } - ], - "source": [ - "print(test.shape)" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "sweet-venezuela", - "metadata": {}, - "outputs": [], - "source": [ - "with open('small_threshold.pkl', 'rb') as pickle_file:\n", - " t_results = pickle.load(pickle_file)" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "xai-geom-tf", - "language": "python", - "name": "xai-geom-tf" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.8.6" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From 31e1712b682142b1a40419655a435762a50bc09b Mon Sep 17 00:00:00 2001 From: Ashka Shah Date: Tue, 5 Oct 2021 12:28:55 -0500 Subject: [PATCH 05/12] updated all cf and noise generation scripts to work together --- Pilot1/NT3/nt3_cf/cf_nb.py | 36 +++++--------- Pilot1/NT3/nt3_cf/cf_script.py | 65 ------------------------ Pilot1/NT3/nt3_cf/gen_clusters.py | 8 +-- Pilot1/NT3/nt3_cf/inject_noise.py | 82 +++++++++++++++++++------------ Pilot1/NT3/nt3_cf/threshold.py | 8 +-- 5 files changed, 70 insertions(+), 129 deletions(-) delete mode 100644 Pilot1/NT3/nt3_cf/cf_script.py diff --git a/Pilot1/NT3/nt3_cf/cf_nb.py b/Pilot1/NT3/nt3_cf/cf_nb.py index d30b3613..d4ac60db 100644 --- a/Pilot1/NT3/nt3_cf/cf_nb.py +++ b/Pilot1/NT3/nt3_cf/cf_nb.py @@ -12,9 +12,9 @@ print('Eager execution enabled: ', tf.executing_eagerly()) # False print(tf.test.is_gpu_available()) import pickle -model_nt3 = tf.keras.models.load_model('/vol/ml/shahashka/xai-geom/nt3/nt3.autosave.model') -with open('/vol/ml/shahashka/xai-geom/nt3/nt3.autosave.data.pkl', 'rb') as pickle_file: - X_train,Y_train,X_test,Y_test = pickle.load(pickle_file) +model_nt3 = tf.keras.models.load_model('../nt3.autosave.model') +with open('../nt3.autosave.data.pkl', 'rb') as pickle_file: + X_train,X_test,Y_train,Y_test = pickle.load(pickle_file) shape_cf = (1,) + X_train.shape[1:] print(shape_cf) @@ -32,9 +32,10 @@ feature_range=feature_range) shape = X_train[0].shape[0] results=[] +failed_inds = [] X = np.concatenate([X_train,X_test]) -for i in np.arange(902,903): +for i in np.arange(0,10):#X.shape[0]): print(i) x_sample=X[i:i+1] print(x_sample.shape) @@ -43,28 +44,13 @@ explanation = cf.explain(x_sample) print('Counterfactual prediction: {}, {}'.format(explanation.cf['class'], explanation.cf['proba'])) print("Actual prediction: {}".format(model_nt3.predict(x_sample))) - results.append([explanation.cf['X'],explanation.cf['class'], explanation.cf['proba']]) + results.append([i, explanation.cf['X'],explanation.cf['class'], explanation.cf['proba']]) test = model_nt3.predict(explanation.cf['X']) print(test, explanation.cf['proba'], explanation.cf['class']) except: print("Failed cf generation") - results.append([None, None, None]) - #if i %100==0: -pickle.dump(results, open("redo_cf_rest.pkl", "wb")) - # results = [] -#for i in range(len(results)): -# plt.figure(figsize=(20, 4)) -# sample = X_train[i].flatten() -# y = results[i][0].flatten() -# x = np.arange(y.shape[0]) -# plt.plot(x,y,alpha=0.5, label='counterfactual') -# plt.plot(x,sample,alpha=0.5, label='input') -# plt.plot(x,sample-y, label='diff') -# props = dict(boxstyle='round', facecolor='wheat', alpha=1) -# prediction = model_nt3.predict(X_test[i:i+1]) -# plt.text(0.05, 0.95, "original input: {} {} \n counterfactual: {} {}".format(np.argmax(prediction), -# prediction,results[i][1] ,results[i][2]), -# fontsize=16, -# verticalalignment='top', bbox=props) -# plt.legend() -# plt.savefig("fig_{}.png".format(i)) + failed_inds.append(i) + #if i%100 == 0: +pickle.dump(results, open("cf_all.pkl", "wb")) + #results = [] +pickle.dump([failed_inds], open("cf_failed_inds.pkl", "wb")) diff --git a/Pilot1/NT3/nt3_cf/cf_script.py b/Pilot1/NT3/nt3_cf/cf_script.py deleted file mode 100644 index 0b74a2d5..00000000 --- a/Pilot1/NT3/nt3_cf/cf_script.py +++ /dev/null @@ -1,65 +0,0 @@ -import tensorflow as tf -tf.get_logger().setLevel(40) # suppress deprecation messages -tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs -from tensorflow.keras.models import Model, load_model -import matplotlib.pyplot as plt -import numpy as np -import os -from time import time -from alibi.explainers import CounterFactual, CounterFactualProto -#print('TF version: ', tf.__version__) -#print('Eager execution enabled: ', tf.executing_eagerly()) # False -import pickle - -model_nt3 = tf.keras.models.load_model('./nt3.autosave.model') -with open('./nt3.autosave.data.pkl', 'rb') as pickle_file: - X_train,Y_train,X_test,Y_test = pickle.load(pickle_file) - - -shape_cf = (1,) + X_train.shape[1:] -print(shape_cf) -target_proba = 0.9 -tol = 0.1 # want counterfactuals with p(class)>0.90 -target_class = 'other' # any class other than will do -max_iter = 1000 -lam_init = 1e-1 -max_lam_steps = 20 -learning_rate_init = 0.1 -feature_range = (0,1) - - -cf = CounterFactual(model_nt3, shape=shape_cf, target_proba=target_proba, tol=tol, - target_class=target_class, max_iter=max_iter, lam_init=lam_init, - max_lam_steps=max_lam_steps, learning_rate_init=learning_rate_init, - feature_range=feature_range) - -shape = X_train[0].shape[0] -results=[] - -X = np.concatenate([X_train,X_test]) -#X=X_test -print(X.shape[0], "-x shape 0") -for i in np.arange(0,X.shape[0]): -# for i in range(4): - - x_sample=X[i:i+1] - print(x_sample.shape) - start = time() - explanation = cf.explain(x_sample) - iter_cf = 0 - n = len(explanation['all'][iter_cf]) - - print('Counterfactual prediction: {}, {}'.format(explanation.cf['class'], explanation.cf['proba'])) - print("Actual prediction: {}".format(model_nt3.predict(x_sample))) - - results.append([i, explanation.cf['X'],explanation.cf['class'], explanation.cf['proba']]) - - if ((i+1)%2 == 0): - print("saving i=", i) - filename = "save.p" + str(i) - pickle.dump(results, open(filename, "wb")) - results=[] - - if i==1: - print("before exit", i) - break diff --git a/Pilot1/NT3/nt3_cf/gen_clusters.py b/Pilot1/NT3/nt3_cf/gen_clusters.py index cf1aa43a..d49470bb 100644 --- a/Pilot1/NT3/nt3_cf/gen_clusters.py +++ b/Pilot1/NT3/nt3_cf/gen_clusters.py @@ -40,10 +40,12 @@ def get_args(): indices_0 = np.array(indices_0) indices_1 = np.array(indices_1) sil = [] - kmax = 10 + print(len(perturb_vector_0), len(perturb_vector_1)) + kmax = np.min([len(perturb_vector_0), len(perturb_vector_1),10]) # dissimilarity would not be defined for a single cluster, thus, minimum number of clusters should be 2 for k in range(2, kmax + 1): + print(k) kmeans = KMeans(n_clusters=k).fit(perturb_vector_0) labels = kmeans.labels_ sil.append(silhouette_score(perturb_vector_0, labels, metric='euclidean')) @@ -51,7 +53,7 @@ def get_args(): plt.title("Silhouette scores to determine optimal k") plt.xlabel("k") plt.show() - k = np.argmax(sil) + 2 + k = np.argmax(sil) + 2 if len(sil) > 0 else kmax print(k) data_2D = PCA(2).fit_transform(perturb_vector_0) kmeans_0 = KMeans(n_clusters=k).fit(perturb_vector_0) @@ -71,7 +73,7 @@ def get_args(): plt.title("Silhouette scores to determine optimal k") plt.xlabel("k") plt.show() - k = np.argmax(sil) + 2 + k = np.argmax(sil) + 2 if len(sil) > 0 else kmax print(k) data_2D = PCA(2).fit_transform(perturb_vector_1) kmeans_1 = KMeans(n_clusters=k).fit(perturb_vector_1) diff --git a/Pilot1/NT3/nt3_cf/inject_noise.py b/Pilot1/NT3/nt3_cf/inject_noise.py index 870096f4..1e678b4f 100644 --- a/Pilot1/NT3/nt3_cf/inject_noise.py +++ b/Pilot1/NT3/nt3_cf/inject_noise.py @@ -2,88 +2,106 @@ import numpy as np import copy import argparse - +import os def get_args(): parser = argparse.ArgumentParser() parser.add_argument("-t", type=str, help="threshold pickle file") - parser.add_argument("-c1", type=list, help="cluster 1") - parser.add_argument("-c2", type=list, help="cluster 2") + parser.add_argument("-c1", type=str, help="cluster 1 file") + parser.add_argument("-c2", type=str, help="cluster 2 file") parser.add_argument("-scale", type=float, help="scale factor for noise injection") parser.add_argument("-r", type=bool, help="flag to add random noise") + parser.add_argument("-o", type=str, help="folder for output files") + parser.add_argument("-d", type=str, help="nt3 data file") + parser.add_argument("-f", type=str, help="pickle file containing failed cf indices") args = parser.parse_args() return args -def random_noise(c1,c2,scale,size, cluster_inds): - X_train, y_train, X_test, y_test = pickle.load(open("nt3.autosave.data.pkl", 'rb')) + +# Choose a random set of indices to inject cf noise into +def random_noise(s,scale,size, cluster_inds, args): + X_train, X_test, y_train, y_test = pickle.load(open(args.d, 'rb')) X_data = np.concatenate([X_train, X_test]) y_data = np.concatenate([y_train, y_test]) genes = np.random.choice(np.arange(X_data.shape[0]), replace=False, size=size) noise = np.random.normal(0,1,size) X_data_noise = copy.deepcopy(X_data) - print(c1,c2) + s, _ = s.split(".") + cluster_name = s[3:] for p in np.arange(0.1,1.0, 0.1): for i in cluster_inds: for j in range(size): X_data_noise[i][genes[j]]+=noise[j] - pickle.dump([X_data_noise, y_data, [], cluster_inds], open("nt3.data.random.scale_{}_cluster_{}_{}.noise_{}.pkl".format(scale,c1,c2,round(p,1)), "wb")) + pickle.dump([X_data_noise, y_data, [], cluster_inds], open("{}/nt3.data.random.scale_{}_{}.noise_{}.pkl".format(args.o,scale,cluster_name,round(p,1)), "wb")) def main(): args = get_args() + isExist = os.path.exists(args.o) + if not isExist: + os.makedirs(args.o) # For 2 clusters (with sparse injection feature vector) add CF noise to x% of samples - X_train, y_train, X_test, y_test = pickle.load(open("nt3.autosave.data.pkl", 'rb')) + X_train, X_test, y_train, y_test = pickle.load(open(args.d, 'rb')) threshold_dataset = pickle.load(open(args.t, 'rb')) perturb_dataset = threshold_dataset['perturbation vector'] - #failed index - perturb_dataset.insert(919, np.zeros(X_train.shape[1])) - perturb_dataset = np.array(perturb_dataset) + + + #combine for easier indexing later X_data = np.concatenate([X_train, X_test]) y_data = np.concatenate([y_train, y_test]) - clusters = [(0,1),(1,1)] - cluster_files = [] - for c in clusters: - cluster_files.append(pickle.load(open("clusters_0911_0.5/cf_class_{}_cluster{}.pkl".format(c[0], c[1]), 'rb'))) + + #account for failed indices + failed_indices = pickle.load(open(args.f, 'rb'))[0] + print(failed_indices) + for i in failed_indices: + perturb_dataset.insert(i, np.zeros(X_data.shape[1])) + perturb_dataset = np.array(perturb_dataset) + + cluster_files = [args.c1, args.c2] for i in range(len(cluster_files)): - d=cluster_files[i] + print(cluster_files[i]) + d = pickle.load(open(cluster_files[i], "rb")) cluster_inds = d['sample indices in this cluster'] - random_noise(clusters[i][0],clusters[i][1],args.scale,20, cluster_inds) - #return + if args.r: + random_noise(cluster_files[i],args.scale,20, cluster_inds, args) + + # Sweep through percentages for p in np.arange(0.1,1.0, 0.1): print("p={}".format(p)) X_data_noise = copy.deepcopy(X_data) - # Full cf injection + + #Full cf injection # Choose x% of the indices to be perturbed selector = np.random.choice(a=cluster_inds, replace=False, size = (int)(p*len(cluster_inds))) - #print(perturb_dataset[selector]) X_data_noise[selector]-= args.scale*perturb_dataset[selector][:,:,None] - #print(np.sum(X_data_noise - X_data)) - # Now split back into train test + # Now split back into train test for output X_train = X_data[0:(int)(0.8*X_data.shape[0])] - X_test = X_data[0.8*X_data.shape[0]:] + X_test = X_data[(int)(0.8*X_data.shape[0]):] y_train = y_data[0:(int)(0.8*y_data.shape[0])] - y_test = y_data[0.8*y_data.shape[0]:] - - pickle.dump([X_train, X_test, y_train, y_test, selector, cluster_inds], open("nt3.data.scale_{}_cluster_{}_{}.noise_{}.pkl".format(args.scale,clusters[i][0], clusters[i][1], round(p,1)), "wb")) + y_test = y_data[(int)(0.8*y_data.shape[0]):] + s,_ = cluster_files[i].split(".") + cluster_name = s[3:] + pickle.dump([X_train, X_test, y_train, y_test, selector, cluster_inds], open("{}/nt3.data.scale_{}_{}.noise_{}.pkl".format(args.o, args.scale,cluster_name, round(p,1)), "wb")) - # Threshold cf injection + # Add cf noise only to those indices that passed the threshold value (instead of the full cf profile) inds = [] - print(d) for j in d['positive threshold indices'][0]: inds.append(j) for j in d['negative threshold indices'][0]: inds.append(j) - print(len(inds)) X_data_noise_2 = copy.deepcopy(X_data) + for j in inds: perturb_dataset[:,j]=0 X_data_noise_2[selector]-= args.scale*perturb_dataset[selector][:,:,None] + # Now split back into train test X_train = X_data[0:(int)(0.8*X_data.shape[0])] - X_test = X_data[0.8*X_data.shape[0]:] + X_test = X_data[(int)(0.8*X_data.shape[0]):] y_train = y_data[0:(int)(0.8*y_data.shape[0])] - y_test = y_data[0.8*y_data.shape[0]:] - pickle.dump([X_train, X_test, y_train, y_test, selector, cluster_inds], open("nt3.data.threshold_scale_{}_cluster_{}_{}.noise_{}.pkl".format(args.scale, clusters[i][0], clusters[i][1], round(p,1)), "wb")) + y_test = y_data[(int)(0.8*y_data.shape[0]):] + + pickle.dump([X_train, X_test, y_train, y_test, selector, cluster_inds], open("{}/nt3.data.threshold.scale_{}_{}.noise_{}.pkl".format(args.o, args.scale, cluster_name, round(p,1)), "wb")) if __name__ == "__main__": main() diff --git a/Pilot1/NT3/nt3_cf/threshold.py b/Pilot1/NT3/nt3_cf/threshold.py index 92841f12..cdd339e6 100644 --- a/Pilot1/NT3/nt3_cf/threshold.py +++ b/Pilot1/NT3/nt3_cf/threshold.py @@ -25,7 +25,7 @@ def threshold(t_value, X, y, cf): diffs = [] for i in range(len(cf)): test_y = X[i].flatten() - test_cf = cf[i][0].flatten() + test_cf = cf[i][1].flatten() diff = test_y-test_cf max_value = np.max(np.abs(diff)) @@ -38,7 +38,7 @@ def threshold(t_value, X, y, cf): pos.append(ind_pos) neg.append(ind_neg) cf_classes.append(cf_class) - inds.append(i) + inds.append(cf[i][0]) diffs.append(diff) return pos,neg,cf_classes,inds, diffs @@ -46,7 +46,7 @@ def threshold(t_value, X, y, cf): def main(): args = get_args() with open(args.d, 'rb') as pickle_file: - X_train,Y_train,X_test,Y_test = pickle.load(pickle_file) + X_train,X_test, Y_train,Y_test = pickle.load(pickle_file) with open(args.c, 'rb') as pickle_file: cf = pickle.load(pickle_file) @@ -67,4 +67,4 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() From 2b66e74d7c98e11418a067bdf97e065bd2482327 Mon Sep 17 00:00:00 2001 From: Ashka Shah Date: Tue, 5 Oct 2021 12:35:14 -0500 Subject: [PATCH 06/12] update readme --- Pilot1/NT3/nt3_cf/README.md | 17 ++++++++++++++++- Pilot1/NT3/nt3_cf/cf_nb.py | 8 ++++---- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/Pilot1/NT3/nt3_cf/README.md b/Pilot1/NT3/nt3_cf/README.md index 3866b586..29f35acc 100644 --- a/Pilot1/NT3/nt3_cf/README.md +++ b/Pilot1/NT3/nt3_cf/README.md @@ -3,9 +3,24 @@ Code to generate counterfactual examples given an input model and dataset in pkl Clusters and thresholds counterfactuals, injects noise into dataset \ Workflow: 1) Generate counterfactuals using cf_nb.py +''' +python cf_nb.py> +''' + 2) Create threshold pickle files using threshold.py (provide a threshold value between 0 and 1, see --help) -3) Cluster threshold files using gen_clusters.py +''' +python threshold.py -d ../nt3.autosave.data.pkl -c cf_redo_all_reformat.pkl -t 0.9 -o threshold_0.9.pkl> +''' + +3) Cluster threshold files using gen_clusters.py +''' +python gen_clusters.py -t_value 0.9 -t threshold_0.9.pkl> +''' + 4) Inject noise into dataset using inject_noise.py (provide a scale value to modify the amplitude of the noise, see --help) +''' +python inject_noise.py -t threshold_0.9.pkl -c1 cf_class_0_cluster0.pkl -c2 cf_class_1_cluster0.pkl -scale 1.0 -r True -d ../nt3.autosave.data.pkl -f cf_failed_inds.pkl -o noise_data> +''' Abstention with counterfactuals: Code located in abstention/ diff --git a/Pilot1/NT3/nt3_cf/cf_nb.py b/Pilot1/NT3/nt3_cf/cf_nb.py index d4ac60db..736d2847 100644 --- a/Pilot1/NT3/nt3_cf/cf_nb.py +++ b/Pilot1/NT3/nt3_cf/cf_nb.py @@ -35,7 +35,7 @@ failed_inds = [] X = np.concatenate([X_train,X_test]) -for i in np.arange(0,10):#X.shape[0]): +for i in np.arange(0,X.shape[0]): print(i) x_sample=X[i:i+1] print(x_sample.shape) @@ -50,7 +50,7 @@ except: print("Failed cf generation") failed_inds.append(i) - #if i%100 == 0: -pickle.dump(results, open("cf_all.pkl", "wb")) - #results = [] + if i%100 == 0 and i is not 0: + pickle.dump(results, open("cf_{}.pkl".format(i), "wb")) + results = [] pickle.dump([failed_inds], open("cf_failed_inds.pkl", "wb")) From c7a902a1562148c266ea975ea794574f98dacb3a Mon Sep 17 00:00:00 2001 From: shahashka <45748060+shahashka@users.noreply.github.com> Date: Tue, 5 Oct 2021 12:37:16 -0500 Subject: [PATCH 07/12] Update README.md fix code blocks --- Pilot1/NT3/nt3_cf/README.md | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/Pilot1/NT3/nt3_cf/README.md b/Pilot1/NT3/nt3_cf/README.md index 29f35acc..284a1bb3 100644 --- a/Pilot1/NT3/nt3_cf/README.md +++ b/Pilot1/NT3/nt3_cf/README.md @@ -3,28 +3,28 @@ Code to generate counterfactual examples given an input model and dataset in pkl Clusters and thresholds counterfactuals, injects noise into dataset \ Workflow: 1) Generate counterfactuals using cf_nb.py -''' -python cf_nb.py> -''' +``` +python cf_nb.py +``` 2) Create threshold pickle files using threshold.py (provide a threshold value between 0 and 1, see --help) -''' -python threshold.py -d ../nt3.autosave.data.pkl -c cf_redo_all_reformat.pkl -t 0.9 -o threshold_0.9.pkl> -''' +``` +python threshold.py -d ../nt3.autosave.data.pkl -c cf_redo_all_reformat.pkl -t 0.9 -o threshold_0.9.pkl +``` 3) Cluster threshold files using gen_clusters.py -''' -python gen_clusters.py -t_value 0.9 -t threshold_0.9.pkl> -''' +``` +python gen_clusters.py -t_value 0.9 -t threshold_0.9.pkl +``` 4) Inject noise into dataset using inject_noise.py (provide a scale value to modify the amplitude of the noise, see --help) -''' -python inject_noise.py -t threshold_0.9.pkl -c1 cf_class_0_cluster0.pkl -c2 cf_class_1_cluster0.pkl -scale 1.0 -r True -d ../nt3.autosave.data.pkl -f cf_failed_inds.pkl -o noise_data> -''' +``` +python inject_noise.py -t threshold_0.9.pkl -c1 cf_class_0_cluster0.pkl -c2 cf_class_1_cluster0.pkl -scale 1.0 -r True -d ../nt3.autosave.data.pkl -f cf_failed_inds.pkl -o noise_data +``` Abstention with counterfactuals: Code located in abstention/ Workflow: 1) Run abstention model with nt3_abstention_keras2_cf.py, pass in a pickle file with X (with noise), y (this is the output of 4) above) 2) For a sweep use run_abstention_sweep.sh -3) To collect metrics (abstention, cluster abstention) run make_csv.py \ No newline at end of file +3) To collect metrics (abstention, cluster abstention) run make_csv.py From 955d879ceffbb3b7e25ea84f87c4599171e62814 Mon Sep 17 00:00:00 2001 From: Ashka Shah Date: Tue, 26 Oct 2021 18:40:29 -0500 Subject: [PATCH 08/12] test_cf_accuracy fixed to work with current version of data --- Pilot1/NT3/nt3_cf/inject_noise.py | 27 +++---- Pilot1/NT3/nt3_cf/test_cf_accuracy.py | 108 +++++++++++++------------- 2 files changed, 67 insertions(+), 68 deletions(-) diff --git a/Pilot1/NT3/nt3_cf/inject_noise.py b/Pilot1/NT3/nt3_cf/inject_noise.py index 1e678b4f..2670d5f1 100644 --- a/Pilot1/NT3/nt3_cf/inject_noise.py +++ b/Pilot1/NT3/nt3_cf/inject_noise.py @@ -19,8 +19,8 @@ def get_args(): # Choose a random set of indices to inject cf noise into def random_noise(s,scale,size, cluster_inds, args): X_train, X_test, y_train, y_test = pickle.load(open(args.d, 'rb')) + #X_data, y_data = pickle.load(open(args.d, 'rb')) X_data = np.concatenate([X_train, X_test]) - y_data = np.concatenate([y_train, y_test]) genes = np.random.choice(np.arange(X_data.shape[0]), replace=False, size=size) noise = np.random.normal(0,1,size) X_data_noise = copy.deepcopy(X_data) @@ -30,7 +30,10 @@ def random_noise(s,scale,size, cluster_inds, args): for i in cluster_inds: for j in range(size): X_data_noise[i][genes[j]]+=noise[j] - pickle.dump([X_data_noise, y_data, [], cluster_inds], open("{}/nt3.data.random.scale_{}_{}.noise_{}.pkl".format(args.o,scale,cluster_name,round(p,1)), "wb")) + # Now split back into train test for output + X_train = X_data_noise[0:(int)(0.8*X_data.shape[0])] + X_test = X_data_noise[(int)(0.8*X_data.shape[0]):] + pickle.dump([X_train, X_test, y_train, y_test, [], cluster_inds], open("{}/nt3.data.random.scale_{}_{}.noise_{}.pkl".format(args.o,scale,cluster_name,round(p,1)), "wb")) def main(): args = get_args() @@ -39,13 +42,14 @@ def main(): os.makedirs(args.o) # For 2 clusters (with sparse injection feature vector) add CF noise to x% of samples X_train, X_test, y_train, y_test = pickle.load(open(args.d, 'rb')) + print(X_train.shape, X_test.shape, y_train.shape, y_test.shape) + #X_data, y_data = pickle.load(open(args.d, 'rb')) threshold_dataset = pickle.load(open(args.t, 'rb')) perturb_dataset = threshold_dataset['perturbation vector'] #combine for easier indexing later X_data = np.concatenate([X_train, X_test]) - y_data = np.concatenate([y_train, y_test]) #account for failed indices failed_indices = pickle.load(open(args.f, 'rb'))[0] @@ -54,7 +58,9 @@ def main(): perturb_dataset.insert(i, np.zeros(X_data.shape[1])) perturb_dataset = np.array(perturb_dataset) - cluster_files = [args.c1, args.c2] + _, cf1 = os.path.split(args.c1) + _, cf2 = os.path.split(args.c2) + cluster_files = [cf1, cf2] for i in range(len(cluster_files)): print(cluster_files[i]) d = pickle.load(open(cluster_files[i], "rb")) @@ -73,11 +79,9 @@ def main(): X_data_noise[selector]-= args.scale*perturb_dataset[selector][:,:,None] # Now split back into train test for output - X_train = X_data[0:(int)(0.8*X_data.shape[0])] - X_test = X_data[(int)(0.8*X_data.shape[0]):] + X_train = X_data_noise[0:(int)(0.8*X_data.shape[0])] + X_test = X_data_noise[(int)(0.8*X_data.shape[0]):] - y_train = y_data[0:(int)(0.8*y_data.shape[0])] - y_test = y_data[(int)(0.8*y_data.shape[0]):] s,_ = cluster_files[i].split(".") cluster_name = s[3:] pickle.dump([X_train, X_test, y_train, y_test, selector, cluster_inds], open("{}/nt3.data.scale_{}_{}.noise_{}.pkl".format(args.o, args.scale,cluster_name, round(p,1)), "wb")) @@ -95,11 +99,8 @@ def main(): X_data_noise_2[selector]-= args.scale*perturb_dataset[selector][:,:,None] # Now split back into train test - X_train = X_data[0:(int)(0.8*X_data.shape[0])] - X_test = X_data[(int)(0.8*X_data.shape[0]):] - - y_train = y_data[0:(int)(0.8*y_data.shape[0])] - y_test = y_data[(int)(0.8*y_data.shape[0]):] + X_train = X_data_noise_2[0:(int)(0.8*X_data.shape[0])] + X_test = X_data_noise_2[(int)(0.8*X_data.shape[0]):] pickle.dump([X_train, X_test, y_train, y_test, selector, cluster_inds], open("{}/nt3.data.threshold.scale_{}_{}.noise_{}.pkl".format(args.o, args.scale, cluster_name, round(p,1)), "wb")) diff --git a/Pilot1/NT3/nt3_cf/test_cf_accuracy.py b/Pilot1/NT3/nt3_cf/test_cf_accuracy.py index d90a834c..4816197b 100644 --- a/Pilot1/NT3/nt3_cf/test_cf_accuracy.py +++ b/Pilot1/NT3/nt3_cf/test_cf_accuracy.py @@ -1,66 +1,64 @@ import tensorflow as tf -tf.get_logger().setLevel(40) # suppress deprecation messages -tf.compat.v1.disable_v2_behavior() # disable TF2 behaviour as alibi code still relies on TF1 constructs from tensorflow.keras.models import Model, load_model import matplotlib.pyplot as plt import numpy as np import os -os.environ["CUDA_VISIBLE_DEVICES"]="1" -from time import time -from alibi.explainers import CounterFactual, CounterFactualProto -print('TF version: ', tf.__version__) -print('Eager execution enabled: ', tf.executing_eagerly()) # False -print(tf.test.is_gpu_available()) import pickle +import argparse +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-m", type=str, help="model file") + parser.add_argument("-prefix", type=str, help="noise file prefix") + parser.add_argument("-prefix_rand", type=str, help="random noise file prefix") + parser.add_argument("-folder", type=str, help="folder path to noise files") + parser.add_argument("-o", type=str, help="name of saved png") + parser.add_argument("-n", type=str, help="name of cluster") + args = parser.parse_args() + return args +def main(): + args = get_args() + model_nt3 = tf.keras.models.load_model(args.m) -model_nt3 = tf.keras.models.load_model('/vol/ml/shahashka/xai-geom/nt3/nt3.autosave.model') -# results = [] -# for i in np.arange(0.1,1.0, 0.1): -# cf_dataset = pickle.load(open("nt3.data.scale_1.0.cluster_0_1.noise_{}.pkl".format(round(i,2)), "rb")) -# X_cf_dataset = cf_dataset[0] -# y_cf_dataset = cf_dataset[1] -# cluster_inds = cf_dataset[-1] -# print(model_nt3.metrics_names) -# acc = model_nt3.evaluate(X_cf_dataset, y_cf_dataset) -# cluster_acc = model_nt3.evaluate(X_cf_dataset[cluster_inds], y_cf_dataset[cluster_inds]) -# print(i, acc, cluster_acc) -# results.append([acc[1], cluster_acc[1]]) -# plt.plot(np.arange(0.1,1.0,0.1), results[:,0], label="accuracy", marker='o') -# plt.plot(np.arange(0.1,1.0, 0.1), results[:,1], label="cluster accuracy", marker='o') + results = [] + for i in np.arange(0.1,1.0, 0.1): + cf_dataset = pickle.load(open("{}_{}.pkl".format(args.prefix, round(i,2)), "rb")) + X_cf_dataset = np.concatenate([cf_dataset[0], cf_dataset[1]]) + y_cf_dataset = np.concatenate([cf_dataset[2], cf_dataset[3]]) + #X_cf_dataset = cf_dataset[0] + #y_cf_dataset = cf_dataset[1] + cluster_inds = cf_dataset[-1] + print(model_nt3.metrics_names) + acc = model_nt3.evaluate(X_cf_dataset, y_cf_dataset) + cluster_acc = model_nt3.evaluate(X_cf_dataset[cluster_inds], y_cf_dataset[cluster_inds]) + print(i, acc, cluster_acc) + results.append([acc[1], cluster_acc[1]]) + results = np.array(results) + plt.plot(np.arange(0.1,1.0,0.1), results[:,0], label="accuracy", marker='o') + plt.plot(np.arange(0.1,1.0, 0.1), results[:,1], label="cluster accuracy", marker='o') -results = [] -for i in np.arange(0.1,1.0, 0.1): - cf_dataset = pickle.load(open("nt3.data.threshold_scale_1.0_cluster_0_1.noise_{}.pkl".format(round(i,2)), "rb")) - X_cf_dataset = cf_dataset[0] - y_cf_dataset = cf_dataset[1] - cluster_inds = cf_dataset[-1] - print(model_nt3.metrics_names) - acc = model_nt3.evaluate(X_cf_dataset, y_cf_dataset) - cluster_acc = model_nt3.evaluate(X_cf_dataset[cluster_inds], y_cf_dataset[cluster_inds]) - print(i, acc, cluster_acc) - results.append([acc[1], cluster_acc[1]]) -results = np.array(results) -plt.plot(np.arange(0.1,1.0,0.1), results[:,0], label="accuracy", marker='o') -plt.plot(np.arange(0.1,1.0, 0.1), results[:,1], label="cluster accuracy", marker='o') + results = [] + for i in np.arange(0.1,1.0, 0.1): + cf_dataset = pickle.load(open("{}_{}.pkl".format(args.prefix_rand, round(i,2)), "rb")) + X_cf_dataset = np.concatenate([cf_dataset[0], cf_dataset[1]]) + y_cf_dataset = np.concatenate([cf_dataset[2], cf_dataset[3]]) + #X_cf_dataset = cf_dataset[0] + #y_cf_dataset = cf_dataset[1] + cluster_inds = cf_dataset[-1] + print(model_nt3.metrics_names) + acc = model_nt3.evaluate(X_cf_dataset, y_cf_dataset) + cluster_acc = model_nt3.evaluate(X_cf_dataset[cluster_inds], y_cf_dataset[cluster_inds]) + print(i, acc, cluster_acc) + results.append([acc[1], cluster_acc[1]]) + results = np.array(results) + plt.plot(np.arange(0.1,1.0,0.1), results[:,0], label="accuracy with Gaussian noise", marker='o') + plt.plot(np.arange(0.1,1.0, 0.1), results[:,1], label="cluster accuracy with Gaussian noise", marker='o') -results = [] -for i in np.arange(0.1,1.0, 0.1): - cf_dataset = pickle.load(open("nt3.data.random.scale_1.0_cluster_0_1.noise_{}.pkl".format(round(i,2)), "rb")) - X_cf_dataset = cf_dataset[0] - y_cf_dataset = cf_dataset[1] - cluster_inds = cf_dataset[-1] - print(model_nt3.metrics_names) - acc = model_nt3.evaluate(X_cf_dataset, y_cf_dataset) - cluster_acc = model_nt3.evaluate(X_cf_dataset[cluster_inds], y_cf_dataset[cluster_inds]) - print(i, acc, cluster_acc) - results.append([acc[1], cluster_acc[1]]) -results = np.array(results) -plt.plot(np.arange(0.1,1.0,0.1), results[:,0], label="accuracy with Gaussian noise", marker='o') -plt.plot(np.arange(0.1,1.0, 0.1), results[:,1], label="cluster accuracy with Gaussian noise", marker='o') + plt.xlabel("Noise fraction in cluster") + plt.ylabel("Accuracy") + plt.legend() + plt.title("Model accuracy with counterfactual noise injection for {}".format(args.n)) + plt.savefig(args.o) -plt.xlabel("Noise fraction in cluster") -plt.ylabel("Accuracy") -plt.legend() -plt.title("Model accuracy with counterfactual noise injection for class 0, cluster 1") -plt.savefig("abstract_plot.png") +if __name__ == "__main__": + main() From 751d28cc767f959c666d467649f8668e70106a74 Mon Sep 17 00:00:00 2001 From: Ashka Shah Date: Thu, 28 Oct 2021 15:43:26 -0500 Subject: [PATCH 09/12] fix bug in inject noise --- Pilot1/NT3/nt3_cf/gen_clusters.py | 42 ++++++++++++++++--------------- Pilot1/NT3/nt3_cf/inject_noise.py | 6 +++-- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/Pilot1/NT3/nt3_cf/gen_clusters.py b/Pilot1/NT3/nt3_cf/gen_clusters.py index d49470bb..2dcd36a7 100644 --- a/Pilot1/NT3/nt3_cf/gen_clusters.py +++ b/Pilot1/NT3/nt3_cf/gen_clusters.py @@ -42,47 +42,49 @@ def get_args(): sil = [] print(len(perturb_vector_0), len(perturb_vector_1)) kmax = np.min([len(perturb_vector_0), len(perturb_vector_1),10]) + data_2D = PCA(20).fit_transform(perturb_vector_0) # dissimilarity would not be defined for a single cluster, thus, minimum number of clusters should be 2 for k in range(2, kmax + 1): print(k) - kmeans = KMeans(n_clusters=k).fit(perturb_vector_0) + kmeans = KMeans(n_clusters=k).fit(data_2D[:,0:2]) labels = kmeans.labels_ - sil.append(silhouette_score(perturb_vector_0, labels, metric='euclidean')) - plt.plot(np.arange(2, kmax+1), sil) - plt.title("Silhouette scores to determine optimal k") - plt.xlabel("k") - plt.show() + sil.append(silhouette_score(data_2D[:,0:2], labels, metric='euclidean')) + #plt.plot(np.arange(2, kmax+1), sil) + #plt.title("Silhouette scores to determine optimal k") + #plt.xlabel("k") + #plt.show() k = np.argmax(sil) + 2 if len(sil) > 0 else kmax print(k) - data_2D = PCA(2).fit_transform(perturb_vector_0) - kmeans_0 = KMeans(n_clusters=k).fit(perturb_vector_0) + #data_2D = PCA(2).fit_transform(perturb_vector_0) + kmeans_0 = KMeans(n_clusters=k).fit(data_2D[:,0:2]) labels_0 = kmeans_0.labels_ colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] for i in range(k): plt.scatter(data_2D[:,0][labels_0==i], data_2D[:,1][labels_0==i], c=colors[i%len(colors)]) - plt.title("CF 0 KMeans clusters with 2D PCA") + plt.title("KMeans clusters with 2D PCA") plt.savefig("CF_0.png") - + k0 = k sil=[] + data_2D = PCA(20).fit_transform(perturb_vector_1) for k in range(2, kmax + 1): - kmeans = KMeans(n_clusters=k).fit(perturb_vector_1) + kmeans = KMeans(n_clusters=k).fit(data_2D[:,0:2])#perturb_vector_1) labels = kmeans.labels_ - sil.append(silhouette_score(perturb_vector_1, labels, metric='euclidean')) - plt.plot(np.arange(2, kmax+1), sil) - plt.title("Silhouette scores to determine optimal k") - plt.xlabel("k") - plt.show() + sil.append(silhouette_score(data_2D[:,0:2], labels, metric='euclidean')) + #plt.plot(np.arange(2, kmax+1), sil) + #plt.title("Silhouette scores to determine optimal k") + #plt.xlabel("k") + #plt.show() k = np.argmax(sil) + 2 if len(sil) > 0 else kmax print(k) - data_2D = PCA(2).fit_transform(perturb_vector_1) - kmeans_1 = KMeans(n_clusters=k).fit(perturb_vector_1) + #data_2D = PCA(2).fit_transform(perturb_vector_1) + kmeans_1 = KMeans(n_clusters=k).fit(data_2D[:,0:2])#perturb_vector_1) labels_1 = kmeans_1.labels_ colors = ['b', 'g', 'r', 'c', 'm', 'y', 'k'] for i in range(k): plt.scatter(data_2D[:,0][labels_1==i], data_2D[:,1][labels_1==i], c=colors[i%len(colors)]) - plt.title("CF 1 KMeans clusters with 2D PCA") - plt.savefig("CF_1.png") + plt.title("Perturbation vectors KMeans clusters with 2D PCA") + plt.savefig("CF 1.png") for i in range(len(kmeans_0.cluster_centers_)): diff_0=kmeans_0.cluster_centers_[i] diff --git a/Pilot1/NT3/nt3_cf/inject_noise.py b/Pilot1/NT3/nt3_cf/inject_noise.py index 2670d5f1..be51b055 100644 --- a/Pilot1/NT3/nt3_cf/inject_noise.py +++ b/Pilot1/NT3/nt3_cf/inject_noise.py @@ -94,8 +94,10 @@ def main(): inds.append(j) X_data_noise_2 = copy.deepcopy(X_data) - for j in inds: - perturb_dataset[:,j]=0 + all_inds = np.arange(X_data.shape[0]) + for j in all_inds: + if j not in inds: + perturb_dataset[:,j]=0 X_data_noise_2[selector]-= args.scale*perturb_dataset[selector][:,:,None] # Now split back into train test From 160c8d9a085d349f82f4d1de8836dc39fafb45f1 Mon Sep 17 00:00:00 2001 From: Ashka Shah Date: Sun, 31 Oct 2021 09:40:56 -0500 Subject: [PATCH 10/12] add sweep files for abstention --- Pilot1/NT3/make_csv.py | 61 +++++++++++++++++++++++++++++ Pilot1/NT3/nt3_abstention_keras2.py | 4 +- Pilot1/NT3/run_abstention_sweep.sh | 8 ++++ 3 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 Pilot1/NT3/make_csv.py create mode 100755 Pilot1/NT3/run_abstention_sweep.sh diff --git a/Pilot1/NT3/make_csv.py b/Pilot1/NT3/make_csv.py new file mode 100644 index 00000000..4b6b1f79 --- /dev/null +++ b/Pilot1/NT3/make_csv.py @@ -0,0 +1,61 @@ +import pandas as pd +import pickle +import argparse +import glob, os +from pathlib import Path +import matplotlib.pyplot as plt + +def get_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-f",type=str, help="Run folder") + parser.add_argument("-c1", type=str, help="cluster 1 name") + parser.add_argument("-c2", type=str, help="cluster 2 name") + args = parser.parse_args() + return args + +def main(): + args = get_args() + l1 = [] + l2 = [] + runs = glob.glob(args.f+"/EXP000/*/") + print(runs) + for r in runs: + print(r) + global_data = pd.read_csv(r+"training.log") + val_abs = global_data['val_abstention'].iloc[-1] + val_abs_acc = global_data['val_abstention_acc'].iloc[-1] + if os.path.exists(r+"cluster_trace.pkl"): + cluster_data = pickle.load(open(r+"cluster_trace.pkl", "rb")) + else: + continue + polluted_abs = cluster_data['Abs polluted'] + val_abs_cluster = cluster_data['Abs val cluster'] + val_abs_acc_cluster = cluster_data['Abs val acc'] + ratio = float(r[-8:-5]) + if args.c1 in r: + l1.append([ratio, val_abs, val_abs_acc, val_abs_cluster, val_abs_acc_cluster, polluted_abs]) + elif args.c2 in r: + l2.append([ratio, val_abs, val_abs_acc, val_abs_cluster, val_abs_acc_cluster, polluted_abs]) + + df1 = pd.DataFrame(l1, columns=['Noise Fraction', 'Val Abs', 'Val Abs Acc', 'Val Abs Cluster', 'Val Abs Acc Cluster', 'Polluted Abs']) + df2 = pd.DataFrame(l2, columns=['Noise Fraction', 'Val Abs', 'Val Abs Acc', 'Val Abs Cluster', 'Val Abs Acc Cluster', 'Polluted Abs']) + print(df1) + df1.to_csv("cluster_1.csv") + df2.to_csv("cluster_2.csv") + plt.plot(df1['Noise Fraction'], df1['Val Abs'], marker='o', label='Val Abs') + plt.plot(df1['Noise Fraction'], df1['Val Abs Acc'], marker='o',label='Val Abs Acc') + plt.plot(df1['Noise Fraction'], df1['Val Abs Cluster'], marker='o',label='Val Abs Cluster') + plt.plot(df1['Noise Fraction'], df1['Val Abs Acc Cluster'], marker='o',label='Val Abs Acc Cluster') + plt.xlabel("Noise fraction") + plt.legend() + plt.savefig('c1.png') + + plt.plot(df2['Noise Fraction'], df2['Val Abs'], marker='o',label='Val Abs') + plt.plot(df2['Noise Fraction'], df2['Val Abs Acc'], marker='o',label='Val Abs Acc') + plt.plot(df2['Noise Fraction'], df2['Val Abs Cluster'], marker='o',label='Val Abs Cluster') + plt.plot(df2['Noise Fraction'], df2['Val Abs Acc Cluster'], marker='o',label='Val Abs Acc Cluster') + plt.xlabel("Noise Fraction") + plt.legend() + plt.savefig('c2.png') +if __name__ == "__main__": + main() diff --git a/Pilot1/NT3/nt3_abstention_keras2.py b/Pilot1/NT3/nt3_abstention_keras2.py index 032730b2..5d47d75b 100644 --- a/Pilot1/NT3/nt3_abstention_keras2.py +++ b/Pilot1/NT3/nt3_abstention_keras2.py @@ -111,7 +111,7 @@ def load_data(train_path, test_path, gParameters): return X_train, Y_train, X_test, Y_test -def evaluate_cf(model, nb_classes, output_dir, y_pred, y, polluted_inds, cluster_inds, gParameters): +def evaluate_cf(model, nb_classes, output_dir, X_train, X_test, Y_train, Y_test, polluted_inds, cluster_inds, gParameters): if len(polluted_inds) > 0: y_pred = model.predict(X_test) abstain_inds = [] @@ -334,7 +334,7 @@ def run(gParameters): score = model.evaluate(X_test, Y_test, verbose=0) if gParameters['noise_cf'] is not None: - evaluate_cf(model, nb_classes, output_dir, y_pred, y, polluted_inds, cluster_inds, gParameters) + evaluate_cf(model, nb_classes, output_dir, X_train, X_test, Y_train, Y_test, polluted_inds, cluster_inds, gParameters) alpha_trace = open(output_dir + "/alpha_trace", "w+") for alpha in abstention_cbk.alphavalues: alpha_trace.write(str(alpha) + '\n') diff --git a/Pilot1/NT3/run_abstention_sweep.sh b/Pilot1/NT3/run_abstention_sweep.sh new file mode 100755 index 00000000..4a217cc3 --- /dev/null +++ b/Pilot1/NT3/run_abstention_sweep.sh @@ -0,0 +1,8 @@ +#!/bin/bash + +vals=0.1 +for filename in /vol/ml/shahashka/temp/Benchmarks/Pilot1/NT3/nt3_cf/noise_all_clusters_t=0.1/nt3.data.threshold.*; do + echo $filename + python nt3_abstention_keras2.py --noise_cf $filename --output_dir cf_sweep_1030 --run_id $(basename $filename) --epochs 100 + #cp cf_sweep_0902/EXP000/RUN000/training.log ${filename}_training_0902.log +done From b1dd09ead20a705c48209f031610f31ead230817 Mon Sep 17 00:00:00 2001 From: Ashka Shah Date: Sun, 31 Oct 2021 09:42:16 -0500 Subject: [PATCH 11/12] fix bug in failed index generation --- Pilot1/NT3/nt3_cf/cf_nb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Pilot1/NT3/nt3_cf/cf_nb.py b/Pilot1/NT3/nt3_cf/cf_nb.py index 736d2847..2cd187b3 100644 --- a/Pilot1/NT3/nt3_cf/cf_nb.py +++ b/Pilot1/NT3/nt3_cf/cf_nb.py @@ -53,4 +53,4 @@ if i%100 == 0 and i is not 0: pickle.dump(results, open("cf_{}.pkl".format(i), "wb")) results = [] -pickle.dump([failed_inds], open("cf_failed_inds.pkl", "wb")) +pickle.dump(failed_inds, open("cf_failed_inds.pkl", "wb")) From 51539afd66e96eafd0dbca4ffde5178de44b3add Mon Sep 17 00:00:00 2001 From: Ashka Shah Date: Fri, 18 Feb 2022 10:31:40 -0600 Subject: [PATCH 12/12] checkpoint --- Pilot1/NT3/nt3_baseline_keras2.py | 2 +- Pilot1/NT3/nt3_cf/inject_noise.py | 42 +++++++++++++++------------ Pilot1/NT3/nt3_cf/test_cf_accuracy.py | 28 ++++++++++++++---- Pilot1/NT3/nt3_noise_model.txt | 1 - Pilot1/NT3/run_abstention_sweep.sh | 17 +++++++---- common/file_utils.py | 1 - common/parsing_utils.py | 3 +- 7 files changed, 60 insertions(+), 34 deletions(-) diff --git a/Pilot1/NT3/nt3_baseline_keras2.py b/Pilot1/NT3/nt3_baseline_keras2.py index efe2e7d4..61e296ee 100644 --- a/Pilot1/NT3/nt3_baseline_keras2.py +++ b/Pilot1/NT3/nt3_baseline_keras2.py @@ -17,7 +17,7 @@ import candle import pickle -def initialize_parameters(default_model='nt3_default_model.txt'): +def initialize_parameters(default_model='nt3_noise_model.txt'): # Build benchmark object nt3Bmk = bmk.BenchmarkNT3( diff --git a/Pilot1/NT3/nt3_cf/inject_noise.py b/Pilot1/NT3/nt3_cf/inject_noise.py index be51b055..b8071579 100644 --- a/Pilot1/NT3/nt3_cf/inject_noise.py +++ b/Pilot1/NT3/nt3_cf/inject_noise.py @@ -20,10 +20,10 @@ def get_args(): def random_noise(s,scale,size, cluster_inds, args): X_train, X_test, y_train, y_test = pickle.load(open(args.d, 'rb')) #X_data, y_data = pickle.load(open(args.d, 'rb')) - X_data = np.concatenate([X_train, X_test]) - genes = np.random.choice(np.arange(X_data.shape[0]), replace=False, size=size) + #X_data = np.concatenate([X_train, X_test]) + genes = np.random.choice(np.arange(X_train.shape[0]), replace=False, size=size) noise = np.random.normal(0,1,size) - X_data_noise = copy.deepcopy(X_data) + X_data_noise = copy.deepcopy(X_train) s, _ = s.split(".") cluster_name = s[3:] for p in np.arange(0.1,1.0, 0.1): @@ -31,9 +31,9 @@ def random_noise(s,scale,size, cluster_inds, args): for j in range(size): X_data_noise[i][genes[j]]+=noise[j] # Now split back into train test for output - X_train = X_data_noise[0:(int)(0.8*X_data.shape[0])] - X_test = X_data_noise[(int)(0.8*X_data.shape[0]):] - pickle.dump([X_train, X_test, y_train, y_test, [], cluster_inds], open("{}/nt3.data.random.scale_{}_{}.noise_{}.pkl".format(args.o,scale,cluster_name,round(p,1)), "wb")) + #X_train = X_data_noise[0:(int)(0.8*X_data.shape[0])] + #X_test = X_data_noise[(int)(0.8*X_data.shape[0]):] + pickle.dump([X_data_noise, X_test, y_train, y_test, [], cluster_inds], open("{}/nt3.data.random.scale_{}_{}.noise_{}.pkl".format(args.o,scale,cluster_name,round(p,1)), "wb")) def main(): args = get_args() @@ -49,42 +49,46 @@ def main(): #combine for easier indexing later - X_data = np.concatenate([X_train, X_test]) + #X_data = np.concatenate([X_train, X_test]) #account for failed indices failed_indices = pickle.load(open(args.f, 'rb'))[0] + failed_indices=[919] print(failed_indices) for i in failed_indices: - perturb_dataset.insert(i, np.zeros(X_data.shape[1])) + perturb_dataset.insert(i, np.zeros(X_train.shape[1])) perturb_dataset = np.array(perturb_dataset) _, cf1 = os.path.split(args.c1) _, cf2 = os.path.split(args.c2) cluster_files = [cf1, cf2] + perturb_dataset = perturb_dataset[0:X_train.shape[0]] for i in range(len(cluster_files)): print(cluster_files[i]) d = pickle.load(open(cluster_files[i], "rb")) cluster_inds = d['sample indices in this cluster'] + cluster_inds_noise = list(filter(lambda val: val < 1120, cluster_inds)) + if args.r: - random_noise(cluster_files[i],args.scale,20, cluster_inds, args) + random_noise(cluster_files[i],args.scale,20, cluster_inds_noise, args) # Sweep through percentages for p in np.arange(0.1,1.0, 0.1): print("p={}".format(p)) - X_data_noise = copy.deepcopy(X_data) + X_data_noise = copy.deepcopy(X_train) #Full cf injection # Choose x% of the indices to be perturbed - selector = np.random.choice(a=cluster_inds, replace=False, size = (int)(p*len(cluster_inds))) + selector = np.random.choice(a=cluster_inds_noise, replace=False, size = (int)(p*len(cluster_inds_noise))) X_data_noise[selector]-= args.scale*perturb_dataset[selector][:,:,None] # Now split back into train test for output - X_train = X_data_noise[0:(int)(0.8*X_data.shape[0])] - X_test = X_data_noise[(int)(0.8*X_data.shape[0]):] + #X_train = X_data_noise[0:(int)(0.8*X_data.shape[0])] + #X_test = X_data_noise[(int)(0.8*X_data.shape[0]):] s,_ = cluster_files[i].split(".") cluster_name = s[3:] - pickle.dump([X_train, X_test, y_train, y_test, selector, cluster_inds], open("{}/nt3.data.scale_{}_{}.noise_{}.pkl".format(args.o, args.scale,cluster_name, round(p,1)), "wb")) + pickle.dump([X_data_noise, X_test, y_train, y_test, selector, cluster_inds], open("{}/nt3.data.scale_{}_{}.noise_{}.pkl".format(args.o, args.scale,cluster_name, round(p,1)), "wb")) # Add cf noise only to those indices that passed the threshold value (instead of the full cf profile) inds = [] @@ -92,19 +96,19 @@ def main(): inds.append(j) for j in d['negative threshold indices'][0]: inds.append(j) - X_data_noise_2 = copy.deepcopy(X_data) + X_data_noise_2 = copy.deepcopy(X_train) - all_inds = np.arange(X_data.shape[0]) + all_inds = np.arange(X_train.shape[0]) for j in all_inds: if j not in inds: perturb_dataset[:,j]=0 X_data_noise_2[selector]-= args.scale*perturb_dataset[selector][:,:,None] # Now split back into train test - X_train = X_data_noise_2[0:(int)(0.8*X_data.shape[0])] - X_test = X_data_noise_2[(int)(0.8*X_data.shape[0]):] + #X_train = X_data_noise_2[0:(int)(0.8*X_data.shape[0])] + #X_test = X_data_noise_2[(int)(0.8*X_data.shape[0]):] - pickle.dump([X_train, X_test, y_train, y_test, selector, cluster_inds], open("{}/nt3.data.threshold.scale_{}_{}.noise_{}.pkl".format(args.o, args.scale, cluster_name, round(p,1)), "wb")) + pickle.dump([X_data_noise_2, X_test, y_train, y_test, selector, cluster_inds], open("{}/nt3.data.threshold.scale_{}_{}.noise_{}.pkl".format(args.o, args.scale, cluster_name, round(p,1)), "wb")) if __name__ == "__main__": main() diff --git a/Pilot1/NT3/nt3_cf/test_cf_accuracy.py b/Pilot1/NT3/nt3_cf/test_cf_accuracy.py index 4816197b..af9b5f19 100644 --- a/Pilot1/NT3/nt3_cf/test_cf_accuracy.py +++ b/Pilot1/NT3/nt3_cf/test_cf_accuracy.py @@ -10,6 +10,7 @@ def get_args(): parser.add_argument("-m", type=str, help="model file") parser.add_argument("-prefix", type=str, help="noise file prefix") parser.add_argument("-prefix_rand", type=str, help="random noise file prefix") + parser.add_argument("-prefix_rand_cf", type=str, help="random noise along cf indices") parser.add_argument("-folder", type=str, help="folder path to noise files") parser.add_argument("-o", type=str, help="name of saved png") parser.add_argument("-n", type=str, help="name of cluster") @@ -33,8 +34,8 @@ def main(): print(i, acc, cluster_acc) results.append([acc[1], cluster_acc[1]]) results = np.array(results) - plt.plot(np.arange(0.1,1.0,0.1), results[:,0], label="accuracy", marker='o') - plt.plot(np.arange(0.1,1.0, 0.1), results[:,1], label="cluster accuracy", marker='o') +# plt.plot(np.arange(0.1,1.0,0.1), results[:,0], label="full dataset accuracy with cf pertubation", marker='o') + plt.plot(np.arange(0.1,1.0, 0.1), results[:,1], label="cluster accuracy with cf perturbation", marker='o') results = [] for i in np.arange(0.1,1.0, 0.1): @@ -50,14 +51,31 @@ def main(): print(i, acc, cluster_acc) results.append([acc[1], cluster_acc[1]]) results = np.array(results) - plt.plot(np.arange(0.1,1.0,0.1), results[:,0], label="accuracy with Gaussian noise", marker='o') - plt.plot(np.arange(0.1,1.0, 0.1), results[:,1], label="cluster accuracy with Gaussian noise", marker='o') +# plt.plot(np.arange(0.1,1.0,0.1), results[:,0], label="full dataset accuracy with Gaussian noise (rand indices)", marker='o') + plt.plot(np.arange(0.1,1.0, 0.1), results[:,1], label="cluster accuracy with Gaussian noise (random features)", marker='o') + + results = [] + for i in np.arange(0.1,1.0, 0.1): + cf_dataset = pickle.load(open("{}_{}.pkl".format(args.prefix_rand_cf, round(i,2)), "rb")) + X_cf_dataset = np.concatenate([cf_dataset[0], cf_dataset[1]]) + y_cf_dataset = np.concatenate([cf_dataset[2], cf_dataset[3]]) + #X_cf_dataset = cf_dataset[0] + #y_cf_dataset = cf_dataset[1] + cluster_inds = cf_dataset[-1] + print(model_nt3.metrics_names) + acc = model_nt3.evaluate(X_cf_dataset, y_cf_dataset) + cluster_acc = model_nt3.evaluate(X_cf_dataset[cluster_inds], y_cf_dataset[cluster_inds]) + print(i, acc, cluster_acc) + results.append([acc[1], cluster_acc[1]]) + results = np.array(results) +# plt.plot(np.arange(0.1,1.0,0.1), results[:,0], label="full dataset accuracy with Gaussian noise (cf indices)", marker='o') + plt.plot(np.arange(0.1,1.0, 0.1), results[:,1], label="cluster accuracy with Gaussian noise (cf features)", marker='o') plt.xlabel("Noise fraction in cluster") plt.ylabel("Accuracy") plt.legend() - plt.title("Model accuracy with counterfactual noise injection for {}".format(args.n)) + plt.title("Model accuracy with counterfactual noise injection") plt.savefig(args.o) if __name__ == "__main__": diff --git a/Pilot1/NT3/nt3_noise_model.txt b/Pilot1/NT3/nt3_noise_model.txt index e4d929fa..261ec19c 100644 --- a/Pilot1/NT3/nt3_noise_model.txt +++ b/Pilot1/NT3/nt3_noise_model.txt @@ -34,4 +34,3 @@ init_abs_epoch = 2 task_list = 0 task_names = ['activation_5'] noise_save_cf = True -noise_cf = '/vol/ml/shahashka/nt3_cf/cf_data/data_with_cf_noise/nt3.data.scale_1.0_cluster_0_1.noise_0.1.pkl' diff --git a/Pilot1/NT3/run_abstention_sweep.sh b/Pilot1/NT3/run_abstention_sweep.sh index 4a217cc3..7978ecc1 100755 --- a/Pilot1/NT3/run_abstention_sweep.sh +++ b/Pilot1/NT3/run_abstention_sweep.sh @@ -1,8 +1,15 @@ #!/bin/bash -vals=0.1 -for filename in /vol/ml/shahashka/temp/Benchmarks/Pilot1/NT3/nt3_cf/noise_all_clusters_t=0.1/nt3.data.threshold.*; do - echo $filename - python nt3_abstention_keras2.py --noise_cf $filename --output_dir cf_sweep_1030 --run_id $(basename $filename) --epochs 100 - #cp cf_sweep_0902/EXP000/RUN000/training.log ${filename}_training_0902.log +#vals=0.1 +#for filename in /vol/ml/shahashka/temp/Benchmarks/Pilot1/NT3/nt3_cf/noise_both_clusters/nt3.data.threshold.*; do +# echo $filename +# python nt3_abstention_keras2.py --noise_cf $filename --output_dir cf_sweep_1104 --run_id $(basename $filename) --epochs 100 +# #cp cf_sweep_0902/EXP000/RUN000/training.log ${filename}_training_0902.log +#done + +for i in $(seq 0 0.1 1); do + echo $i + for j in $(seq 1 1 5); do + python nt3_baseline_keras2.py --label_noise $i --output_dir baseline_label_noise_$i --run_id RUN$j + done done diff --git a/common/file_utils.py b/common/file_utils.py index a1cfdb0b..b11bf90d 100644 --- a/common/file_utils.py +++ b/common/file_utils.py @@ -204,7 +204,6 @@ def directory_from_parameters(params, commonroot='Output'): String to specify the common folder to store results. """ - if commonroot in set(['.', './']): # Same directory --> convert to absolute path outdir = os.path.abspath('.') else: # Create path specified diff --git a/common/parsing_utils.py b/common/parsing_utils.py index 27006000..e495dda3 100644 --- a/common/parsing_utils.py +++ b/common/parsing_utils.py @@ -386,7 +386,7 @@ class ArgumentStruct: or object entries) can be used. """ def __init__(self, **entries): - self.__dict__.update(entries) + self.__dict__.update(entries) class ListOfListsAction(argparse.Action): @@ -573,7 +573,6 @@ def args_overwrite_config(args, config): for key in args_dict.keys(): # try casting here params[key] = args_dict[key] - if 'data_type' not in params: params['data_type'] = DEFAULT_DATATYPE else: