@@ -46,13 +46,10 @@ def __init__(self, model: Any):
4646 def _fit_get_shap (self , X_train , Y_train , X_val , Y_val , random_seed , ** kwargs ) -> np .array :
4747 raise NotImplementedError
4848
49- # Should be implemented by explainers themselves
49+ # If the explainer supports nan values, infinite values, or others, the explainer must override this function
5050 def validate_data (self , _estimator , X , y , ** kwargs ):
5151 return validate_data (_estimator , X , y , ** kwargs )
5252
53- # def _validate_data(self, validate_data: Callable, X, y, **kwargs):
54- # return validate_data(X, y, **kwargs)
55-
5653 # Should be implemented by subclass
5754 @staticmethod
5855 def supports_model (model ) -> bool :
@@ -241,6 +238,7 @@ def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -
241238 C_explainer = shap .TreeExplainer (PowerShap_model )
242239 return C_explainer .shap_values (X_val )
243240
241+ # Function to define the tags which will be used in sklearn pipelines
244242 def _get_more_tags (self ):
245243 return Tags (
246244 estimator_type = None ,
@@ -273,6 +271,7 @@ def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -
273271 C_explainer = shap .TreeExplainer (PowerShap_model )
274272 return C_explainer .shap_values (X_val )
275273
274+ # Function to define the tags which will be used in sklearn pipelines
276275 def _get_more_tags (self ):
277276 return Tags (
278277 estimator_type = None ,
@@ -305,6 +304,7 @@ def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -
305304 C_explainer = shap .TreeExplainer (PowerShap_model )
306305 return C_explainer .shap_values (X_val )
307306
307+ # Function to define the tags which will be used in sklearn pipelines
308308 def _get_more_tags (self ):
309309 return Tags (
310310 estimator_type = None ,
@@ -370,7 +370,8 @@ def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -
370370
371371### DEEP LEARNING
372372
373-
373+ # Tensorflow has been phased out and current version does not support deepLearning approach
374+ # TODO add support for Pytorch instead
374375class DeepLearningExplainer (ShapExplainer ):
375376 @staticmethod
376377 def supports_model (model ) -> bool :
@@ -379,30 +380,31 @@ def supports_model(model) -> bool:
379380 # import torch ## TODO: do we support pytorch??
380381
381382 # supported_models = [tf.keras.Model] # , torch.nn.Module]
382- return None #isinstance(model, tuple(supported_models))
383+ # return isinstance(model, tuple(supported_models))
384+ return False
383385
384386
385- def _fit_get_shap (self , X_train , Y_train , X_val , Y_val , random_seed , ** kwargs ) -> np .array :
386- # import tensorflow as tf
387-
388- # tf.compat.v1.disable_v2_behavior() # https://github.com/slundberg/shap/issues/2189
389- # Fit the model
390- # PowerShap_model = tf.keras.models.clone_model(self.model)
391- # metrics = kwargs.get("nn_metric")
392- # PowerShap_model.compile(
393- # loss=kwargs["loss"],
394- # optimizer=kwargs["optimizer"],
395- # metrics=metrics if metrics is None else [metrics],
396- # # run_eagerly=True,
397- # )
398- # _ = PowerShap_model.fit(
399- # X_train,
400- # Y_train,
401- # batch_size=kwargs["batch_size"],
402- # epochs=kwargs["epochs"],
403- # validation_data=(X_val, Y_val),
404- # verbose=False,
405- # )
406- # # Calculate the shap values
407- # C_explainer = shap.DeepExplainer(PowerShap_model, X_train)
408- return None # C_explainer.shap_values(X_val)
387+ # def _fit_get_shap(self, X_train, Y_train, X_val, Y_val, random_seed, **kwargs) -> np.array:
388+ # # import tensorflow as tf
389+
390+ # # tf.compat.v1.disable_v2_behavior() # https://github.com/slundberg/shap/issues/2189
391+ # # Fit the model
392+ # # PowerShap_model = tf.keras.models.clone_model(self.model)
393+ # # metrics = kwargs.get("nn_metric")
394+ # # PowerShap_model.compile(
395+ # # loss=kwargs["loss"],
396+ # # optimizer=kwargs["optimizer"],
397+ # # metrics=metrics if metrics is None else [metrics],
398+ # # # run_eagerly=True,
399+ # # )
400+ # # _ = PowerShap_model.fit(
401+ # # X_train,
402+ # # Y_train,
403+ # # batch_size=kwargs["batch_size"],
404+ # # epochs=kwargs["epochs"],
405+ # # validation_data=(X_val, Y_val),
406+ # # verbose=False,
407+ # # )
408+ # # # Calculate the shap values
409+ # # C_explainer = shap.DeepExplainer(PowerShap_model, X_train)
410+ # return C_explainer.shap_values(X_val)
0 commit comments