5252 MonitoringAlertActions ,
5353 ModelDashboardIndicatorAction ,
5454)
55+ from sagemaker .model_monitor .data_quality_monitoring_config import DataQualityMonitoringConfig
5556from sagemaker .model_monitor .dataset_format import MonitoringDatasetFormat
5657from sagemaker .network import NetworkConfig
5758from sagemaker .processing import Processor , ProcessingInput , ProcessingJob , ProcessingOutput
9899_INFERENCE_ATTRIBUTE_ENV_NAME = "inference_attribute"
99100_PROBABILITY_ATTRIBUTE_ENV_NAME = "probability_attribute"
100101_PROBABILITY_THRESHOLD_ATTRIBUTE_ENV_NAME = "probability_threshold_attribute"
102+ _CATEGORICAL_DRIFT_METHOD_ENV_NAME = "categorical_drift_method"
101103
102104_LOGGER = logging .getLogger (__name__ )
103105
@@ -1136,6 +1138,7 @@ def _generate_env_map(
11361138 probability_attribute = None ,
11371139 ground_truth_attribute = None ,
11381140 probability_threshold_attribute = None ,
1141+ categorical_drift_method = None ,
11391142 ):
11401143 """Generate a list of environment variables from first-class parameters.
11411144
@@ -1157,6 +1160,9 @@ def _generate_env_map(
11571160 Only used for ModelQualityMonitor.
11581161 probability_threshold_attribute (float): threshold to convert probabilities to binaries
11591162 Only used for ModelQualityMonitor.
1163+ categorical_drift_method (str): categorical_drift_method to override the
1164+ categorical_drift_method of global monitoring_config in constraints
1165+ suggested by Model Monitor container. Only used for DataQualityMonitor.
11601166
11611167 Returns:
11621168 dict: Dictionary of environment keys and values.
@@ -1206,6 +1212,9 @@ def _generate_env_map(
12061212 if probability_threshold_attribute is not None :
12071213 env [_PROBABILITY_THRESHOLD_ATTRIBUTE_ENV_NAME ] = probability_threshold_attribute
12081214
1215+ if categorical_drift_method is not None :
1216+ env [_CATEGORICAL_DRIFT_METHOD_ENV_NAME ] = categorical_drift_method
1217+
12091218 return env
12101219
12111220 @staticmethod
@@ -1647,6 +1656,7 @@ def suggest_baseline(
16471656 wait = True ,
16481657 logs = True ,
16491658 job_name = None ,
1659+ monitoring_config_override = None ,
16501660 ):
16511661 """Suggest baselines for use with Amazon SageMaker Model Monitoring Schedules.
16521662
@@ -1666,12 +1676,18 @@ def suggest_baseline(
16661676 Only meaningful when wait is True (default: True).
16671677 job_name (str): Processing job name. If not specified, the processor generates
16681678 a default job name, based on the image name and current timestamp.
1669-
1679+ monitoring_config_override (DataQualityMonitoringConfig): monitoring_config object to
1680+ override the global monitoring_config parameter of constraints suggested by
1681+ Model Monitor Container. If not specified, the values suggested by container is
1682+ set.
16701683 Returns:
16711684 sagemaker.processing.ProcessingJob: The ProcessingJob object representing the
16721685 baselining job.
16731686
16741687 """
1688+ if not DataQualityMonitoringConfig .valid_monitoring_config (monitoring_config_override ):
1689+ raise RuntimeError ("Invalid value for monitoring_config_override." )
1690+
16751691 self .latest_baselining_job_name = self ._generate_baselining_job_name (job_name = job_name )
16761692
16771693 normalized_baseline_dataset_input = self ._upload_and_convert_to_processing_input (
@@ -1731,6 +1747,11 @@ def suggest_baseline(
17311747
17321748 normalized_baseline_output = self ._normalize_baseline_output (output_s3_uri = output_s3_uri )
17331749
1750+ categorical_drift_method = None
1751+ if monitoring_config_override and monitoring_config_override .distribution_constraints :
1752+ distribution_constraints = monitoring_config_override .distribution_constraints
1753+ categorical_drift_method = distribution_constraints .categorical_drift_method
1754+
17341755 normalized_env = self ._generate_env_map (
17351756 env = self .env ,
17361757 dataset_format = dataset_format ,
@@ -1739,6 +1760,7 @@ def suggest_baseline(
17391760 dataset_source_container_path = baseline_dataset_container_path ,
17401761 record_preprocessor_script_container_path = record_preprocessor_script_container_path ,
17411762 post_processor_script_container_path = post_processor_script_container_path ,
1763+ categorical_drift_method = categorical_drift_method ,
17421764 )
17431765
17441766 baselining_processor = Processor (
0 commit comments