Skip to content

Commit 834ee40

Browse files
committed
Filled out preserve method, including error handling. Refactor on spec test.
1 parent d16f7af commit 834ee40

File tree

5 files changed

+107
-52
lines changed

5 files changed

+107
-52
lines changed

src/python/interpret/test/test_develop.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,3 @@ def test_print_debug_info():
1515
# Very light check, just testing if the function runs.
1616
print_debug_info()
1717
assert 1 == 1
18-

src/python/interpret/test/test_explainers.py

Lines changed: 4 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,15 @@
11
# Copyright (c) 2019 Microsoft Corporation
22
# Distributed under the MIT software license
33

4-
from ..data import ClassHistogram
5-
from ..perf import ROC, RegressionPerf
64

7-
from ..blackbox import LimeTabular
8-
from ..blackbox import ShapKernel
9-
from ..blackbox import MorrisSensitivity
10-
from ..blackbox import PartialDependence
11-
12-
from ..glassbox import LogisticRegression, LinearRegression
13-
from ..glassbox import ClassificationTree, RegressionTree
14-
from ..glassbox import DecisionListClassifier
15-
from ..glassbox import ExplainableBoostingClassifier, ExplainableBoostingRegressor
16-
17-
from .utils import synthetic_classification
5+
from .utils import synthetic_classification, get_all_explainers
186
from .utils import assert_valid_explanation, assert_valid_model_explainer
197

8+
from ..glassbox import LogisticRegression
209

21-
def test_spec_synthetic():
22-
data_explainer_classes = [ClassHistogram]
23-
perf_explainer_classes = [ROC, RegressionPerf]
24-
model_explainer_classes = [
25-
ClassificationTree,
26-
DecisionListClassifier,
27-
LogisticRegression,
28-
ExplainableBoostingClassifier,
29-
RegressionTree,
30-
LinearRegression,
31-
ExplainableBoostingRegressor,
32-
]
33-
blackbox_explainer_classes = [
34-
LimeTabular,
35-
ShapKernel,
36-
MorrisSensitivity,
37-
PartialDependence,
38-
]
39-
all_explainers = []
40-
all_explainers.extend(model_explainer_classes)
41-
all_explainers.extend(blackbox_explainer_classes)
42-
all_explainers.extend(data_explainer_classes)
43-
all_explainers.extend(perf_explainer_classes)
4410

11+
def test_spec_synthetic():
12+
all_explainers = get_all_explainers()
4513
data = synthetic_classification()
4614
blackbox = LogisticRegression()
4715
blackbox.fit(data["train"]["X"], data["train"]["y"])

src/python/interpret/test/test_interactive.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright (c) 2019 Microsoft Corporation
22
# Distributed under the MIT software license
3+
# TODO: Testing for show/snap functions.
34

45
from ..visual.interactive import set_show_addr, get_show_addr, shutdown_show_server
56

src/python/interpret/test/utils.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,19 @@
11
# Copyright (c) 2019 Microsoft Corporation
22
# Distributed under the MIT software license
33

4+
from ..data import ClassHistogram
5+
from ..perf import ROC, RegressionPerf
6+
7+
from ..blackbox import LimeTabular
8+
from ..blackbox import ShapKernel
9+
from ..blackbox import MorrisSensitivity
10+
from ..blackbox import PartialDependence
11+
12+
from ..glassbox import LogisticRegression, LinearRegression
13+
from ..glassbox import ClassificationTree, RegressionTree
14+
from ..glassbox import DecisionListClassifier
15+
from ..glassbox import ExplainableBoostingClassifier, ExplainableBoostingRegressor
16+
417
import pandas as pd
518
import numpy as np
619
from sklearn.model_selection import train_test_split
@@ -10,6 +23,33 @@
1023
from sklearn.base import is_classifier
1124

1225

