Skip to content

Commit aba3b50

Browse files
Add tooling for feature importance (#109)
* Laid groundwork for feature importance stuff * Got stuff initially working * Cleaned up comments/functionality * Worked on some stuff * Probably finished up feature importance tool for grabbing shap values * Fixed CI (probably) * Added shap to development dependencies * Moved requirements to a more appropriate place * Got plots initially working * Minor changes, fix CI * Added functionality to collapse along LRs in the regalloc case * Added typing info to main feature_importance script * Added typing annotations * Added unit tests for feature_importance script * Fix pytest absl flag errors * Added in inline documentation * Added documentation on graphing shap data in an IPython notebook * Added notes on model path * Added missing import * Removed inline inner functions and added documentation to functions I missed * Refactored utilities to feature_importance_utils.py Co-authored-by: Yundi Qian <[email protected]>
1 parent fc26cfc commit aba3b50

9 files changed

+734
-3
lines changed
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""A tool for analyzing which features a model uses to make a decision.
16+
17+
This script allows for processing a set of examples generated from a trace
18+
created through generate_default_trace into a set of shap values which
19+
represent how much that specific feature contributes to the final output of
20+
the model. These values can then be imported into an IPython notebook and
21+
graphed with the help of the feature_importance_graphs.py module in the same
22+
folder.
23+
24+
Usage:
25+
PYTHONPATH=$PYTHONPATH:. python3 compiler_opt/tools/feature_importance.py \
26+
--gin_files=compiler_opt/rl/regalloc/gin_configs/common.gin \
27+
--gin_bindings=config_registry.get_configuration.implementation=\
28+
@configs.RegallocEvictionConfig \
29+
--data_path=/default_trace \
30+
--model_path=/warmstart/saved_policy \
31+
--num_examples=5 \
32+
--output_file=./explanation_data.json
33+
34+
The type of trace that is performed (ie if it is just tracing the default
35+
heuristic or if it is a trace of a ML model) doesn't matter as the only data
36+
that matters re the input features. The num_examples flag sets the number of
37+
examples that get processed into shap values. Increasing this value will
38+
potentially allow you to reach better conclusions depending upon how you're
39+
viewing the data, but increasing it will also increase the runtime of this
40+
script quite significantly as the process is not multithreaded.
41+
"""
42+
43+
from absl import app
44+
from absl import flags
45+
from absl import logging
46+
import gin
47+
48+
from compiler_opt.rl import data_reader
49+
from compiler_opt.rl import constant
50+
from compiler_opt.rl import registry
51+
52+
from compiler_opt.tools import feature_importance_utils
53+
54+
import tensorflow as tf
55+
import shap
56+
import numpy
57+
import numpy.typing
58+
import json
59+
60+
_DATA_PATH = flags.DEFINE_multi_string(
61+
'data_path', [], 'Path to TFRecord file(s) containing trace data.')
62+
_MODEL_PATH = flags.DEFINE_string('model_path', '',
63+
'Path to the model to explain')
64+
_OUTPUT_FILE = flags.DEFINE_string(
65+
'output_file', '', 'The path to the output file containing the SHAP values')
66+
_NUM_EXAMPLES = flags.DEFINE_integer(
67+
'num_examples', 1, 'The number of examples to process from the trace')
68+
_GIN_FILES = flags.DEFINE_multi_string(
69+
'gin_files', [], 'List of paths to gin configuration files.')
70+
_GIN_BINDINGS = flags.DEFINE_multi_string(
71+
'gin_bindings', [],
72+
'Gin bindings to override the values set in the config files.')
73+
74+
75+
def main(_):
76+
gin.parse_config_files_and_bindings(
77+
_GIN_FILES.value, bindings=_GIN_BINDINGS.value, skip_unknown=False)
78+
logging.info(gin.config_str())
79+
80+
problem_config = registry.get_configuration()
81+
time_step_spec, action_spec = problem_config.get_signature_spec()
82+
83+
tfrecord_dataset_fn = data_reader.create_tfrecord_dataset_fn(
84+
agent_name=constant.AgentName.BEHAVIORAL_CLONE,
85+
time_step_spec=time_step_spec,
86+
action_spec=action_spec,
87+
batch_size=1,
88+
train_sequence_length=1)
89+
90+
dataset_iter = iter(tfrecord_dataset_fn(_DATA_PATH.value).repeat())
91+
92+
raw_trajectory = next(dataset_iter)
93+
94+
saved_policy = tf.saved_model.load(_MODEL_PATH.value)
95+
action_fn = saved_policy.signatures['action']
96+
97+
observation = feature_importance_utils.process_raw_trajectory(raw_trajectory)
98+
input_sig = feature_importance_utils.get_input_signature(observation)
99+
100+
run_model = feature_importance_utils.create_run_model_function(
101+
action_fn, input_sig)
102+
103+
total_size = feature_importance_utils.get_signature_total_size(input_sig)
104+
flattened_input = feature_importance_utils.flatten_input(
105+
observation, total_size)
106+
flattened_input = numpy.expand_dims(flattened_input, axis=0)
107+
dataset = numpy.empty((_NUM_EXAMPLES.value, total_size))
108+
for i in range(0, _NUM_EXAMPLES.value):
109+
raw_trajectory = next(dataset_iter)
110+
observation = feature_importance_utils.process_raw_trajectory(
111+
raw_trajectory)
112+
flat_input = feature_importance_utils.flatten_input(observation, total_size)
113+
dataset[i] = flat_input
114+
115+
explainer = shap.KernelExplainer(run_model, numpy.zeros((1, total_size)))
116+
shap_values = explainer.shap_values(dataset, nsamples=1000)
117+
processed_shap_values = feature_importance_utils.collapse_values(
118+
input_sig, shap_values, _NUM_EXAMPLES.value)
119+
120+
# if we have more than one value per feature, just set the dataset to zeros
121+
# as summing across a dimension produces data that doesn't really mean
122+
# anything
123+
if feature_importance_utils.get_max_part_size(input_sig) > 1:
124+
dataset = numpy.zeros(processed_shap_values.shape)
125+
126+
feature_names = list(input_sig.keys())
127+
128+
output_file_data = {
129+
'expected_values': explainer.expected_value,
130+
'shap_values': processed_shap_values.tolist(),
131+
'data': dataset.tolist(),
132+
'feature_names': feature_names
133+
}
134+
135+
with open(_OUTPUT_FILE.value, 'w', encoding='utf-8') as output_file:
136+
json.dump(output_file_data, output_file)
137+
138+
139+
if __name__ == '__main__':
140+
app.run(main)
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# coding=utf-8
2+
# Copyright 2020 Google LLC
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
"""A module that allows for easily graphing feature importance data in
16+
notebooks"""
17+
18+
import numpy
19+
import numpy.typing
20+
import shap
21+
import json
22+
23+
from typing import Dict, List, Union, Optional
24+
25+
DataType = Dict[str, Union[numpy.typing.ArrayLike, List[str]]]
26+
27+
28+
def load_shap_values(file_name: str) -> DataType:
29+
"""Loads a set of shap values created with the feature_importance.py script
30+
into a dictionary that can then be used for creating graphs
31+
32+
Args:
33+
file_name: The name of the file in which the shap values are stored. What
34+
the --output_path flag was set to in the feature importance script.
35+
"""
36+
with open(file_name, encoding='utf-8') as file_to_load:
37+
data = json.load(file_to_load)
38+
if data['expected_values'] is not list:
39+
data['expected_values'] = [data['expected_values']]
40+
return {
41+
'expected_values': numpy.asarray(data['expected_values']),
42+
'shap_values': numpy.asarray(data['shap_values']),
43+
'data': numpy.asarray(data['data']),
44+
'feature_names': data['feature_names']
45+
}
46+
47+
48+
def init_shap_for_notebook():
49+
"""Initalizes some JS code for interactive feature importance plots."""
50+
shap.initjs()
51+
52+
53+
def graph_individual_example(data: DataType, index: Optional[int]):
54+
"""Creates a force plot for an example
55+
56+
Args:
57+
data: An object containing all the shap values and other information
58+
necessary to create the plot. Should be created with load_shap_values.
59+
index: The index of the example that you wish to plot.
60+
"""
61+
return shap.force_plot(
62+
data['expected_values'],
63+
data['shap_values'][index, :],
64+
data['data'][index, :],
65+
feature_names=data['feature_names'])
66+
67+
68+
def graph_summary_plot(data: DataType):
69+
"""Creates a summary plot of the entire dataset given
70+
71+
Args:
72+
data: An object containing all the shap values necessary to create the
73+
plot. Should come from load_shap_values
74+
"""
75+
return shap.summary_plot(
76+
data['shap_values'], data['data'], feature_names=data['feature_names'])

0 commit comments

Comments
 (0)