1616import pytest
1717
1818from sagemaker .model_monitor import DatasetFormat
19- from sagemaker .workflow .parameters import ParameterString
19+ from sagemaker .workflow .execution_variables import ExecutionVariable
20+ from sagemaker .workflow .parameters import ParameterString , ParameterInteger
2021from sagemaker .workflow .pipeline import Pipeline
2122from sagemaker .workflow .pipeline import PipelineDefinitionConfig
2223from sagemaker .workflow .quality_check_step import (
178179 "dataset_source" : "/opt/ml/processing/input/baseline_dataset_input" ,
179180 "analysis_type" : "MODEL_QUALITY" ,
180181 "problem_type" : "BinaryClassification" ,
181- "probability_attribute" : "0" ,
182- "probability_threshold_attribute" : "0.5" ,
183182 },
184183 "StoppingCondition" : {"MaxRuntimeInSeconds" : 1800 },
185184 },
@@ -269,23 +268,54 @@ def test_data_quality_check_step(
269268 assert step_definition == _expected_data_quality_dsl
270269
271270
271+ @pytest .mark .parametrize (
272+ "quality_cfg_attr_value, expected_value_in_dsl" ,
273+ [
274+ (0 , "0" ),
275+ ("attr" , "attr" ),
276+ (None , None ),
277+ (ParameterString (name = "ParamStringEnvVar" ), {"Get" : "Parameters.ParamStringEnvVar" }),
278+ (ExecutionVariable ("PipelineArn" ), {"Get" : "Execution.PipelineArn" }),
279+ (ParameterInteger (name = "ParamIntEnvVar" ), "Error" ),
280+ ],
281+ )
272282def test_model_quality_check_step (
273283 sagemaker_session ,
274284 check_job_config ,
275285 model_package_group_name ,
276286 supplied_baseline_statistics_uri ,
277287 supplied_baseline_constraints_uri ,
288+ quality_cfg_attr_value ,
289+ expected_value_in_dsl ,
278290):
279291 model_quality_check_config = ModelQualityCheckConfig (
280292 baseline_dataset = "baseline_dataset_s3_url" ,
281293 dataset_format = DatasetFormat .csv (header = True ),
282294 problem_type = "BinaryClassification" ,
283- probability_attribute = 0 , # the integer should be converted to str by SDK
284- ground_truth_attribute = None ,
285- probability_threshold_attribute = 0.5 , # the float should be converted to str by SDK
295+ inference_attribute = quality_cfg_attr_value ,
296+ probability_attribute = quality_cfg_attr_value ,
297+ ground_truth_attribute = quality_cfg_attr_value ,
298+ probability_threshold_attribute = quality_cfg_attr_value ,
286299 post_analytics_processor_script = "s3://my_bucket/data_quality/postprocessor.py" ,
287300 output_s3_uri = "" ,
288301 )
302+
303+ if expected_value_in_dsl == "Error" :
304+ with pytest .raises (ValueError ) as err :
305+ QualityCheckStep (
306+ name = "ModelQualityCheckStep" ,
307+ register_new_baseline = False ,
308+ skip_check = False ,
309+ fail_on_violation = True ,
310+ quality_check_config = model_quality_check_config ,
311+ check_job_config = check_job_config ,
312+ model_package_group_name = model_package_group_name ,
313+ supplied_baseline_statistics = supplied_baseline_statistics_uri ,
314+ supplied_baseline_constraints = supplied_baseline_constraints_uri ,
315+ )
316+ assert "cannot be Parameter types other than ParameterString" in str (err )
317+ return
318+
289319 model_quality_check_step = QualityCheckStep (
290320 name = "ModelQualityCheckStep" ,
291321 register_new_baseline = False ,
@@ -297,6 +327,7 @@ def test_model_quality_check_step(
297327 supplied_baseline_statistics = supplied_baseline_statistics_uri ,
298328 supplied_baseline_constraints = supplied_baseline_constraints_uri ,
299329 )
330+
300331 pipeline = Pipeline (
301332 name = "MyPipeline" ,
302333 parameters = [
@@ -310,6 +341,16 @@ def test_model_quality_check_step(
310341
311342 step_definition = _get_step_definition_for_test (pipeline )
312343
344+ step_def_env = step_definition ["Arguments" ]["Environment" ]
345+ for var in [
346+ "inference_attribute" ,
347+ "probability_attribute" ,
348+ "ground_truth_attribute" ,
349+ "probability_threshold_attribute" ,
350+ ]:
351+ env_var_dsl = step_def_env .pop (var , None )
352+ assert env_var_dsl == expected_value_in_dsl
353+
313354 assert step_definition == _expected_model_quality_dsl
314355
315356
0 commit comments