26+
def get_all_explainers():
27+
data_explainer_classes = [ClassHistogram]
28+
perf_explainer_classes = [ROC, RegressionPerf]
29+
model_explainer_classes = [
30+
ClassificationTree,
31+
DecisionListClassifier,
32+
LogisticRegression,
33+
ExplainableBoostingClassifier,
34+
RegressionTree,
35+
LinearRegression,
36+
ExplainableBoostingRegressor,
37+
]
38+
blackbox_explainer_classes = [
39+
LimeTabular,
40+
ShapKernel,
41+
MorrisSensitivity,
42+
PartialDependence,
43+
]
44+
all_explainers = []
45+
all_explainers.extend(model_explainer_classes)
46+
all_explainers.extend(blackbox_explainer_classes)
47+
all_explainers.extend(data_explainer_classes)
48+
all_explainers.extend(perf_explainer_classes)
49+
50+
return all_explainers
51+
52+
1353
def synthetic_regression():
1454
dataset = _synthetic("regression")
1555
return dataset

src/python/interpret/visual/interactive.py

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def preserve(explanation, selector_key=None, file_name=None, **kwargs):
133133
134134
If file_name is not None the following occurs:
135135
- For Plotly figures, saves to HTML using `plot`.
136-
- For dataframes, saves to CSV using `to_csv`.
136+
- For dataframes, saves to HTML using `to_html`.
137137
- For strings (html), saves to HTML.
138138
- For Dash components, fails with exception. This is currently not supported.
139139
@@ -147,17 +147,58 @@ def preserve(explanation, selector_key=None, file_name=None, **kwargs):
147147
None.
148148
"""
149149

150+
try:
151+
# Get explanation key
152+
if selector_key is None:
153+
key = None
154+
else:
155+
series = explanation.selector[explanation.selector.columns[0]]
156+
key = series[series == selector_key].index[0]
157+
158+
# Get visual object
159+
visual = explanation.visualize(key=key)
160+
161+
# Output to front-end/file
162+
_preserve_output(
163+
explanation.name,
164+
visual,
165+
selector_key=selector_key,
166+
file_name=file_name,
167+
**kwargs
168+
)
169+
return None
170+
except Exception as e:
171+
log.error(e, exc_info=True)
172+
raise e
173+
174+
175+
def _preserve_output(
176+
explanation_name, visual, selector_key=None, file_name=None, **kwargs
177+
):
150178
from plotly.offline import iplot, plot, init_notebook_mode
151-
from IPython.display import display, HTML
179+
from IPython.display import display, display_html
180+
from base64 import b64encode
181+
152182
init_notebook_mode(connected=True)
153183

154-
if selector_key is None:
155-
key = None
156-
else:
157-
series = explanation.selector[explanation.selector.columns[0]]
158-
key = series[series == selector_key].index[0]
184+
def render_html(html_string):
185+
base64_html = b64encode(html_string.encode("utf-8")).decode("ascii")
186+
final_html = """<iframe src="data:text/html;base64,{data}" width="100%" height=400 frameBorder="0"></iframe>""".format(
187+
data=base64_html
188+
)
189+
display_html(final_html, raw=True)
190+
191+
if visual is None:
192+
msg = "No visualization for explanation [{0}] with selector_key [{1}]".format(
193+
explanation_name, selector_key
194+
)
195+
log.error(msg)
196+
if file_name is None:
197+
render_html(msg)
198+
else:
199+
pass
200+
return False
159201

160-
visual = explanation.visualize(key=key)
161202
if isinstance(visual, go.Figure):
162203
if file_name is None:
163204
iplot(visual, **kwargs)
@@ -167,18 +208,24 @@ def preserve(explanation, selector_key=None, file_name=None, **kwargs):
167208
if file_name is None:
168209
display(visual, **kwargs)
169210
else:
170-
visual.to_csv(file_name, **kwargs)
211+
visual.to_html(file_name, **kwargs)
171212
elif isinstance(visual, str):
172213
if file_name is None:
173-
with(file_name, "w") as f:
174-
f.write(visual)
214+
render_html(visual)
175215
else:
176-
HTML(visual, **kwargs)
216+
with open(file_name, "w") as f:
217+
f.write(visual)
177218
elif isinstance(visual, dash_base.Component):
178219
msg = "Preserving dash components is currently not supported."
179-
raise Exception(msg)
220+
if file_name is None:
221+
render_html(msg)
222+
log.error(msg)
223+
return False
180224
else:
181225
msg = "Visualization cannot be preserved for type: {0}.".format(type(visual))
182-
raise Exception(msg)
226+
if file_name is None:
227+
render_html(msg)
228+
log.error(msg)
229+
return False
183230

184-
return None
231+
return True

0 commit comments

Comments
 (0)