Skip to content

Commit 074a910

Browse files
committed
Adding wrapper for supervised models to be used in shap
1 parent 7690b1f commit 074a910

File tree

7 files changed

+181
-9
lines changed

7 files changed

+181
-9
lines changed

bigml/constants.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,9 +332,11 @@
332332
OUT_NEW_HEADERS = "output_headers"
333333

334334
# input data allowed formats in batch predictions
335+
NUMPY = "numpy"
335336
DATAFRAME = "dataframe"
336337
INTERNAL = "list_of_dicts"
337338

339+
CATEGORICAL = "categorical"
338340

339341
IMAGE_EXTENSIONS = ['png', 'jpg', 'jpeg', 'gif', 'tiff', 'tif', 'bmp',
340342
'webp', 'cur', 'ico', 'pcx', 'psd', 'psb']

bigml/deepnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,14 +68,14 @@
6868
import tensorflow as tf
6969
tf.autograph.set_verbosity(0)
7070
LAMINAR_VERSION = False
71-
except ModuleNotFoundError:
71+
except Exception:
7272
LAMINAR_VERSION = True
7373

7474
try:
7575
from sensenet.models.wrappers import create_model
7676
from bigml.images.utils import to_relative_coordinates
7777
from bigml.constants import IOU_REMOTE_SETTINGS
78-
except ModuleNotFoundError:
78+
except Exception:
7979
LAMINAR_VERSION = True
8080

8181
LOGGER = logging.getLogger('BigML')

bigml/fields.py

Lines changed: 105 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,21 @@
4545
import json
4646
import csv
4747
import random
48+
import numpy as np
49+
50+
try:
51+
from pandas import DataFrame
52+
PANDAS_READY = True
53+
except ImportError:
54+
PANDAS_READY = False
4855

4956

5057
from bigml.util import invert_dictionary, python_map_type, find_locale
5158
from bigml.util import DEFAULT_LOCALE
5259
from bigml.api_handlers.resourcehandler import get_resource_type, get_fields
5360
from bigml.constants import (
5461
SOURCE_PATH, DATASET_PATH, SUPERVISED_PATHS, FUSION_PATH,
55-
RESOURCES_WITH_FIELDS, DEFAULT_MISSING_TOKENS, REGIONS)
62+
RESOURCES_WITH_FIELDS, DEFAULT_MISSING_TOKENS, REGIONS, CATEGORICAL)
5663
from bigml.io import UnicodeReader, UnicodeWriter
5764

5865
LIST_LIMIT = 10
@@ -193,6 +200,32 @@ def get_new_fields(output_fields):
193200
return new_fields
194201

195202

