Skip to content

Commit 56f3ccf

Browse files
committed
Adding the predict_proba method to ShapWrapper and Evaluation class
1 parent 074a910 commit 56f3ccf

File tree

3 files changed

+139
-1
lines changed

3 files changed

+139
-1
lines changed

bigml/evaluation.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# -*- coding: utf-8 -*-
2+
#
3+
# Copyright 2023 BigML
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License"); you may
6+
# not use this file except in compliance with the License. You may obtain
7+
# a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
13+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
14+
# License for the specific language governing permissions and limitations
15+
# under the License.
16+
17+
"""An local Evaluation object.
18+
19+
This module defines a local class to handle the results of an evaluation
20+
21+
"""
22+
import json
23+
24+
25+
from bigml.api import get_api_connection, ID_GETTERS
26+
from bigml.basemodel import retrieve_resource, get_resource_dict
27+
28+
CLASSIFICATION_METRICS = [
29+
"accuracy", "precision", "recall", "phi" "phi_coefficient",
30+
"f_measure", "confusion_matrix", "per_class_statistics"]
31+
32+
REGRESSION_METRICS = ["mean_absolute_error", "mean_squared_error", "r_squared"]
33+
34+
35+
class ClassificationEval():
36+
"""A class to store the classification metrics """
37+
def __init__(self, name, per_class_statistics):
38+
self.name = name
39+
for statistics in per_class_statistics:
40+
if statistics["class_name"] == name:
41+
break
42+
for metric in CLASSIFICATION_METRICS:
43+
if metric in statistics.keys():
44+
setattr(self, metric, statistics.get(metric))
45+
46+
47+
class Evaluation():
48+
"""A class to deal with the information in an evaluation result
49+
50+
"""
51+
def __init__(self, evaluation, api=None):
52+
53+
self.resource_id = None
54+
self.model_id = None
55+
self.test_dataset_id = None
56+
self.regression = None
57+
self.full = None
58+
self.random = None
59+
self.error = None
60+
self.error_message = None
61+
self.api = get_api_connection(api)
62+
63+
try:
64+
self.resource_id, evaluation = get_resource_dict( \
65+
evaluation, "evaluation", self.api)
66+
except ValueError as resource:
67+
try:
68+
evaluation = json.loads(str(resource))
69+
self.resource_id = evaluation["resource"]
70+
except ValueError:
71+
raise ValueError("The evaluation resource was faulty: \n%s" % \
72+
resource)
73+
74+
if 'object' in evaluation and isinstance(evaluation['object'], dict):
75+
evaluation = evaluation['object']
76+
self.status = evaluation["status"]
77+
self.error = self.status.get("error")
78+
if self.error is not None:
79+
self.error_message = self.status.get("message")
80+
else:
81+
self.model_id = evaluation["model"]
82+
self.test_dataset_id = evaluation["dataset"]
83+
84+
if 'result' in evaluation and \
85+
isinstance(evaluation['result'], dict):
86+
self.full = evaluation.get("result", {}).get("model")
87+
self.random = evaluation.get("result", {}).get("random")
88+
self.regression = not self.full.get("confusion_matrix")
89+
if self.regression:
90+
self.add_metrics(self.full, REGRESSION_METRICS)
91+
self.mean = evaluation.get("result", {}).get("mean")
92+
else:
93+
self.add_metrics(self.full, CLASSIFICATION_METRICS)
94+
self.mode = evaluation.get("result", {}).get("mode")
95+
self.classes = evaluation.get("result", {}).get(
96+
"class_names")
97+
else:
98+
raise ValueError("Failed to find the correct evaluation"
99+
" structure.")
100+
if not self.regression:
101+
self.positive_class = ClassificationEval(self.classes[-1],
102+
self.per_class_statistics)
103+
104+
def add_metrics(self, metrics_info, metrics_list, obj=None):
105+
"""Adding the metrics in the `metrics_info` dictionary as attributes
106+
in the object passed as argument. If None is given, the metrics will
107+
be added to the self object.
108+
"""
109+
if obj is None:
110+
obj = self
111+
112+
for metric in metrics_list:
113+
setattr(obj, metric, metrics_info.get(metric,
114+
metrics_info.get("average_%s" % metric)))
115+
116+
def set_positive_class(self, positive_class):
117+
"""Changing the positive class """
118+
if positive_class is None or positive_class not in self.classes:
119+
raise ValueError("The possible classes are: %s" %
120+
", ".join(self.classes))
121+
self.positive_class = ClassificationEval(positive_class,
122+
self.per_class_statistics)

bigml/shapwrapper.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import numpy as np
2222

2323
from bigml.supervised import SupervisedModel, extract_id
24+
from bigml.fusion import Fusion
2425
from bigml.fields import Fields
2526
from bigml.api import get_resource_type, get_api_connection
2627

@@ -35,7 +36,8 @@ def __init__(self, model, api=None, cache_get=None,
3536
self.api = get_api_connection(api)
3637
resource_id, model = extract_id(model, self.api)
3738
resource_type = get_resource_type(resource_id)
38-
self.local_model = SupervisedModel(model, api=api, cache_get=cache_get,
39+
model_class = Fusion if resource_type == "fusion" else SupervisedModel
40+
self.local_model = model_class(model, api=api, cache_get=cache_get,
3941
operation_settings=operation_settings)
4042
objective_id = getattr(self.local_model, "objective_id", None)
4143
self.fields = Fields(self.local_model.fields,
@@ -55,3 +57,15 @@ def predict(self, x_test, **kwargs):
5557
pred_fields = Fields(objective_field)
5658
return pred_fields.to_numpy(batch_prediction,
5759
objective=True).reshape(-1)
60+
61+
def predict_proba(self, x_test):
62+
"""Prediction method that interfaces with the Shap library"""
63+
if self.local_model.regression:
64+
raise ValueError("This method is only available for classification"
65+
" models.")
66+
input_data_list = self.fields.from_numpy(x_test)
67+
predictions = np.ndarray([])
68+
for input_data in inner_data_list:
69+
prediction = self.predict_probability(input_data, compact=True)
70+
np.append(predictions, np.ndarray(prediction))
71+
return predictions

bigml/supervised.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ def __init__(self, model, api=None, cache_get=None,
137137
for attr, value in list(local_model.__dict__.items()):
138138
setattr(self, attr, value)
139139
self.local_model = local_model
140+
self.regression = resource_type == "linearregression" or \
141+
self.local_model.regression
140142
self.name = self.local_model.name
141143
self.description = self.local_model.description
142144

0 commit comments

Comments
 (0)