1313
1414from __future__ import print_function , absolute_import
1515
16+ import copy
17+
1618from mock import patch , Mock , MagicMock
1719import pytest
1820
2325 ModelConfig ,
2426 ModelPredictedLabelConfig ,
2527 SHAPConfig ,
28+ PDPConfig ,
2629)
2730from sagemaker import image_uris , Processor
2831
@@ -304,6 +307,14 @@ def test_shap_config_no_parameters():
304307 assert expected_config == shap_config .get_explainability_config ()
305308
306309
310+ def test_pdp_config ():
311+ pdp_config = PDPConfig (features = ["f1" , "f2" ], grid_resolution = 20 )
312+ expected_config = {
313+ "pdp" : {"features" : ["f1" , "f2" ], "grid_resolution" : 20 , "top_k_features" : 10 }
314+ }
315+ assert expected_config == pdp_config .get_explainability_config ()
316+
317+
307318def test_invalid_shap_config ():
308319 with pytest .raises (ValueError ) as error :
309320 SHAPConfig (
@@ -409,13 +420,18 @@ def shap_config():
409420 0.26124998927116394 ,
410421 0.2824999988079071 ,
411422 0.06875000149011612 ,
412- ]
423+ ],
413424 ],
414425 num_samples = 100 ,
415426 agg_method = "mean_sq" ,
416427 )
417428
418429
430+ @pytest .fixture (scope = "module" )
431+ def pdp_config ():
432+ return PDPConfig (features = ["F1" , "F2" ], grid_resolution = 20 )
433+
434+
419435@patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
420436def test_pre_training_bias (
421437 name_from_base ,
@@ -594,21 +610,30 @@ def test_run_on_s3_analysis_config_file(
594610 )
595611
596612
597- def _run_test_shap (
613+ def _run_test_explain (
598614 name_from_base ,
599615 clarify_processor ,
600616 clarify_processor_with_job_name_prefix ,
601617 data_config ,
602618 model_config ,
603619 shap_config ,
620+ pdp_config ,
604621 model_scores ,
605622 expected_predictor_config ,
606623):
607624 with patch .object (SageMakerClarifyProcessor , "_run" , return_value = None ) as mock_method :
625+ explanation_configs = None
626+ if shap_config and pdp_config :
627+ explanation_configs = [shap_config , pdp_config ]
628+ elif shap_config :
629+ explanation_configs = shap_config
630+ elif pdp_config :
631+ explanation_configs = pdp_config
632+
608633 clarify_processor .run_explainability (
609634 data_config ,
610635 model_config ,
611- shap_config ,
636+ explanation_configs ,
612637 model_scores = model_scores ,
613638 wait = True ,
614639 job_name = "test" ,
@@ -623,23 +648,30 @@ def _run_test_shap(
623648 "F3" ,
624649 ],
625650 "label" : "Label" ,
626- "methods" : {
627- "shap" : {
628- "baseline" : [
629- [
630- 0.26124998927116394 ,
631- 0.2824999988079071 ,
632- 0.06875000149011612 ,
633- ]
634- ],
635- "num_samples" : 100 ,
636- "agg_method" : "mean_sq" ,
637- "use_logit" : False ,
638- "save_local_shap_values" : True ,
639- }
640- },
641651 "predictor" : expected_predictor_config ,
642652 }
653+ expected_explanation_configs = {}
654+ if shap_config :
655+ expected_explanation_configs ["shap" ] = {
656+ "baseline" : [
657+ [
658+ 0.26124998927116394 ,
659+ 0.2824999988079071 ,
660+ 0.06875000149011612 ,
661+ ]
662+ ],
663+ "num_samples" : 100 ,
664+ "agg_method" : "mean_sq" ,
665+ "use_logit" : False ,
666+ "save_local_shap_values" : True ,
667+ }
668+ if pdp_config :
669+ expected_explanation_configs ["pdp" ] = {
670+ "features" : ["F1" , "F2" ],
671+ "grid_resolution" : 20 ,
672+ "top_k_features" : 10 ,
673+ }
674+ expected_analysis_config ["methods" ] = expected_explanation_configs
643675 mock_method .assert_called_with (
644676 data_config ,
645677 expected_analysis_config ,
@@ -652,7 +684,7 @@ def _run_test_shap(
652684 clarify_processor_with_job_name_prefix .run_explainability (
653685 data_config ,
654686 model_config ,
655- shap_config ,
687+ explanation_configs ,
656688 model_scores = model_scores ,
657689 wait = True ,
658690 experiment_config = {"ExperimentName" : "AnExperiment" },
@@ -669,6 +701,34 @@ def _run_test_shap(
669701 )
670702
671703
704+ @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
705+ def test_pdp (
706+ name_from_base ,
707+ clarify_processor ,
708+ clarify_processor_with_job_name_prefix ,
709+ data_config ,
710+ model_config ,
711+ shap_config ,
712+ pdp_config ,
713+ ):
714+ expected_predictor_config = {
715+ "model_name" : "xgboost-model" ,
716+ "instance_type" : "ml.c5.xlarge" ,
717+ "initial_instance_count" : 1 ,
718+ }
719+ _run_test_explain (
720+ name_from_base ,
721+ clarify_processor ,
722+ clarify_processor_with_job_name_prefix ,
723+ data_config ,
724+ model_config ,
725+ None ,
726+ pdp_config ,
727+ None ,
728+ expected_predictor_config ,
729+ )
730+
731+
672732@patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
673733def test_shap (
674734 name_from_base ,
@@ -683,18 +743,78 @@ def test_shap(
683743 "instance_type" : "ml.c5.xlarge" ,
684744 "initial_instance_count" : 1 ,
685745 }
686- _run_test_shap (
746+ _run_test_explain (
687747 name_from_base ,
688748 clarify_processor ,
689749 clarify_processor_with_job_name_prefix ,
690750 data_config ,
691751 model_config ,
692752 shap_config ,
693753 None ,
754+ None ,
694755 expected_predictor_config ,
695756 )
696757
697758
759+ @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
760+ def test_explainability_with_invalid_config (
761+ name_from_base ,
762+ clarify_processor ,
763+ clarify_processor_with_job_name_prefix ,
764+ data_config ,
765+ model_config ,
766+ ):
767+ expected_predictor_config = {
768+ "model_name" : "xgboost-model" ,
769+ "instance_type" : "ml.c5.xlarge" ,
770+ "initial_instance_count" : 1 ,
771+ }
772+ with pytest .raises (
773+ AttributeError , match = "'NoneType' object has no attribute 'get_explainability_config'"
774+ ):
775+ _run_test_explain (
776+ name_from_base ,
777+ clarify_processor ,
778+ clarify_processor_with_job_name_prefix ,
779+ data_config ,
780+ model_config ,
781+ None ,
782+ None ,
783+ None ,
784+ expected_predictor_config ,
785+ )
786+
787+
788+ @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
789+ def test_explainability_with_multiple_shap_config (
790+ name_from_base ,
791+ clarify_processor ,
792+ clarify_processor_with_job_name_prefix ,
793+ data_config ,
794+ model_config ,
795+ shap_config ,
796+ ):
797+ expected_predictor_config = {
798+ "model_name" : "xgboost-model" ,
799+ "instance_type" : "ml.c5.xlarge" ,
800+ "initial_instance_count" : 1 ,
801+ }
802+ with pytest .raises (ValueError , match = "Duplicate explainability configs are provided" ):
803+ second_shap_config = copy .deepcopy (shap_config )
804+ second_shap_config .shap_config ["num_samples" ] = 200
805+ _run_test_explain (
806+ name_from_base ,
807+ clarify_processor ,
808+ clarify_processor_with_job_name_prefix ,
809+ data_config ,
810+ model_config ,
811+ [shap_config , second_shap_config ],
812+ None ,
813+ None ,
814+ expected_predictor_config ,
815+ )
816+
817+
698818@patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
699819def test_shap_with_predicted_label (
700820 name_from_base ,
@@ -703,6 +823,7 @@ def test_shap_with_predicted_label(
703823 data_config ,
704824 model_config ,
705825 shap_config ,
826+ pdp_config ,
706827):
707828 probability = "pr"
708829 label_headers = ["success" ]
@@ -717,13 +838,14 @@ def test_shap_with_predicted_label(
717838 "probability" : probability ,
718839 "label_headers" : label_headers ,
719840 }
720- _run_test_shap (
841+ _run_test_explain (
721842 name_from_base ,
722843 clarify_processor ,
723844 clarify_processor_with_job_name_prefix ,
724845 data_config ,
725846 model_config ,
726847 shap_config ,
848+ pdp_config ,
727849 model_scores ,
728850 expected_predictor_config ,
729851 )
0 commit comments