@@ -42,39 +42,57 @@ def test_uri():
4242 assert "306415355426.dkr.ecr.us-west-2.amazonaws.com/sagemaker-clarify-processing:1.0" == uri
4343
4444
45- def test_data_config ():
45+ @pytest .mark .parametrize (
46+ ("dataset_type" , "features" , "excluded_columns" , "predicted_label" ),
47+ [
48+ ("text/csv" , None , ["F4" ], "Predicted Label" ),
49+ ("application/jsonlines" , None , ["F4" ], "Predicted Label" ),
50+ ("application/json" , "[*].[F1,F2,F3]" , ["F4" ], "Predicted Label" ),
51+ ("application/x-parquet" , None , ["F4" ], "Predicted Label" ),
52+ ],
53+ )
54+ def test_data_config (dataset_type , features , excluded_columns , predicted_label ):
4655 # facets in input dataset
4756 s3_data_input_path = "s3://path/to/input.csv"
4857 s3_output_path = "s3://path/to/output"
4958 label_name = "Label"
50- headers = [
51- "Label" ,
52- "F1" ,
53- "F2" ,
54- "F3" ,
55- "F4" ,
56- ]
57- dataset_type = "text/csv"
59+ headers = ["Label" , "F1" , "F2" , "F3" , "F4" , "Predicted Label" ]
5860 data_config = DataConfig (
5961 s3_data_input_path = s3_data_input_path ,
6062 s3_output_path = s3_output_path ,
63+ features = features ,
6164 label = label_name ,
6265 headers = headers ,
6366 dataset_type = dataset_type ,
67+ excluded_columns = excluded_columns ,
68+ predicted_label = predicted_label ,
6469 )
6570
6671 expected_config = {
67- "dataset_type" : "text/csv" ,
72+ "dataset_type" : dataset_type ,
6873 "headers" : headers ,
6974 "label" : "Label" ,
7075 }
76+ if features :
77+ expected_config ["features" ] = features
78+ if excluded_columns :
79+ expected_config ["excluded_columns" ] = excluded_columns
80+ if predicted_label :
81+ expected_config ["predicted_label" ] = predicted_label
7182
7283 assert expected_config == data_config .get_config ()
7384 assert s3_data_input_path == data_config .s3_data_input_path
7485 assert s3_output_path == data_config .s3_output_path
7586 assert "None" == data_config .s3_compression_type
7687 assert "FullyReplicated" == data_config .s3_data_distribution_type
7788
89+
90+ def test_data_config_with_separate_facet_dataset ():
91+ s3_data_input_path = "s3://path/to/input.csv"
92+ s3_output_path = "s3://path/to/output"
93+ label_name = "Label"
94+ headers = ["Label" , "F1" , "F2" , "F3" , "F4" ]
95+
7896 # facets NOT in input dataset
7997 joinsource = 5
8098 facet_dataset_uri = "s3://path/to/facet.csv"
@@ -89,7 +107,7 @@ def test_data_config():
89107 s3_output_path = s3_output_path ,
90108 label = label_name ,
91109 headers = headers ,
92- dataset_type = dataset_type ,
110+ dataset_type = "text/csv" ,
93111 joinsource = joinsource ,
94112 facet_dataset_uri = facet_dataset_uri ,
95113 facet_headers = facet_headers ,
@@ -126,7 +144,7 @@ def test_data_config():
126144 s3_output_path = s3_output_path ,
127145 label = label_name ,
128146 headers = headers ,
129- dataset_type = dataset_type ,
147+ dataset_type = "text/csv" ,
130148 joinsource = joinsource ,
131149 excluded_columns = excluded_columns ,
132150 )
@@ -158,7 +176,7 @@ def test_invalid_data_config():
158176 DataConfig (
159177 s3_data_input_path = "s3://bucket/inputpath" ,
160178 s3_output_path = "s3://bucket/outputpath" ,
161- dataset_type = "application/x-parquet " ,
179+ dataset_type = "application/x-image " ,
162180 predicted_label = "label" ,
163181 )
164182 error_msg = r"^The parameter 'excluded_columns' is not supported for dataset_type"
@@ -189,6 +207,28 @@ def test_invalid_data_config():
189207 )
190208
191209
210+ # features JMESPath is required for JSON dataset types
211+ def test_json_type_data_config_missing_features ():
212+ # facets in input dataset
213+ s3_data_input_path = "s3://path/to/input.csv"
214+ s3_output_path = "s3://path/to/output"
215+ label_name = "Label"
216+ headers = ["Label" , "F1" , "F2" , "F3" , "F4" , "Predicted Label" ]
217+ with pytest .raises (
218+ ValueError , match = "features JMESPath is required for application/json dataset_type"
219+ ):
220+ DataConfig (
221+ s3_data_input_path = s3_data_input_path ,
222+ s3_output_path = s3_output_path ,
223+ features = None ,
224+ label = label_name ,
225+ headers = headers ,
226+ dataset_type = "application/json" ,
227+ excluded_columns = ["F4" ],
228+ predicted_label = "Predicted Label" ,
229+ )
230+
231+
192232def test_s3_data_distribution_type_ignorance ():
193233 data_config = DataConfig (
194234 s3_data_input_path = "s3://input/train.csv" ,
@@ -344,12 +384,25 @@ def test_facet_of_bias_config(facet_name, facet_values_or_threshold, expected_re
344384 assert bias_config .get_config () == expected_config
345385
346386
347- def test_model_config ():
387+ @pytest .mark .parametrize (
388+ ("content_type" , "accept_type" ),
389+ [
390+ # All the combinations of content_type and accept_type should be acceptable
391+ ("text/csv" , "text/csv" ),
392+ ("application/jsonlines" , "application/jsonlines" ),
393+ ("text/csv" , "application/json" ),
394+ ("application/jsonlines" , "application/json" ),
395+ ("application/jsonlines" , "text/csv" ),
396+ ("image/jpeg" , "text/csv" ),
397+ ("image/jpg" , "text/csv" ),
398+ ("image/png" , "text/csv" ),
399+ ("application/x-npy" , "text/csv" ),
400+ ],
401+ )
402+ def test_valid_model_config (content_type , accept_type ):
348403 model_name = "xgboost-model"
349404 instance_type = "ml.c5.xlarge"
350405 instance_count = 1
351- accept_type = "text/csv"
352- content_type = "application/jsonlines"
353406 custom_attributes = "c000b4f9-df62-4c85-a0bf-7c525f9104a4"
354407 target_model = "target_model_name"
355408 accelerator_type = "ml.eia1.medium"
0 commit comments