203+
def one_hot_code(value, field, decode=False):
204+
"""Translating into codes categorical values. The codes are the index
205+
of the value in the list of categories read from the fields summary.
206+
Decode set to True will cause the code to be translated to the value"""
207+
208+
try:
209+
categories = [cat[0] for cat in field["summary"]["categories"]]
210+
except KeyError:
211+
raise KeyError("Failed to find the categories list. Check the field"
212+
" information.")
213+
214+
if decode:
215+
try:
216+
result = categories[int(value)]
217+
except KeyError:
218+
raise KeyError("Code not found in the categories list. %s" %
219+
categories)
220+
else:
221+
try:
222+
result = categories.index(value)
223+
except ValueError:
224+
raise ValueError("The '%s' value is not found in the categories "
225+
"list: %s" % (value, categories))
226+
return result
227+
228+
196229
class Fields():
197230
"""A class to deal with BigML auto-generated ids.
198231
@@ -483,6 +516,77 @@ def stats(self, field_name):
483516
summary = self.fields[field_id].get('summary', {})
484517
return summary
485518

519+
def objective_field_info(self):
520+
"""Returns the fields structure for the objective field"""
521+
if self.objective_field is None:
522+
return None
523+
objective_id = self.field_id(self.objective_field)
524+
return {objective_id: self.fields[objective_id]}
525+
526+
def sorted_field_ids(self, objective=False):
527+
"""List of field IDs ordered by column number. If objective is
528+
set to False, the objective field will be excluded.
529+
"""
530+
fields = {}
531+
fields.update(self.fields_by_column_number)
532+
if not objective and self.objective_field is not None:
533+
del(fields[self.objective_field])
534+
field_ids = fields.values()
535+
return field_ids
536+
537+
def to_numpy(self, input_data_list, objective=False):
538+
"""Transforming input data to numpy syntax. Fields are sorted
539+
in the dataset order and categorical fields are one-hot encoded.
540+
If objective set to False, the objective field will not be included"""
541+
if PANDAS_READY and isinstance(input_data_list, DataFrame):
542+
inner_data_list = input_data_list.to_dict('records')
543+
else:
544+
inner_data_list = input_data_list
545+
field_ids = self.sorted_field_ids(objective=objective)
546+
np_input_list = np.empty(shape=(len(input_data_list),
547+
len(field_ids)))
548+
for index, input_data in enumerate(inner_data_list):
549+
np_input = np.array([])
550+
for field_id in field_ids:
551+
field_input = input_data.get(field_id,
552+
input_data.get(self.field_name(field_id)))
553+
field = self.fields[field_id]
554+
if field["optype"] == CATEGORICAL:
555+
field_input = one_hot_code(field_input, field)
556+
np_input = np.append(np_input, field_input)
557+
np_input_list[index] = np_input
558+
return np_input_list
559+
560+
def from_numpy(self, np_data_list, objective=False, by_name=True):
561+
"""Transforming input data from numpy syntax. Fields are sorted
562+
in the dataset order and categorical fields are one-hot encoded."""
563+
input_data_list = []
564+
field_ids = self.sorted_field_ids(objective=objective)
565+
for np_data in np_data_list:
566+
if len(np_data) != len(field_ids):
567+
raise ValueError("Wrong number of features in data: %s"
568+
" found, %s expected" % (len(np_data), len(field_ids)))
569+
input_data = {}
570+
for index, field_id in enumerate(field_ids):
571+
field_input = None if np.isnan(np_data[index]) else \
572+
np_data[index]
573+
field = self.fields[field_id]
574+
if field["optype"] == CATEGORICAL:
575+
field_input = one_hot_code(field_input, field, decode=True)
576+
if by_name:
577+
field_id = self.fields[field_id]["name"]
578+
input_data.update({field_id: field_input})
579+
input_data_list.append(input_data)
580+
return input_data_list
581+
582+
def one_hot_codes(self, field_name):
583+
"""Returns the codes used for every category in a categorical field"""
584+
field = self.fields[self.field_id(field_name)]
585+
if field["optype"] != CATEGORICAL:
586+
raise ValueError("Only categorical fields are encoded")
587+
categories = [cat[0] for cat in field["summary"]["categories"]]
588+
return dict(zip(categories, range(0, len(categories))))
589+
486590
def summary_csv(self, filename=None):
487591
"""Summary of the contents of the fields
488592

