|
| 1 | +''' |
| 2 | +######################################################### |
| 3 | +This file trains a deep learning model to predict |
| 4 | +disruptions on time series data from plasma discharges. |
| 5 | +
|
| 6 | +Must run guarantee_preprocessed.py in order for this to work. |
| 7 | +
|
| 8 | +Dependencies: |
| 9 | +conf.py: configuration of model,training,paths, and data |
| 10 | +model_builder.py: logic to construct the ML architecture |
| 11 | +data_processing.py: classes to handle data processing |
| 12 | +
|
| 13 | +Author: Julian Kates-Harbeck, [email protected] |
| 14 | +
|
| 15 | +This work was supported by the DOE CSGF program. |
| 16 | +######################################################### |
| 17 | +''' |
| 18 | + |
| 19 | +from __future__ import print_function |
| 20 | +import os |
| 21 | +import sys |
| 22 | +import time |
| 23 | +import datetime |
| 24 | +import random |
| 25 | +import numpy as np |
| 26 | +import copy |
| 27 | +from functools import partial |
| 28 | + |
| 29 | +os.environ["PYTHONHASHSEED"] = "0" |
| 30 | + |
| 31 | +import matplotlib |
| 32 | +matplotlib.use('Agg') |
| 33 | + |
| 34 | +from pprint import pprint |
| 35 | +sys.setrecursionlimit(10000) |
| 36 | + |
| 37 | +from plasma.conf import conf |
| 38 | +from plasma.models.loader import Loader |
| 39 | +from plasma.primitives.shots import ShotList |
| 40 | +from plasma.preprocessor.normalize import Normalizer |
| 41 | +from plasma.preprocessor.augment import ByShotAugmentator |
| 42 | +from plasma.preprocessor.preprocess import guarantee_preprocessed |
| 43 | + |
| 44 | +if conf['model']['shallow']: |
| 45 | + print("Shallow learning using MPI is not supported yet. set conf['model']['shallow'] to false.") |
| 46 | + exit(1) |
| 47 | +if conf['data']['normalizer'] == 'minmax': |
| 48 | + from plasma.preprocessor.normalize import MinMaxNormalizer as Normalizer |
| 49 | +elif conf['data']['normalizer'] == 'meanvar': |
| 50 | + from plasma.preprocessor.normalize import MeanVarNormalizer as Normalizer |
| 51 | +elif conf['data']['normalizer'] == 'var': |
| 52 | + from plasma.preprocessor.normalize import VarNormalizer as Normalizer #performs !much better than minmaxnormalizer |
| 53 | +elif conf['data']['normalizer'] == 'averagevar': |
| 54 | + from plasma.preprocessor.normalize import AveragingVarNormalizer as Normalizer #performs !much better than minmaxnormalizer |
| 55 | +else: |
| 56 | + print('unkown normalizer. exiting') |
| 57 | + exit(1) |
| 58 | + |
| 59 | +from mpi4py import MPI |
| 60 | +comm = MPI.COMM_WORLD |
| 61 | +task_index = comm.Get_rank() |
| 62 | +num_workers = comm.Get_size() |
| 63 | +NUM_GPUS = conf['num_gpus'] |
| 64 | +MY_GPU = task_index % NUM_GPUS |
| 65 | + |
| 66 | +from plasma.models.mpi_runner import * |
| 67 | + |
| 68 | +np.random.seed(task_index) |
| 69 | +random.seed(task_index) |
| 70 | +if task_index == 0: |
| 71 | + pprint(conf) |
| 72 | + |
| 73 | +only_predict = len(sys.argv) > 1 |
| 74 | +custom_path = None |
| 75 | +if only_predict: |
| 76 | + custom_path = sys.argv[1] |
| 77 | +print("predicting using path {}".format(custom_path)) |
| 78 | + |
| 79 | +assert(only_predict) |
| 80 | +##################################################### |
| 81 | +####################Normalization#################### |
| 82 | +##################################################### |
| 83 | +if task_index == 0: #make sure preprocessing has been run, and is saved as a file |
| 84 | + shot_list_train,shot_list_validate,shot_list_test = guarantee_preprocessed(conf) |
| 85 | +comm.Barrier() |
| 86 | +shot_list_train,shot_list_validate,shot_list_test = guarantee_preprocessed(conf) |
| 87 | + |
| 88 | + |
| 89 | +def chunks(l, n): |
| 90 | + """Yield successive n-sized chunks from l.""" |
| 91 | + return[ l[i:i + n] for i in range(0, len(l), n)] |
| 92 | + |
| 93 | +def hide_signal_data(shot,t=0,sigs_to_hide=None): |
| 94 | + for sig in shot.signals: |
| 95 | + if sigs_to_hide is None or (sigs_to_hide is not None and sig in sigs_to_hide): |
| 96 | + shot.signals_dict[sig][t:,:] = shot.signals_dict[sig][t,:] |
| 97 | + |
| 98 | +def create_shot_list_tmp(original_shot,time_points,sigs=None): |
| 99 | + shot_list_tmp = ShotList() |
| 100 | + T = len(original_shot.ttd) |
| 101 | + t_range = np.linspace(0,T-1,time_points,dtype=np.int) |
| 102 | + for t in t_range: |
| 103 | + new_shot = copy.copy(original_shot) |
| 104 | + assert(new_shot.augmentation_fn == None) |
| 105 | + new_shot.augmentation_fn = partial(hide_signal_data,t = t,sigs_to_hide=sigs) |
| 106 | + #new_shot.number = original_shot.number |
| 107 | + shot_list_tmp.append(new_shot) |
| 108 | + return shot_list_tmp,t_range |
| 109 | + |
| 110 | +def get_importance_measure(original_shot,loader,custom_path,metric,time_points=10,sig=None): |
| 111 | + shot_list_tmp,t_range = create_shot_list_tmp(original_shot,time_points,sigs) |
| 112 | + y_prime,y_gold,disruptive = mpi_make_predictions(conf,shot_list_tmp,loader,custom_path) |
| 113 | + shot_list_tmp.make_light() |
| 114 | + return t_range,get_importance_measure_given_y_prime(y_prime,metric),y_prime[-1] |
| 115 | + |
| 116 | +def difference_metric(y_prime,y_prime_orig): |
| 117 | + idx = np.argmax(y_prime_orig) |
| 118 | + return (np.max(y_prime_orig) - y_prime[idx])/(np.max(y_prime_orig) - np.min(y_prime_orig)) |
| 119 | + |
| 120 | +def get_importance_measure_given_y_prime(y_prime,metric): |
| 121 | + differences = [metric(y_prime[i],y_prime[-1]) for i in range(len(y_prime))] |
| 122 | + return 1.0-np.array(differences)#/np.max(differences) |
| 123 | + |
| 124 | + |
| 125 | +print("normalization",end='') |
| 126 | +normalizer = Normalizer(conf) |
| 127 | +normalizer.train() |
| 128 | +normalizer = ByShotAugmentator(normalizer) |
| 129 | +loader = Loader(conf,normalizer) |
| 130 | +print("...done") |
| 131 | + |
| 132 | +# if not only_predict: |
| 133 | +# mpi_train(conf,shot_list_train,shot_list_validate,loader) |
| 134 | + |
| 135 | +#load last model for testing |
| 136 | +loader.set_inference_mode(True) |
| 137 | +use_signals = copy.copy(conf['paths']['use_signals']) |
| 138 | +use_signals.append(None) |
| 139 | + |
| 140 | + |
| 141 | + |
| 142 | +for shot in shot_list_test: |
| 143 | + shot.augmentation_fn = None# partial(hide_signal_data,t = 0,sigs_to_hide = sigs_to_hide) |
| 144 | + |
| 145 | +print("All signals:") |
| 146 | +y_prime,y_gold,disruptive,roc,loss = mpi_make_predictions_and_evaluate(conf,shot_list_test,loader,custom_path) |
| 147 | +print(roc) |
| 148 | +print(loss) |
| 149 | + |
| 150 | +#for sigs_to_hide in [[s] for s in use_signals[:-3]] + [use_signals[-3:-1]] + [use_signals[-1]]: |
| 151 | +for sigs_to_hide in [[s] for s in use_signals[:-3]] + [[s] for s in use_signals[-3:-1]] + [use_signals[-3:-1]]:# + [use_signals[-1]]: |
| 152 | + for shot in shot_list_test: |
| 153 | + shot.augmentation_fn = partial(hide_signal_data,t = 0,sigs_to_hide = sigs_to_hide) |
| 154 | + print("Hiding: {}".format(sigs_to_hide)) |
| 155 | + y_prime,y_gold,disruptive,roc,loss = mpi_make_predictions_and_evaluate(conf,shot_list_test,loader,custom_path) |
| 156 | + print(roc) |
| 157 | + print(loss) |
| 158 | + |
| 159 | + |
| 160 | + |
| 161 | +if task_index == 0: |
| 162 | + print('finished.') |
0 commit comments