2626from sagemaker import image_uris , s3
2727from sagemaker .session import Session
2828from sagemaker .utils import name_from_base
29- from sagemaker .clarify import SageMakerClarifyProcessor
29+ from sagemaker .clarify import SageMakerClarifyProcessor , ModelPredictedLabelConfig
3030
3131_LOGGER = logging .getLogger (__name__ )
3232
@@ -833,9 +833,10 @@ def suggest_baseline(
833833 specific explainability method. Currently, only SHAP is supported.
834834 model_config (:class:`~sagemaker.clarify.ModelConfig`): Config of the model and its
835835 endpoint to be created.
836- model_scores (int or str): Index or JSONPath location in the model output for the
837- predicted scores to be explained. This is not required if the model output is
838- a single score.
836+ model_scores (int or str or :class:`~sagemaker.clarify.ModelPredictedLabelConfig`):
837+ Index or JSONPath to locate the predicted scores in the model output. This is not
838+ required if the model output is a single score. Alternatively, it can be an instance
839+ of ModelPredictedLabelConfig to provide more parameters like label_headers.
839840 wait (bool): Whether the call should wait until the job completes (default: False).
840841 logs (bool): Whether to show the logs produced by the job.
841842 Only meaningful when wait is True (default: False).
@@ -865,14 +866,24 @@ def suggest_baseline(
865866 headers = copy .deepcopy (data_config .headers )
866867 if headers and data_config .label in headers :
867868 headers .remove (data_config .label )
869+ if model_scores is None :
870+ inference_attribute = None
871+ label_headers = None
872+ elif isinstance (model_scores , ModelPredictedLabelConfig ):
873+ inference_attribute = str (model_scores .label )
874+ label_headers = model_scores .label_headers
875+ else :
876+ inference_attribute = str (model_scores )
877+ label_headers = None
868878 self .latest_baselining_job_config = ClarifyBaseliningConfig (
869879 analysis_config = ExplainabilityAnalysisConfig (
870880 explainability_config = explainability_config ,
871881 model_config = model_config ,
872882 headers = headers ,
883+ label_headers = label_headers ,
873884 ),
874885 features_attribute = data_config .features ,
875- inference_attribute = model_scores if model_scores is None else str ( model_scores ) ,
886+ inference_attribute = inference_attribute ,
876887 )
877888 self .latest_baselining_job_name = baselining_job_name
878889 self .latest_baselining_job = ClarifyBaseliningJob (
@@ -1166,7 +1177,7 @@ def attach(cls, monitor_schedule_name, sagemaker_session=None):
11661177class ExplainabilityAnalysisConfig :
11671178 """Analysis configuration for ModelExplainabilityMonitor."""
11681179
1169- def __init__ (self , explainability_config , model_config , headers = None ):
1180+ def __init__ (self , explainability_config , model_config , headers = None , label_headers = None ):
11701181 """Creates an analysis config dictionary.
11711182
11721183 Args:
@@ -1175,13 +1186,19 @@ def __init__(self, explainability_config, model_config, headers=None):
11751186 model_config (sagemaker.clarify.ModelConfig): Config object related to bias
11761187 configurations.
11771188 headers (list[str]): A list of feature names (without label) of model/endpint input.
1189+ label_headers (list[str]): List of headers, each for a predicted score in model output.
1190+ It is used to beautify the analysis report by replacing placeholders like "label0".
1191+
11781192 """
1193+ predictor_config = model_config .get_predictor_config ()
11791194 self .analysis_config = {
11801195 "methods" : explainability_config .get_explainability_config (),
1181- "predictor" : model_config . get_predictor_config () ,
1196+ "predictor" : predictor_config ,
11821197 }
11831198 if headers is not None :
11841199 self .analysis_config ["headers" ] = headers
1200+ if label_headers is not None :
1201+ predictor_config ["label_headers" ] = label_headers
11851202
11861203 def _to_dict (self ):
11871204 """Generates a request dictionary using the parameters provided to the class."""
0 commit comments