Skip to content

Commit 7b5137d

Browse files
jarverhajarverha
authored andcommitted
added documentation
1 parent 90ed570 commit 7b5137d

File tree

3 files changed

+36
-33
lines changed

3 files changed

+36
-33
lines changed

powershap/powershap.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -534,9 +534,8 @@ def transform(self, X):
534534
)
535535
return super().transform(X)
536536

537-
# def _more_tags(self):
538-
# return self._explainer._get_more_tags()
539-
537+
# Since sklearn 1.6, the tag system changed so this function is necessary to make it compatible
538+
# https://scikit-learn.org/stable/auto_examples/release_highlights/plot_release_highlights_1_6_0.html#improvements-to-the-developer-api-for-third-party-libraries
540539
def __sklearn_tags__(self):
541540
return self._explainer._get_more_tags()
542541

powershap/shap_wrappers/shap_explainer.py

Lines changed: 32 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,10 @@ def __init__(self, model: Any):
4646
def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -> np.array:
4747
raise NotImplementedError
4848

49-
# Should be implemented by explainers themselves
49+
# If the explainer supports nan values, infinite values, or others, the explainer must override this function
5050
def validate_data(self, _estimator, X, y, **kwargs):
5151
return validate_data(_estimator, X, y, **kwargs)
5252

53-
# def _validate_data(self, validate_data: Callable, X, y, **kwargs):
54-
# return validate_data(X, y, **kwargs)
55-
5653
# Should be implemented by subclass
5754
@staticmethod
5855
def supports_model(model) -> bool:
@@ -241,6 +238,7 @@ def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -
241238
C_explainer = shap.TreeExplainer(PowerShap_model)
242239
return C_explainer.shap_values(X_val)
243240

241+
# Function to define the tags which will be used in sklearn pipelines
244242
def _get_more_tags(self):
245243
return Tags(
246244
estimator_type=None,
@@ -273,6 +271,7 @@ def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -
273271
C_explainer = shap.TreeExplainer(PowerShap_model)
274272
return C_explainer.shap_values(X_val)
275273

274+
# Function to define the tags which will be used in sklearn pipelines
276275
def _get_more_tags(self):
277276
return Tags(
278277
estimator_type=None,
@@ -305,6 +304,7 @@ def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -
305304
C_explainer = shap.TreeExplainer(PowerShap_model)
306305
return C_explainer.shap_values(X_val)
307306

307+
# Function to define the tags which will be used in sklearn pipelines
308308
def _get_more_tags(self):
309309
return Tags(
310310
estimator_type=None,
@@ -370,7 +370,8 @@ def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -
370370

371371
### DEEP LEARNING
372372

373-
373+
# Tensorflow has been phased out and current version does not support deepLearning approach
374+
# TODO add support for Pytorch instead
374375
class DeepLearningExplainer(ShapExplainer):
375376
@staticmethod
376377
def supports_model(model) -> bool:
@@ -379,30 +380,31 @@ def supports_model(model) -> bool:
379380
# import torch ## TODO: do we support pytorch??
380381

381382
# supported_models = [tf.keras.Model] # , torch.nn.Module]
382-
return None #isinstance(model, tuple(supported_models))
383+
# return isinstance(model, tuple(supported_models))
384+
return False
383385

384386

385-
def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -> np.array:
386-
# import tensorflow as tf
387-
388-
# tf.compat.v1.disable_v2_behavior() # https://github.com/slundberg/shap/issues/2189
389-
# Fit the model
390-
# PowerShap_model = tf.keras.models.clone_model(self.model)
391-
# metrics = kwargs.get("nn_metric")
392-
# PowerShap_model.compile(
393-
# loss=kwargs["loss"],
394-
# optimizer=kwargs["optimizer"],
395-
# metrics=metrics if metrics is None else [metrics],
396-
# # run_eagerly=True,
397-
# )
398-
# _ = PowerShap_model.fit(
399-
# X_train,
400-
# Y_train,
401-
# batch_size=kwargs["batch_size"],
402-
# epochs=kwargs["epochs"],
403-
# validation_data=(X_val, Y_val),
404-
# verbose=False,
405-
# )
406-
# # Calculate the shap values
407-
# C_explainer = shap.DeepExplainer(PowerShap_model, X_train)
408-
return None# C_explainer.shap_values(X_val)
387+
# def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -> np.array:
388+
# # import tensorflow as tf
389+
390+
# # tf.compat.v1.disable_v2_behavior() # https://github.com/slundberg/shap/issues/2189
391+
# # Fit the model
392+
# # PowerShap_model = tf.keras.models.clone_model(self.model)
393+
# # metrics = kwargs.get("nn_metric")
394+
# # PowerShap_model.compile(
395+
# # loss=kwargs["loss"],
396+
# # optimizer=kwargs["optimizer"],
397+
# # metrics=metrics if metrics is None else [metrics],
398+
# # # run_eagerly=True,
399+
# # )
400+
# # _ = PowerShap_model.fit(
401+
# # X_train,
402+
# # Y_train,
403+
# # batch_size=kwargs["batch_size"],
404+
# # epochs=kwargs["epochs"],
405+
# # validation_data=(X_val, Y_val),
406+
# # verbose=False,
407+
# # )
408+
# # # Calculate the shap values
409+
# # C_explainer = shap.DeepExplainer(PowerShap_model, X_train)
410+
# return C_explainer.shap_values(X_val)

powershap/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def powerSHAP_statistical_analysis(
6262
effect_size.append(0)
6363
power_list.append(0)
6464

65+
# The solve power of statsmodels might not always converge. If that happens it outputs a list of length 1 instead of a float
66+
# Numpy typing requires nonambiguous typing so this line is to ensure the cast to np.array later on will not result in a ValueError due to a failed converge
6567
required_iterations = [x[0] if not (isinstance(x, float) or isinstance(x, int)) else x for x in required_iterations]
6668

6769
processed_shaps_df = pd.DataFrame(

0 commit comments

Comments
 (0)