Skip to content

Commit d7f26de

Browse files
author
Julian Kates-Harbeck
committed
signal importance through occlusion
1 parent 5890a1a commit d7f26de

File tree

1 file changed

+162
-0
lines changed

1 file changed

+162
-0
lines changed

examples/simple_augmentation.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
'''
2+
#########################################################
3+
This file trains a deep learning model to predict
4+
disruptions on time series data from plasma discharges.
5+
6+
Must run guarantee_preprocessed.py in order for this to work.
7+
8+
Dependencies:
9+
conf.py: configuration of model,training,paths, and data
10+
model_builder.py: logic to construct the ML architecture
11+
data_processing.py: classes to handle data processing
12+
13+
Author: Julian Kates-Harbeck, [email protected]
14+
15+
This work was supported by the DOE CSGF program.
16+
#########################################################
17+
'''
18+
19+
from __future__ import print_function
20+
import os
21+
import sys
22+
import time
23+
import datetime
24+
import random
25+
import numpy as np
26+
import copy
27+
from functools import partial
28+
29+
os.environ["PYTHONHASHSEED"] = "0"
30+
31+
import matplotlib
32+
matplotlib.use('Agg')
33+
34+
from pprint import pprint
35+
sys.setrecursionlimit(10000)
36+
37+
from plasma.conf import conf
38+
from plasma.models.loader import Loader
39+
from plasma.primitives.shots import ShotList
40+
from plasma.preprocessor.normalize import Normalizer
41+
from plasma.preprocessor.augment import ByShotAugmentator
42+
from plasma.preprocessor.preprocess import guarantee_preprocessed
43+
44+
if conf['model']['shallow']:
45+
print("Shallow learning using MPI is not supported yet. set conf['model']['shallow'] to false.")
46+
exit(1)
47+
if conf['data']['normalizer'] == 'minmax':
48+
from plasma.preprocessor.normalize import MinMaxNormalizer as Normalizer
49+
elif conf['data']['normalizer'] == 'meanvar':
50+
from plasma.preprocessor.normalize import MeanVarNormalizer as Normalizer
51+
elif conf['data']['normalizer'] == 'var':
52+
from plasma.preprocessor.normalize import VarNormalizer as Normalizer #performs !much better than minmaxnormalizer
53+
elif conf['data']['normalizer'] == 'averagevar':
54+
from plasma.preprocessor.normalize import AveragingVarNormalizer as Normalizer #performs !much better than minmaxnormalizer
55+
else:
56+
print('unkown normalizer. exiting')
57+
exit(1)
58+
59+
from mpi4py import MPI
60+
comm = MPI.COMM_WORLD
61+
task_index = comm.Get_rank()
62+
num_workers = comm.Get_size()
63+
NUM_GPUS = conf['num_gpus']
64+
MY_GPU = task_index % NUM_GPUS
65+
66+
from plasma.models.mpi_runner import *
67+
68+
np.random.seed(task_index)
69+
random.seed(task_index)
70+
if task_index == 0:
71+
pprint(conf)
72+
73+
only_predict = len(sys.argv) > 1
74+
custom_path = None
75+
if only_predict:
76+
custom_path = sys.argv[1]
77+
print("predicting using path {}".format(custom_path))
78+
79+
assert(only_predict)
80+
#####################################################
81+
####################Normalization####################
82+
#####################################################
83+
if task_index == 0: #make sure preprocessing has been run, and is saved as a file
84+
shot_list_train,shot_list_validate,shot_list_test = guarantee_preprocessed(conf)
85+
comm.Barrier()
86+
shot_list_train,shot_list_validate,shot_list_test = guarantee_preprocessed(conf)
87+
88+
89+
def chunks(l, n):
90+
"""Yield successive n-sized chunks from l."""
91+
return[ l[i:i + n] for i in range(0, len(l), n)]
92+
93+
def hide_signal_data(shot,t=0,sigs_to_hide=None):
94+
for sig in shot.signals:
95+
if sigs_to_hide is None or (sigs_to_hide is not None and sig in sigs_to_hide):
96+
shot.signals_dict[sig][t:,:] = shot.signals_dict[sig][t,:]
97+
98+
def create_shot_list_tmp(original_shot,time_points,sigs=None):
99+
shot_list_tmp = ShotList()
100+
T = len(original_shot.ttd)
101+
t_range = np.linspace(0,T-1,time_points,dtype=np.int)
102+
for t in t_range:
103+
new_shot = copy.copy(original_shot)
104+
assert(new_shot.augmentation_fn == None)
105+
new_shot.augmentation_fn = partial(hide_signal_data,t = t,sigs_to_hide=sigs)
106+
#new_shot.number = original_shot.number
107+
shot_list_tmp.append(new_shot)
108+
return shot_list_tmp,t_range
109+
110+
def get_importance_measure(original_shot,loader,custom_path,metric,time_points=10,sig=None):
111+
shot_list_tmp,t_range = create_shot_list_tmp(original_shot,time_points,sigs)
112+
y_prime,y_gold,disruptive = mpi_make_predictions(conf,shot_list_tmp,loader,custom_path)
113+
shot_list_tmp.make_light()
114+
return t_range,get_importance_measure_given_y_prime(y_prime,metric),y_prime[-1]
115+
116+
def difference_metric(y_prime,y_prime_orig):
117+
idx = np.argmax(y_prime_orig)
118+
return (np.max(y_prime_orig) - y_prime[idx])/(np.max(y_prime_orig) - np.min(y_prime_orig))
119+
120+
def get_importance_measure_given_y_prime(y_prime,metric):
121+
differences = [metric(y_prime[i],y_prime[-1]) for i in range(len(y_prime))]
122+
return 1.0-np.array(differences)#/np.max(differences)
123+
124+
125+
print("normalization",end='')
126+
normalizer = Normalizer(conf)
127+
normalizer.train()
128+
normalizer = ByShotAugmentator(normalizer)
129+
loader = Loader(conf,normalizer)
130+
print("...done")
131+
132+
# if not only_predict:
133+
# mpi_train(conf,shot_list_train,shot_list_validate,loader)
134+
135+
#load last model for testing
136+
loader.set_inference_mode(True)
137+
use_signals = copy.copy(conf['paths']['use_signals'])
138+
use_signals.append(None)
139+
140+
141+
142+
for shot in shot_list_test:
143+
shot.augmentation_fn = None# partial(hide_signal_data,t = 0,sigs_to_hide = sigs_to_hide)
144+
145+
print("All signals:")
146+
y_prime,y_gold,disruptive,roc,loss = mpi_make_predictions_and_evaluate(conf,shot_list_test,loader,custom_path)
147+
print(roc)
148+
print(loss)
149+
150+
#for sigs_to_hide in [[s] for s in use_signals[:-3]] + [use_signals[-3:-1]] + [use_signals[-1]]:
151+
for sigs_to_hide in [[s] for s in use_signals[:-3]] + [[s] for s in use_signals[-3:-1]] + [use_signals[-3:-1]]:# + [use_signals[-1]]:
152+
for shot in shot_list_test:
153+
shot.augmentation_fn = partial(hide_signal_data,t = 0,sigs_to_hide = sigs_to_hide)
154+
print("Hiding: {}".format(sigs_to_hide))
155+
y_prime,y_gold,disruptive,roc,loss = mpi_make_predictions_and_evaluate(conf,shot_list_test,loader,custom_path)
156+
print(roc)
157+
print(loss)
158+
159+
160+
161+
if task_index == 0:
162+
print('finished.')

0 commit comments

Comments
 (0)