1111# ANY KIND, either express or implied. See the License for the specific
1212# language governing permissions and limitations under the License.
1313"""This module configures the SageMaker Clarify bias and model explainability processor job."""
14- from __future__ import print_function , absolute_import
14+ from __future__ import absolute_import , print_function
1515
1616import copy
17-
18- from abc import ABC , abstractmethod
1917import json
18+ import logging
2019import os
21- import tempfile
2220import re
23- from sagemaker .processing import ProcessingInput , ProcessingOutput , Processor
21+ import tempfile
22+ from abc import ABC , abstractmethod
23+
2424from sagemaker import image_uris , s3 , utils
25+ from sagemaker .processing import ProcessingInput , ProcessingOutput , Processor
26+
27+ logger = logging .getLogger (__name__ )
2528
2629
2730class DataConfig :
@@ -338,6 +341,121 @@ def get_explainability_config(self):
338341 return copy .deepcopy ({"pdp" : self .pdp_config })
339342
340343
344+ class TextConfig :
345+ """Config object to handle text features.
346+
347+ The SHAP analysis will break down longer text into chunks (e.g. tokens, sentences, or paragraphs
348+ ) and replace them with the strings specified in the baseline for that feature. The shap value
349+ of a chunk then captures how much replacing it affects the prediction.
350+ """
351+
352+ _SUPPORTED_GRANULARITIES = ["token" , "sentence" , "paragraph" ]
353+ _SUPPORTED_LANGUAGES = [
354+ "chinese" ,
355+ "danish" ,
356+ "dutch" ,
357+ "english" ,
358+ "french" ,
359+ "german" ,
360+ "greek" ,
361+ "italian" ,
362+ "japanese" ,
363+ "lithuanian" ,
364+ "multi-language" ,
365+ "norwegian bokmål" ,
366+ "polish" ,
367+ "portuguese" ,
368+ "romanian" ,
369+ "russian" ,
370+ "spanish" ,
371+ "afrikaans" ,
372+ "albanian" ,
373+ "arabic" ,
374+ "armenian" ,
375+ "basque" ,
376+ "bengali" ,
377+ "bulgarian" ,
378+ "catalan" ,
379+ "croatian" ,
380+ "czech" ,
381+ "estonian" ,
382+ "finnish" ,
383+ "gujarati" ,
384+ "hebrew" ,
385+ "hindi" ,
386+ "hungarian" ,
387+ "icelandic" ,
388+ "indonesian" ,
389+ "irish" ,
390+ "kannada" ,
391+ "kyrgyz" ,
392+ "latvian" ,
393+ "ligurian" ,
394+ "luxembourgish" ,
395+ "macedonian" ,
396+ "malayalam" ,
397+ "marathi" ,
398+ "nepali" ,
399+ "persian" ,
400+ "sanskrit" ,
401+ "serbian" ,
402+ "setswana" ,
403+ "sinhala" ,
404+ "slovak" ,
405+ "slovenian" ,
406+ "swedish" ,
407+ "tagalog" ,
408+ "tamil" ,
409+ "tatar" ,
410+ "telugu" ,
411+ "thai" ,
412+ "turkish" ,
413+ "ukrainian" ,
414+ "urdu" ,
415+ "vietnamese" ,
416+ "yoruba" ,
417+ ]
418+
419+ def __init__ (
420+ self ,
421+ granularity ,
422+ language ,
423+ ):
424+ """Initializes a text configuration.
425+
426+ Args: granularity (str): Determines the granularity in which text features are broken down
427+ to, can be "token", "sentence", or "paragraph". Shap values are computed for these units.
428+ language (str): Specifies the language of the text features, can be "chinese", "danish",
429+ "dutch", "english", "french", "german", "greek", "italian", "japanese", "lithuanian",
430+ "multi-language", "norwegian bokmål", "polish", "portuguese", "romanian", "russian",
431+ "spanish", "afrikaans", "albanian", "arabic", "armenian", "basque", "bengali", "bulgarian",
432+ "catalan", "croatian", "czech", "estonian", "finnish", "gujarati", "hebrew", "hindi",
433+ "hungarian", "icelandic", "indonesian", "irish", "kannada", "kyrgyz", "latvian", "ligurian",
434+ "luxembourgish", "macedonian", "malayalam", "marathi", "nepali", "persian", "sanskrit",
435+ "serbian", "setswana", "sinhala", "slovak", "slovenian", "swedish", "tagalog", "tamil",
436+ "tatar", "telugu", "thai", "turkish", "ukrainian", "urdu", "vietnamese", "yoruba". Use
437+ "multi-language" for a mix of mulitple languages.
438+ """
439+ if granularity not in TextConfig ._SUPPORTED_GRANULARITIES :
440+ raise ValueError (
441+ f"Invalid granularity { granularity } . Please choose among "
442+ f"{ TextConfig ._SUPPORTED_GRANULARITIES } "
443+ )
444+ if language not in TextConfig ._SUPPORTED_LANGUAGES :
445+ raise ValueError (
446+ f"Invalid language { language } . Please choose among "
447+ f"{ TextConfig ._SUPPORTED_LANGUAGES } "
448+ )
449+ self .text_config = {
450+ "granularity" : granularity ,
451+ "language" : language ,
452+ }
453+
454+ def get_text_config (self ):
455+ """Returns part of an analysis config dictionary."""
456+ return copy .deepcopy (self .text_config )
457+
458+
341459class SHAPConfig (ExplainabilityConfig ):
342460 """Config class of SHAP."""
343461
@@ -350,6 +468,7 @@ def __init__(
350468 save_local_shap_values = True ,
351469 seed = None ,
352470 num_clusters = None ,
471+ text_config = None ,
353472 ):
354473 """Initializes config for SHAP.
355474
@@ -378,6 +497,7 @@ def __init__(
378497 computes a baseline dataset via a clustering algorithm (K-means/K-prototypes).
379498 num_clusters is a parameter for this algorithm. num_clusters will be the resulting
380499 size of the baseline dataset. If not provided, Clarify job will use a default value.
500+ text_config (:class:`~sagemaker.clarify.TextConfig`): Config to handle text features
381501 """
382502 if agg_method is not None and agg_method not in ["mean_abs" , "median" , "mean_sq" ]:
383503 raise ValueError (
@@ -402,6 +522,15 @@ def __init__(
402522 self .shap_config ["seed" ] = seed
403523 if num_clusters is not None :
404524 self .shap_config ["num_clusters" ] = num_clusters
525+ _set (seed , "seed" , self .shap_config )
526+ if text_config :
527+ _set (text_config .get_text_config (), "text_config" , self .shap_config )
528+ if not save_local_shap_values :
529+ logger .warning (
530+ "Global aggregation is not yet supported for text features. "
531+ "Consider setting save_local_shap_values=True to inspect local text "
532+ "explanations."
533+ )
405534
406535 def get_explainability_config (self ):
407536 """Returns config."""
@@ -525,7 +654,10 @@ def _run(
525654 will be unassociated.
526655 * `TrialComponentDisplayName` is used for display in Studio.
527656 """
528- analysis_config ["methods" ]["report" ] = {"name" : "report" , "title" : "Analysis Report" }
657+ analysis_config ["methods" ]["report" ] = {
658+ "name" : "report" ,
659+ "title" : "Analysis Report" ,
660+ }
529661 with tempfile .TemporaryDirectory () as tmpdirname :
530662 analysis_config_file = os .path .join (tmpdirname , "analysis_config.json" )
531663 with open (analysis_config_file , "w" ) as f :
@@ -627,7 +759,15 @@ def run_pre_training_bias(
627759 job_name = utils .name_from_base (self .job_name_prefix )
628760 else :
629761 job_name = utils .name_from_base ("Clarify-Pretraining-Bias" )
630- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
762+ self ._run (
763+ data_config ,
764+ analysis_config ,
765+ wait ,
766+ logs ,
767+ job_name ,
768+ kms_key ,
769+ experiment_config ,
770+ )
631771
632772 def run_post_training_bias (
633773 self ,
@@ -705,7 +845,15 @@ def run_post_training_bias(
705845 job_name = utils .name_from_base (self .job_name_prefix )
706846 else :
707847 job_name = utils .name_from_base ("Clarify-Posttraining-Bias" )
708- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
848+ self ._run (
849+ data_config ,
850+ analysis_config ,
851+ wait ,
852+ logs ,
853+ job_name ,
854+ kms_key ,
855+ experiment_config ,
856+ )
709857
710858 def run_bias (
711859 self ,
@@ -800,7 +948,15 @@ def run_bias(
800948 job_name = utils .name_from_base (self .job_name_prefix )
801949 else :
802950 job_name = utils .name_from_base ("Clarify-Bias" )
803- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
951+ self ._run (
952+ data_config ,
953+ analysis_config ,
954+ wait ,
955+ logs ,
956+ job_name ,
957+ kms_key ,
958+ experiment_config ,
959+ )
804960
805961 def run_explainability (
806962 self ,
@@ -861,7 +1017,10 @@ def run_explainability(
8611017 analysis_config = data_config .get_config ()
8621018 predictor_config = model_config .get_predictor_config ()
8631019 if isinstance (model_scores , ModelPredictedLabelConfig ):
864- probability_threshold , predicted_label_config = model_scores .get_predictor_config ()
1020+ (
1021+ probability_threshold ,
1022+ predicted_label_config ,
1023+ ) = model_scores .get_predictor_config ()
8651024 _set (probability_threshold , "probability_threshold" , analysis_config )
8661025 predictor_config .update (predicted_label_config )
8671026 else :
@@ -896,7 +1055,15 @@ def run_explainability(
8961055 job_name = utils .name_from_base (self .job_name_prefix )
8971056 else :
8981057 job_name = utils .name_from_base ("Clarify-Explainability" )
899- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
1058+ self ._run (
1059+ data_config ,
1060+ analysis_config ,
1061+ wait ,
1062+ logs ,
1063+ job_name ,
1064+ kms_key ,
1065+ experiment_config ,
1066+ )
9001067
9011068
9021069def _upload_analysis_config (analysis_config_file , s3_output_path , sagemaker_session , kms_key ):
0 commit comments