2525
2626import tempfile
2727from abc import ABC , abstractmethod
28- from typing import List , Union , Dict
28+ from typing import List , Union , Dict , Optional , Any
2929
3030from sagemaker import image_uris , s3 , utils
31+ from sagemaker .session import Session
32+ from sagemaker .network import NetworkConfig
3133from sagemaker .processing import ProcessingInput , ProcessingOutput , Processor
3234
3335logger = logging .getLogger (__name__ )
@@ -38,21 +40,21 @@ class DataConfig:
3840
3941 def __init__ (
4042 self ,
41- s3_data_input_path ,
42- s3_output_path ,
43- s3_analysis_config_output_path = None ,
44- label = None ,
45- headers = None ,
46- features = None ,
47- dataset_type = "text/csv" ,
48- s3_compression_type = "None" ,
49- joinsource = None ,
50- facet_dataset_uri = None ,
51- facet_headers = None ,
52- predicted_label_dataset_uri = None ,
53- predicted_label_headers = None ,
54- predicted_label = None ,
55- excluded_columns = None ,
43+ s3_data_input_path : str ,
44+ s3_output_path : str ,
45+ s3_analysis_config_output_path : Optional [ str ] = None ,
46+ label : Optional [ str ] = None ,
47+ headers : Optional [ List [ str ]] = None ,
48+ features : Optional [ List [ str ]] = None ,
49+ dataset_type : str = "text/csv" ,
50+ s3_compression_type : str = "None" ,
51+ joinsource : Optional [ Union [ str , int ]] = None ,
52+ facet_dataset_uri : Optional [ str ] = None ,
53+ facet_headers : Optional [ List [ str ]] = None ,
54+ predicted_label_dataset_uri : Optional [ str ] = None ,
55+ predicted_label_headers : Optional [ List [ str ]] = None ,
56+ predicted_label : Optional [ Union [ str , int ]] = None ,
57+ excluded_columns : Optional [ Union [ List [ int ], List [ str ]]] = None ,
5658 ):
5759 """Initializes a configuration of both input and output datasets.
5860
@@ -65,7 +67,7 @@ def __init__(
6567 label (str): Target attribute of the model required by bias metrics.
6668 Specified as column name or index for CSV dataset or as JSONPath for JSONLines.
6769 *Required parameter* except for when the input dataset does not contain the label.
68- features (str): JSONPath for locating the feature columns for bias metrics if the
70+ features (List[ str] ): JSONPath for locating the feature columns for bias metrics if the
6971 dataset format is JSONLines.
7072 dataset_type (str): Format of the dataset. Valid values are ``"text/csv"`` for CSV,
7173 ``"application/jsonlines"`` for JSONLines, and
@@ -191,10 +193,10 @@ class BiasConfig:
191193
192194 def __init__ (
193195 self ,
194- label_values_or_threshold ,
195- facet_name ,
196- facet_values_or_threshold = None ,
197- group_name = None ,
196+ label_values_or_threshold : Union [ int , float , str ] ,
197+ facet_name : Union [ str , int , List [ str ], List [ int ]] ,
198+ facet_values_or_threshold : Optional [ Union [ int , float , str ]] = None ,
199+ group_name : Optional [ str ] = None ,
198200 ):
199201 """Initializes a configuration of the sensitive groups in the dataset.
200202
@@ -275,17 +277,17 @@ class ModelConfig:
275277
276278 def __init__ (
277279 self ,
278- model_name : str = None ,
279- instance_count : int = None ,
280- instance_type : str = None ,
281- accept_type : str = None ,
282- content_type : str = None ,
283- content_template : str = None ,
284- custom_attributes : str = None ,
285- accelerator_type : str = None ,
286- endpoint_name_prefix : str = None ,
287- target_model : str = None ,
288- endpoint_name : str = None ,
280+ model_name : Optional [ str ] = None ,
281+ instance_count : Optional [ int ] = None ,
282+ instance_type : Optional [ str ] = None ,
283+ accept_type : Optional [ str ] = None ,
284+ content_type : Optional [ str ] = None ,
285+ content_template : Optional [ str ] = None ,
286+ custom_attributes : Optional [ str ] = None ,
287+ accelerator_type : Optional [ str ] = None ,
288+ endpoint_name_prefix : Optional [ str ] = None ,
289+ target_model : Optional [ str ] = None ,
290+ endpoint_name : Optional [ str ] = None ,
289291 ):
290292 r"""Initializes a configuration of a model and the endpoint to be created for it.
291293
@@ -414,10 +416,10 @@ class ModelPredictedLabelConfig:
414416
415417 def __init__ (
416418 self ,
417- label = None ,
418- probability = None ,
419- probability_threshold = None ,
420- label_headers = None ,
419+ label : Optional [ Union [ str , int ]] = None ,
420+ probability : Optional [ Union [ str , int ]] = None ,
421+ probability_threshold : Optional [ float ] = None ,
422+ label_headers : Optional [ List [ str ]] = None ,
421423 ):
422424 """Initializes a model output config to extract the predicted label or predicted score(s).
423425
@@ -509,7 +511,9 @@ class PDPConfig(ExplainabilityConfig):
509511 and the corresponding values are included in the analysis output.
510512 """ # noqa E501
511513
512- def __init__ (self , features = None , grid_resolution = 15 , top_k_features = 10 ):
514+ def __init__ (
515+ self , features : Optional [List ] = None , grid_resolution : int = 15 , top_k_features : int = 10
516+ ):
513517 """Initializes PDP config.
514518
515519 Args:
@@ -680,8 +684,8 @@ class TextConfig:
680684
681685 def __init__ (
682686 self ,
683- granularity ,
684- language ,
687+ granularity : str ,
688+ language : str ,
685689 ):
686690 """Initializes a text configuration.
687691
@@ -736,13 +740,13 @@ class ImageConfig:
736740
737741 def __init__ (
738742 self ,
739- model_type ,
740- num_segments = None ,
741- feature_extraction_method = None ,
742- segment_compactness = None ,
743- max_objects = None ,
744- iou_threshold = None ,
745- context = None ,
743+ model_type : str ,
744+ num_segments : Optional [ int ] = None ,
745+ feature_extraction_method : Optional [ str ] = None ,
746+ segment_compactness : Optional [ float ] = None ,
747+ max_objects : Optional [ int ] = None ,
748+ iou_threshold : Optional [ float ] = None ,
749+ context : Optional [ float ] = None ,
746750 ):
747751 """Initializes a config object for Computer Vision (CV) Image explainability.
748752
@@ -817,15 +821,15 @@ class SHAPConfig(ExplainabilityConfig):
817821
818822 def __init__ (
819823 self ,
820- baseline = None ,
821- num_samples = None ,
822- agg_method = None ,
823- use_logit = False ,
824- save_local_shap_values = True ,
825- seed = None ,
826- num_clusters = None ,
827- text_config = None ,
828- image_config = None ,
824+ baseline : Optional [ Union [ str , List ]] = None ,
825+ num_samples : Optional [ int ] = None ,
826+ agg_method : Optional [ str ] = None ,
827+ use_logit : bool = False ,
828+ save_local_shap_values : bool = True ,
829+ seed : Optional [ int ] = None ,
830+ num_clusters : Optional [ int ] = None ,
831+ text_config : Optional [ TextConfig ] = None ,
832+ image_config : Optional [ ImageConfig ] = None ,
829833 ):
830834 """Initializes config for SHAP analysis.
831835
@@ -909,19 +913,19 @@ class SageMakerClarifyProcessor(Processor):
909913
910914 def __init__ (
911915 self ,
912- role ,
913- instance_count ,
914- instance_type ,
915- volume_size_in_gb = 30 ,
916- volume_kms_key = None ,
917- output_kms_key = None ,
918- max_runtime_in_seconds = None ,
919- sagemaker_session = None ,
920- env = None ,
921- tags = None ,
922- network_config = None ,
923- job_name_prefix = None ,
924- version = None ,
916+ role : str ,
917+ instance_count : int ,
918+ instance_type : str ,
919+ volume_size_in_gb : int = 30 ,
920+ volume_kms_key : Optional [ str ] = None ,
921+ output_kms_key : Optional [ str ] = None ,
922+ max_runtime_in_seconds : Optional [ int ] = None ,
923+ sagemaker_session : Optional [ Session ] = None ,
924+ env : Optional [ Dict [ str , str ]] = None ,
925+ tags : Optional [ List [ Dict [ str , str ]]] = None ,
926+ network_config : Optional [ NetworkConfig ] = None ,
927+ job_name_prefix : Optional [ str ] = None ,
928+ version : Optional [ str ] = None ,
925929 ):
926930 """Initializes a SageMakerClarifyProcessor to compute bias metrics and model explanations.
927931
@@ -993,13 +997,13 @@ def run(self, **_):
993997
994998 def _run (
995999 self ,
996- data_config ,
997- analysis_config ,
998- wait ,
999- logs ,
1000- job_name ,
1001- kms_key ,
1002- experiment_config ,
1000+ data_config : DataConfig ,
1001+ analysis_config : Dict [ str , Any ] ,
1002+ wait : bool ,
1003+ logs : bool ,
1004+ job_name : str ,
1005+ kms_key : str ,
1006+ experiment_config : Dict [ str , str ] ,
10031007 ):
10041008 """Runs a :class:`~sagemaker.processing.ProcessingJob` with the SageMaker Clarify container
10051009
@@ -1077,14 +1081,14 @@ def _run(
10771081
10781082 def run_pre_training_bias (
10791083 self ,
1080- data_config ,
1081- data_bias_config ,
1082- methods = "all" ,
1083- wait = True ,
1084- logs = True ,
1085- job_name = None ,
1086- kms_key = None ,
1087- experiment_config = None ,
1084+ data_config : DataConfig ,
1085+ data_bias_config : BiasConfig ,
1086+ methods : Union [ str , List [ str ]] = "all" ,
1087+ wait : bool = True ,
1088+ logs : bool = True ,
1089+ job_name : Optional [ str ] = None ,
1090+ kms_key : Optional [ str ] = None ,
1091+ experiment_config : Optional [ Dict [ str , str ]] = None ,
10881092 ):
10891093 """Runs a :class:`~sagemaker.processing.ProcessingJob` to compute pre-training bias methods
10901094
@@ -1146,16 +1150,16 @@ def run_pre_training_bias(
11461150
11471151 def run_post_training_bias (
11481152 self ,
1149- data_config ,
1150- data_bias_config ,
1151- model_config ,
1152- model_predicted_label_config ,
1153- methods = "all" ,
1154- wait = True ,
1155- logs = True ,
1156- job_name = None ,
1157- kms_key = None ,
1158- experiment_config = None ,
1153+ data_config : DataConfig ,
1154+ data_bias_config : BiasConfig ,
1155+ model_config : ModelConfig ,
1156+ model_predicted_label_config : ModelPredictedLabelConfig ,
1157+ methods : Union [ str , List [ str ]] = "all" ,
1158+ wait : bool = True ,
1159+ logs : bool = True ,
1160+ job_name : Optional [ str ] = None ,
1161+ kms_key : Optional [ str ] = None ,
1162+ experiment_config : Optional [ Dict [ str , str ]] = None ,
11591163 ):
11601164 """Runs a :class:`~sagemaker.processing.ProcessingJob` to compute posttraining bias
11611165
@@ -1231,17 +1235,17 @@ def run_post_training_bias(
12311235
12321236 def run_bias (
12331237 self ,
1234- data_config ,
1235- bias_config ,
1236- model_config ,
1237- model_predicted_label_config = None ,
1238- pre_training_methods = "all" ,
1239- post_training_methods = "all" ,
1240- wait = True ,
1241- logs = True ,
1242- job_name = None ,
1243- kms_key = None ,
1244- experiment_config = None ,
1238+ data_config : DataConfig ,
1239+ bias_config : BiasConfig ,
1240+ model_config : ModelConfig ,
1241+ model_predicted_label_config : Optional [ ModelPredictedLabelConfig ] = None ,
1242+ pre_training_methods : Union [ str , List [ str ]] = "all" ,
1243+ post_training_methods : Union [ str , List [ str ]] = "all" ,
1244+ wait : bool = True ,
1245+ logs : bool = True ,
1246+ job_name : Optional [ str ] = None ,
1247+ kms_key : Optional [ str ] = None ,
1248+ experiment_config : Optional [ Dict [ str , str ]] = None ,
12451249 ):
12461250 """Runs a :class:`~sagemaker.processing.ProcessingJob` to compute the requested bias methods
12471251
@@ -1325,15 +1329,15 @@ def run_bias(
13251329
13261330 def run_explainability (
13271331 self ,
1328- data_config ,
1329- model_config ,
1330- explainability_config ,
1331- model_scores = None ,
1332- wait = True ,
1333- logs = True ,
1334- job_name = None ,
1335- kms_key = None ,
1336- experiment_config = None ,
1332+ data_config : DataConfig ,
1333+ model_config : ModelConfig ,
1334+ explainability_config : Union [ ExplainabilityConfig , List ] ,
1335+ model_scores : Optional [ Union [ int , str , ModelPredictedLabelConfig ]] = None ,
1336+ wait : bool = True ,
1337+ logs : bool = True ,
1338+ job_name : Optional [ str ] = None ,
1339+ kms_key : Optional [ str ] = None ,
1340+ experiment_config : Optional [ Dict [ str , str ]] = None ,
13371341 ):
13381342 """Runs a :class:`~sagemaker.processing.ProcessingJob` computing feature attributions.
13391343
0 commit comments