@@ -442,21 +442,22 @@ def test_post_training_bias(
442442 )
443443
444444
445- @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
446- def test_shap (
445+ def _run_test_shap (
447446 name_from_base ,
448447 clarify_processor ,
449448 clarify_processor_with_job_name_prefix ,
450449 data_config ,
451450 model_config ,
452451 shap_config ,
452+ model_scores ,
453+ expected_predictor_config ,
453454):
454455 with patch .object (SageMakerClarifyProcessor , "_run" , return_value = None ) as mock_method :
455456 clarify_processor .run_explainability (
456457 data_config ,
457458 model_config ,
458459 shap_config ,
459- model_scores = None ,
460+ model_scores = model_scores ,
460461 wait = True ,
461462 job_name = "test" ,
462463 experiment_config = {"ExperimentName" : "AnExperiment" },
@@ -485,11 +486,7 @@ def test_shap(
485486 "save_local_shap_values" : True ,
486487 }
487488 },
488- "predictor" : {
489- "model_name" : "xgboost-model" ,
490- "instance_type" : "ml.c5.xlarge" ,
491- "initial_instance_count" : 1 ,
492- },
489+ "predictor" : expected_predictor_config ,
493490 }
494491 mock_method .assert_called_with (
495492 data_config ,
@@ -504,7 +501,7 @@ def test_shap(
504501 data_config ,
505502 model_config ,
506503 shap_config ,
507- model_scores = None ,
504+ model_scores = model_scores ,
508505 wait = True ,
509506 experiment_config = {"ExperimentName" : "AnExperiment" },
510507 )
@@ -518,3 +515,63 @@ def test_shap(
518515 None ,
519516 {"ExperimentName" : "AnExperiment" },
520517 )
518+
519+
520+ @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
521+ def test_shap (
522+ name_from_base ,
523+ clarify_processor ,
524+ clarify_processor_with_job_name_prefix ,
525+ data_config ,
526+ model_config ,
527+ shap_config ,
528+ ):
529+ expected_predictor_config = {
530+ "model_name" : "xgboost-model" ,
531+ "instance_type" : "ml.c5.xlarge" ,
532+ "initial_instance_count" : 1 ,
533+ }
534+ _run_test_shap (
535+ name_from_base ,
536+ clarify_processor ,
537+ clarify_processor_with_job_name_prefix ,
538+ data_config ,
539+ model_config ,
540+ shap_config ,
541+ None ,
542+ expected_predictor_config ,
543+ )
544+
545+
546+ @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
547+ def test_shap_with_predicted_label (
548+ name_from_base ,
549+ clarify_processor ,
550+ clarify_processor_with_job_name_prefix ,
551+ data_config ,
552+ model_config ,
553+ shap_config ,
554+ ):
555+ probability = "pr"
556+ label_headers = ["success" ]
557+ model_scores = ModelPredictedLabelConfig (
558+ probability = probability ,
559+ label_headers = label_headers ,
560+ )
561+ expected_predictor_config = {
562+ "model_name" : "xgboost-model" ,
563+ "instance_type" : "ml.c5.xlarge" ,
564+ "initial_instance_count" : 1 ,
565+ "probability" : probability ,
566+ "label_headers" : label_headers ,
567+ }
568+ _run_test_shap (
569+ name_from_base ,
570+ clarify_processor ,
571+ clarify_processor_with_job_name_prefix ,
572+ data_config ,
573+ model_config ,
574+ shap_config ,
575+ model_scores ,
576+ expected_predictor_config ,
577+ )
0 commit comments