@@ -305,12 +305,13 @@ class SHAPConfig(ExplainabilityConfig):
305305
306306 def __init__ (
307307 self ,
308- baseline ,
309- num_samples ,
310- agg_method ,
308+ baseline = None ,
309+ num_samples = None ,
310+ agg_method = None ,
311311 use_logit = False ,
312312 save_local_shap_values = True ,
313313 seed = None ,
314+ num_clusters = None ,
314315 ):
315316 """Initializes config for SHAP.
316317
@@ -320,34 +321,49 @@ def __init__(
320321 be the same as the dataset format. Each row should contain only the feature
321322 columns/values and omit the label column/values. If None a baseline will be
322323 calculated automatically by using K-means or K-prototypes in the input dataset.
323- num_samples (int): Number of samples to be used in the Kernel SHAP algorithm.
324+ num_samples (None or int): Number of samples to be used in the Kernel SHAP algorithm.
324325 This number determines the size of the generated synthetic dataset to compute the
325- SHAP values.
326- agg_method (str): Aggregation method for global SHAP values. Valid values are
326+ SHAP values. If not provided then Clarify job will choose a proper value according
327+ to the count of features.
328+ agg_method (None or str): Aggregation method for global SHAP values. Valid values are
327329 "mean_abs" (mean of absolute SHAP values for all instances),
328330 "median" (median of SHAP values for all instances) and
329331 "mean_sq" (mean of squared SHAP values for all instances).
332+ If not provided then Clarify job uses method "mean_abs"
330333 use_logit (bool): Indicator of whether the logit function is to be applied to the model
331334 predictions. Default is False. If "use_logit" is true then the SHAP values will
332335 have log-odds units.
333336 save_local_shap_values (bool): Indicator of whether to save the local SHAP values
334337 in the output location. Default is True.
335338 seed (int): seed value to get deterministic SHAP values. Default is None.
339+ num_clusters (None or int): If a baseline is not provided, Clarify automatically
340+ computes a baseline dataset via a clustering algorithm (K-means/K-prototypes).
341+ num_clusters is a parameter for this algorithm. num_clusters will be the resulting
342+ size of the baseline dataset. If not provided, Clarify job will use a default value.
336343 """
337- if agg_method not in ["mean_abs" , "median" , "mean_sq" ]:
344+ if agg_method is not None and agg_method not in ["mean_abs" , "median" , "mean_sq" ]:
338345 raise ValueError (
339346 f"Invalid agg_method { agg_method } ." f" Please choose mean_abs, median, or mean_sq."
340347 )
341-
348+ if num_clusters is not None and baseline is not None :
349+ raise ValueError (
350+ "Baseline and num_clusters cannot be provided together. "
351+ "Please specify one of the two."
352+ )
342353 self .shap_config = {
343- "baseline" : baseline ,
344- "num_samples" : num_samples ,
345- "agg_method" : agg_method ,
346354 "use_logit" : use_logit ,
347355 "save_local_shap_values" : save_local_shap_values ,
348356 }
357+ if baseline is not None :
358+ self .shap_config ["baseline" ] = baseline
359+ if num_samples is not None :
360+ self .shap_config ["num_samples" ] = num_samples
361+ if agg_method is not None :
362+ self .shap_config ["agg_method" ] = agg_method
349363 if seed is not None :
350364 self .shap_config ["seed" ] = seed
365+ if num_clusters is not None :
366+ self .shap_config ["num_clusters" ] = num_clusters
351367
352368 def get_explainability_config (self ):
353369 """Returns config."""
0 commit comments