bigml/modelfields.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def add_terms(self, categories=False, numerics=False):
241241
self.fields[field_id]["summary"]["categories"]:
242242
self.categories[field_id] = [category for \
243243
[category, _] in field['summary']['categories']]
244-
del self.fields[field_id]["summary"]["categories"]
244+
# del self.fields[field_id]["summary"]["categories"]
245245
if field['optype'] == 'datetime' and \
246246
hasattr(self, "coeff_ids"):
247247
self.coeff_id = [coeff_id for coeff_id in self.coeff_ids \

bigml/shapwrapper.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# -*- coding: utf-8 -*-
2+
# pylint: disable=super-init-not-called
3+
#
4+
# Copyright 2023 BigML
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
7+
# not use this file except in compliance with the License. You may obtain
8+
# a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
14+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
15+
# License for the specific language governing permissions and limitations
16+
# under the License.
17+
18+
"""A wrapper for models to produce predictions as expected by Shap Explainer
19+
20+
"""
21+
import numpy as np
22+
23+
from bigml.supervised import SupervisedModel, extract_id
24+
from bigml.fields import Fields
25+
from bigml.api import get_resource_type, get_api_connection
26+
27+
28+
class ShapWrapper():
29+
""" A lightweight wrapper around any supervised model that offers a
30+
predict method adapted to the expected Shap Explainer syntax"""
31+
32+
def __init__(self, model, api=None, cache_get=None,
33+
operation_settings=None):
34+
35+
self.api = get_api_connection(api)
36+
resource_id, model = extract_id(model, self.api)
37+
resource_type = get_resource_type(resource_id)
38+
self.local_model = SupervisedModel(model, api=api, cache_get=cache_get,
39+
operation_settings=operation_settings)
40+
objective_id = getattr(self.local_model, "objective_id", None)
41+
self.fields = Fields(self.local_model.fields,
42+
objective_field=objective_id)
43+
self.x_headers = [self.fields.field_name(field_id) for field_id in
44+
self.fields.sorted_field_ids()]
45+
self.y_header = self.fields.field_name(self.fields.objective_field)
46+
47+
def predict(self, x_test, **kwargs):
48+
"""Prediction method that interfaces with the Shap library"""
49+
input_data_list = self.fields.from_numpy(x_test)
50+
batch_prediction = self.local_model.batch_predict(
51+
input_data_list, outputs={"output_fields": ["prediction"],
52+
"output_headers": [self.y_header]},
53+
all_fields=False, **kwargs)
54+
objective_field = self.fields.objective_field_info()
55+
pred_fields = Fields(objective_field)
56+
return pred_fields.to_numpy(batch_prediction,
57+
objective=True).reshape(-1)

bigml/supervised.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ def data_transformations(self):
170170
"""
171171
return self.local_model.data_transformations()
172172

173-
def batch_predict(self, input_data_list, outputs=None, **kwargs):
173+
def batch_predict(self, input_data_list, outputs=None, all_fields=True,
174+
**kwargs):
174175
"""Creates a batch prediction for a list of inputs using the local
175176
supervised model. Allows to define some output settings to
176177
decide the fields to be added to the input_data (prediction,
@@ -185,6 +186,8 @@ def batch_predict(self, input_data_list, outputs=None, **kwargs):
185186
:type input_data_list: list or Panda's dataframe
186187
:param dict outputs: properties that define the headers and fields to
187188
be added to the input data
189+
:param boolean all_fields: whether all the fields in the input data
190+
should be part of the response
188191
:return: the list of input data plus the predicted values
189192
:rtype: list or Panda's dataframe depending on the input type in
190193
input_data_list
@@ -199,17 +202,22 @@ def batch_predict(self, input_data_list, outputs=None, **kwargs):
199202
new_headers = new_headers[0: len(new_fields)]
200203
data_format = get_data_format(input_data_list)
201204
inner_data_list = get_formatted_data(input_data_list, INTERNAL)
205+
predictions_list = []
206+
kwargs.update({"full": True})
202207
for input_data in inner_data_list:
203-
kwargs.update({"full": True})
204208
prediction = self.predict(input_data, **kwargs)
209+
prediction_data = {}
210+
if all_fields:
211+
prediction_data.update(input_data)
205212
for index, key in enumerate(new_fields):
206213
try:
207-
input_data[new_headers[index]] = prediction[key]
214+
prediction_data[new_headers[index]] = prediction[key]
208215
except KeyError:
209216
pass
217+
predictions_list.append(prediction_data)
210218
if data_format != INTERNAL:
211-
return format_data(inner_data_list, out_format=data_format)
212-
return inner_data_list
219+
return format_data(predictions_list, out_format=data_format)
220+
return predictions_list
213221

214222
#pylint: disable=locally-disabled,arguments-differ
215223
def dump(self, **kwargs):

bigml/util.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -723,6 +723,7 @@ def get_data_format(input_data_list):
723723
raise ValueError("Data is expected to be provided as a list of "
724724
"dictionaries or Pandas' DataFrame.")
725725

726+
726727
#pylint: disable=locally-disabled,comparison-with-itself
727728
def format_data(input_data_list, out_format=None):
728729
"""Transforms the input data format to the one expected """

0 commit comments

Comments
 (0)