2828 SageMakerClarifyProcessor ,
2929 SHAPConfig ,
3030 TextConfig ,
31+ ImageConfig ,
3132)
3233
3334JOB_NAME_PREFIX = "my-prefix"
@@ -254,17 +255,34 @@ def test_shap_config():
254255 seed = 123
255256 granularity = "sentence"
256257 language = "german"
258+ model_type = "IMAGE_CLASSIFICATION"
259+ num_segments = 2
260+ feature_extraction_method = "segmentation"
261+ segment_compactness = 10
262+ max_objects = 4
263+ iou_threshold = 0.5
264+ context = 1.0
257265 text_config = TextConfig (
258266 granularity = granularity ,
259267 language = language ,
260268 )
269+ image_config = ImageConfig (
270+ model_type = model_type ,
271+ num_segments = num_segments ,
272+ feature_extraction_method = feature_extraction_method ,
273+ segment_compactness = segment_compactness ,
274+ max_objects = max_objects ,
275+ iou_threshold = iou_threshold ,
276+ context = context ,
277+ )
261278 shap_config = SHAPConfig (
262279 baseline = baseline ,
263280 num_samples = num_samples ,
264281 agg_method = agg_method ,
265282 use_logit = use_logit ,
266283 seed = seed ,
267284 text_config = text_config ,
285+ image_config = image_config ,
268286 )
269287 expected_config = {
270288 "shap" : {
@@ -278,6 +296,15 @@ def test_shap_config():
278296 "granularity" : granularity ,
279297 "language" : language ,
280298 },
299+ "image_config" : {
300+ "model_type" : model_type ,
301+ "num_segments" : num_segments ,
302+ "feature_extraction_method" : feature_extraction_method ,
303+ "segment_compactness" : segment_compactness ,
304+ "max_objects" : max_objects ,
305+ "iou_threshold" : iou_threshold ,
306+ "context" : context ,
307+ },
281308 }
282309 }
283310 assert expected_config == shap_config .get_explainability_config ()
@@ -359,6 +386,50 @@ def test_invalid_text_config():
359386 assert "Invalid language invalid. Please choose among ['chinese'," in str (error .value )
360387
361388
389+ def test_image_config ():
390+ model_type = "IMAGE_CLASSIFICATION"
391+ num_segments = 2
392+ feature_extraction_method = "segmentation"
393+ segment_compactness = 10
394+ max_objects = 4
395+ iou_threshold = 0.5
396+ context = 1.0
397+ image_config = ImageConfig (
398+ model_type = model_type ,
399+ num_segments = num_segments ,
400+ feature_extraction_method = feature_extraction_method ,
401+ segment_compactness = segment_compactness ,
402+ max_objects = max_objects ,
403+ iou_threshold = iou_threshold ,
404+ context = context ,
405+ )
406+ expected_config = {
407+ "model_type" : model_type ,
408+ "num_segments" : num_segments ,
409+ "feature_extraction_method" : feature_extraction_method ,
410+ "segment_compactness" : segment_compactness ,
411+ "max_objects" : max_objects ,
412+ "iou_threshold" : iou_threshold ,
413+ "context" : context ,
414+ }
415+
416+ assert expected_config == image_config .get_image_config ()
417+
418+
419+ def test_invalid_image_config ():
420+ model_type = "OBJECT_SEGMENTATION"
421+ num_segments = 2
422+ with pytest .raises (ValueError ) as error :
423+ ImageConfig (
424+ model_type = model_type ,
425+ num_segments = num_segments ,
426+ )
427+ assert (
428+ "Clarify SHAP only supports object detection and image classification methods. "
429+ "Please set model_type to OBJECT_DETECTION or IMAGE_CLASSIFICATION." in str (error .value )
430+ )
431+
432+
362433def test_invalid_shap_config ():
363434 with pytest .raises (ValueError ) as error :
364435 SHAPConfig (
@@ -665,6 +736,7 @@ def _run_test_explain(
665736 model_scores ,
666737 expected_predictor_config ,
667738 expected_text_config = None ,
739+ expected_image_config = None ,
668740):
669741 with patch .object (SageMakerClarifyProcessor , "_run" , return_value = None ) as mock_method :
670742 explanation_configs = None
@@ -684,21 +756,6 @@ def _run_test_explain(
684756 job_name = "test" ,
685757 experiment_config = {"ExperimentName" : "AnExperiment" },
686758 )
687- expected_shap_config = {
688- "baseline" : [
689- [
690- 0.26124998927116394 ,
691- 0.2824999988079071 ,
692- 0.06875000149011612 ,
693- ]
694- ],
695- "num_samples" : 100 ,
696- "agg_method" : "mean_sq" ,
697- "use_logit" : False ,
698- "save_local_shap_values" : True ,
699- }
700- if expected_text_config :
701- expected_shap_config ["text_config" ] = expected_text_config
702759 expected_analysis_config = {
703760 "dataset_type" : "text/csv" ,
704761 "headers" : [
@@ -710,9 +767,6 @@ def _run_test_explain(
710767 ],
711768 "label" : "Label" ,
712769 "joinsource_name_or_index" : "F4" ,
713- "methods" : {
714- "shap" : expected_shap_config ,
715- },
716770 "predictor" : expected_predictor_config ,
717771 }
718772 expected_explanation_configs = {}
@@ -732,6 +786,8 @@ def _run_test_explain(
732786 }
733787 if expected_text_config :
734788 expected_explanation_configs ["shap" ]["text_config" ] = expected_text_config
789+ if expected_image_config :
790+ expected_explanation_configs ["shap" ]["image_config" ] = expected_image_config
735791 if pdp_config :
736792 expected_explanation_configs ["pdp" ] = {
737793 "features" : ["F1" , "F2" ],
@@ -963,3 +1019,70 @@ def test_shap_with_text_config(
9631019 expected_predictor_config ,
9641020 expected_text_config = expected_text_config ,
9651021 )
1022+
1023+
1024+ @patch ("sagemaker.utils.name_from_base" , return_value = JOB_NAME )
1025+ def test_shap_with_image_config (
1026+ name_from_base ,
1027+ clarify_processor ,
1028+ clarify_processor_with_job_name_prefix ,
1029+ data_config ,
1030+ model_config ,
1031+ ):
1032+ model_type = "IMAGE_CLASSIFICATION"
1033+ num_segments = 2
1034+ feature_extraction_method = "segmentation"
1035+ segment_compactness = 10
1036+ max_objects = 4
1037+ iou_threshold = 0.5
1038+ context = 1.0
1039+ image_config = ImageConfig (
1040+ model_type = model_type ,
1041+ num_segments = num_segments ,
1042+ feature_extraction_method = feature_extraction_method ,
1043+ segment_compactness = segment_compactness ,
1044+ max_objects = max_objects ,
1045+ iou_threshold = iou_threshold ,
1046+ context = context ,
1047+ )
1048+
1049+ shap_config = SHAPConfig (
1050+ baseline = [
1051+ [
1052+ 0.26124998927116394 ,
1053+ 0.2824999988079071 ,
1054+ 0.06875000149011612 ,
1055+ ]
1056+ ],
1057+ num_samples = 100 ,
1058+ agg_method = "mean_sq" ,
1059+ image_config = image_config ,
1060+ )
1061+
1062+ expected_image_config = {
1063+ "model_type" : model_type ,
1064+ "num_segments" : num_segments ,
1065+ "feature_extraction_method" : feature_extraction_method ,
1066+ "segment_compactness" : segment_compactness ,
1067+ "max_objects" : max_objects ,
1068+ "iou_threshold" : iou_threshold ,
1069+ "context" : context ,
1070+ }
1071+ expected_predictor_config = {
1072+ "model_name" : "xgboost-model" ,
1073+ "instance_type" : "ml.c5.xlarge" ,
1074+ "initial_instance_count" : 1 ,
1075+ }
1076+
1077+ _run_test_explain (
1078+ name_from_base ,
1079+ clarify_processor ,
1080+ clarify_processor_with_job_name_prefix ,
1081+ data_config ,
1082+ model_config ,
1083+ shap_config ,
1084+ None ,
1085+ None ,
1086+ expected_predictor_config ,
1087+ expected_image_config = expected_image_config ,
1088+ )
0 commit comments