@@ -227,6 +227,7 @@ def _build_create_job_definition_request(
227227 env = None ,
228228 tags = None ,
229229 network_config = None ,
230+ batch_transform_input = None ,
230231 ):
231232 """Build the request for job definition creation API
232233
@@ -270,6 +271,8 @@ def _build_create_job_definition_request(
270271 network_config (sagemaker.network.NetworkConfig): A NetworkConfig
271272 object that configures network isolation, encryption of
272273 inter-container traffic, security group IDs, and subnets.
274+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
275+ the monitoring schedule on the batch transform
273276
274277 Returns:
275278 dict: request parameters to create job definition.
@@ -366,6 +369,27 @@ def _build_create_job_definition_request(
366369 latest_baselining_job_config .probability_threshold_attribute
367370 )
368371 job_input = normalized_endpoint_input ._to_request_dict ()
372+ elif batch_transform_input is not None :
373+ # backfill attributes to batch transform input
374+ if latest_baselining_job_config is not None :
375+ if batch_transform_input .features_attribute is None :
376+ batch_transform_input .features_attribute = (
377+ latest_baselining_job_config .features_attribute
378+ )
379+ if batch_transform_input .inference_attribute is None :
380+ batch_transform_input .inference_attribute = (
381+ latest_baselining_job_config .inference_attribute
382+ )
383+ if batch_transform_input .probability_attribute is None :
384+ batch_transform_input .probability_attribute = (
385+ latest_baselining_job_config .probability_attribute
386+ )
387+ if batch_transform_input .probability_threshold_attribute is None :
388+ batch_transform_input .probability_threshold_attribute = (
389+ latest_baselining_job_config .probability_threshold_attribute
390+ )
391+ job_input = batch_transform_input ._to_request_dict ()
392+
369393 if ground_truth_input is not None :
370394 job_input ["GroundTruthS3Input" ] = dict (S3Uri = ground_truth_input )
371395
@@ -500,37 +524,46 @@ def suggest_baseline(
500524 # noinspection PyMethodOverriding
501525 def create_monitoring_schedule (
502526 self ,
503- endpoint_input ,
504- ground_truth_input ,
527+ endpoint_input = None ,
528+ ground_truth_input = None ,
505529 analysis_config = None ,
506530 output_s3_uri = None ,
507531 constraints = None ,
508532 monitor_schedule_name = None ,
509533 schedule_cron_expression = None ,
510534 enable_cloudwatch_metrics = True ,
535+ batch_transform_input = None ,
511536 ):
512537 """Creates a monitoring schedule.
513538
514539 Args:
515540 endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
516- This can either be the endpoint name or an EndpointInput.
517- ground_truth_input (str): S3 URI to ground truth dataset.
541+ This can either be the endpoint name or an EndpointInput. (default: None)
542+ ground_truth_input (str): S3 URI to ground truth dataset. (default: None)
518543 analysis_config (str or BiasAnalysisConfig): URI to analysis_config for the bias job.
519544 If it is None then configuration of the latest baselining job will be reused, but
520- if no baselining job then fail the call.
545+ if no baselining job then fail the call. (default: None)
521546 output_s3_uri (str): S3 destination of the constraint_violations and analysis result.
522- Default: "s3://<default_session_bucket>/<job_name>/output"
547+ Default: "s3://<default_session_bucket>/<job_name>/output" (default: None)
523548 constraints (sagemaker.model_monitor.Constraints or str): If provided it will be used
524549 for monitoring the endpoint. It can be a Constraints object or an S3 uri pointing
525- to a constraints JSON file.
550+ to a constraints JSON file. (default: None)
526551 monitor_schedule_name (str): Schedule name. If not specified, the processor generates
527552 a default job name, based on the image name and current timestamp.
553+ (default: None)
528554 schedule_cron_expression (str): The cron expression that dictates the frequency that
529555 this job run. See sagemaker.model_monitor.CronExpressionGenerator for valid
530- expressions. Default: Daily.
556+ expressions. Default: Daily. (default: None)
531557 enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
532- the baselining or monitoring jobs.
558+ the baselining or monitoring jobs. (default: True)
559+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
560+ the monitoring schedule on the batch transform (default: None)
533561 """
562+ # we default ground_truth_input to None in the function signature
563+ # but verify they are giving here for positional argument
564+ # backward compatibility reason.
565+ if not ground_truth_input :
566+ raise ValueError ("ground_truth_input can not be None." )
534567 if self .job_definition_name is not None or self .monitoring_schedule_name is not None :
535568 message = (
536569 "It seems that this object was already used to create an Amazon Model "
@@ -540,6 +573,15 @@ def create_monitoring_schedule(
540573 _LOGGER .error (message )
541574 raise ValueError (message )
542575
576+ if (batch_transform_input is not None ) ^ (endpoint_input is None ):
577+ message = (
578+ "Need to have either batch_transform_input or endpoint_input to create an "
579+ "Amazon Model Monitoring Schedule. "
580+ "Please provide only one of the above required inputs"
581+ )
582+ _LOGGER .error (message )
583+ raise ValueError (message )
584+
543585 # create job definition
544586 monitor_schedule_name = self ._generate_monitoring_schedule_name (
545587 schedule_name = monitor_schedule_name
@@ -569,6 +611,7 @@ def create_monitoring_schedule(
569611 env = self .env ,
570612 tags = self .tags ,
571613 network_config = self .network_config ,
614+ batch_transform_input = batch_transform_input ,
572615 )
573616 self .sagemaker_session .sagemaker_client .create_model_bias_job_definition (** request_dict )
574617
@@ -612,6 +655,7 @@ def update_monitoring_schedule(
612655 max_runtime_in_seconds = None ,
613656 env = None ,
614657 network_config = None ,
658+ batch_transform_input = None ,
615659 ):
616660 """Updates the existing monitoring schedule.
617661
@@ -651,6 +695,8 @@ def update_monitoring_schedule(
651695 network_config (sagemaker.network.NetworkConfig): A NetworkConfig
652696 object that configures network isolation, encryption of
653697 inter-container traffic, security group IDs, and subnets.
698+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to run
699+ the monitoring schedule on the batch transform
654700 """
655701 valid_args = {
656702 arg : value for arg , value in locals ().items () if arg != "self" and value is not None
@@ -660,6 +706,15 @@ def update_monitoring_schedule(
660706 if len (valid_args ) <= 0 :
661707 return
662708
709+ if batch_transform_input is not None and endpoint_input is not None :
710+ message = (
711+ "Need to have either batch_transform_input or endpoint_input to create an "
712+ "Amazon Model Monitoring Schedule. "
713+ "Please provide only one of the above required inputs"
714+ )
715+ _LOGGER .error (message )
716+ raise ValueError (message )
717+
663718 # Only need to update schedule expression
664719 if len (valid_args ) == 1 and schedule_cron_expression is not None :
665720 self ._update_monitoring_schedule (self .job_definition_name , schedule_cron_expression )
@@ -691,6 +746,7 @@ def update_monitoring_schedule(
691746 env = env ,
692747 tags = self .tags ,
693748 network_config = network_config ,
749+ batch_transform_input = batch_transform_input ,
694750 )
695751 self .sagemaker_session .sagemaker_client .create_model_bias_job_definition (** request_dict )
696752 try :
@@ -895,19 +951,20 @@ def suggest_baseline(
895951 # noinspection PyMethodOverriding
896952 def create_monitoring_schedule (
897953 self ,
898- endpoint_input ,
954+ endpoint_input = None ,
899955 analysis_config = None ,
900956 output_s3_uri = None ,
901957 constraints = None ,
902958 monitor_schedule_name = None ,
903959 schedule_cron_expression = None ,
904960 enable_cloudwatch_metrics = True ,
961+ batch_transform_input = None ,
905962 ):
906963 """Creates a monitoring schedule.
907964
908965 Args:
909966 endpoint_input (str or sagemaker.model_monitor.EndpointInput): The endpoint to monitor.
910- This can either be the endpoint name or an EndpointInput.
967+ This can either be the endpoint name or an EndpointInput. (default: None)
911968 analysis_config (str or ExplainabilityAnalysisConfig): URI to the analysis_config for
912969 the explainability job. If it is None then configuration of the latest baselining
913970 job will be reused, but if no baselining job then fail the call.
@@ -923,6 +980,8 @@ def create_monitoring_schedule(
923980 expressions. Default: Daily.
924981 enable_cloudwatch_metrics (bool): Whether to publish cloudwatch metrics as part of
925982 the baselining or monitoring jobs.
983+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
984+ run the monitoring schedule on the batch transform
926985 """
927986 if self .job_definition_name is not None or self .monitoring_schedule_name is not None :
928987 message = (
@@ -933,6 +992,15 @@ def create_monitoring_schedule(
933992 _LOGGER .error (message )
934993 raise ValueError (message )
935994
995+ if (batch_transform_input is not None ) ^ (endpoint_input is None ):
996+ message = (
997+ "Need to have either batch_transform_input or endpoint_input to create an "
998+ "Amazon Model Monitoring Schedule."
999+ "Please provide only one of the above required inputs"
1000+ )
1001+ _LOGGER .error (message )
1002+ raise ValueError (message )
1003+
9361004 # create job definition
9371005 monitor_schedule_name = self ._generate_monitoring_schedule_name (
9381006 schedule_name = monitor_schedule_name
@@ -961,6 +1029,7 @@ def create_monitoring_schedule(
9611029 env = self .env ,
9621030 tags = self .tags ,
9631031 network_config = self .network_config ,
1032+ batch_transform_input = batch_transform_input ,
9641033 )
9651034 self .sagemaker_session .sagemaker_client .create_model_explainability_job_definition (
9661035 ** request_dict
@@ -1005,6 +1074,7 @@ def update_monitoring_schedule(
10051074 max_runtime_in_seconds = None ,
10061075 env = None ,
10071076 network_config = None ,
1077+ batch_transform_input = None ,
10081078 ):
10091079 """Updates the existing monitoring schedule.
10101080
@@ -1043,6 +1113,8 @@ def update_monitoring_schedule(
10431113 network_config (sagemaker.network.NetworkConfig): A NetworkConfig
10441114 object that configures network isolation, encryption of
10451115 inter-container traffic, security group IDs, and subnets.
1116+ batch_transform_input (sagemaker.model_monitor.BatchTransformInput): Inputs to
1117+ run the monitoring schedule on the batch transform
10461118 """
10471119 valid_args = {
10481120 arg : value for arg , value in locals ().items () if arg != "self" and value is not None
@@ -1052,6 +1124,15 @@ def update_monitoring_schedule(
10521124 if len (valid_args ) <= 0 :
10531125 raise ValueError ("Nothing to update." )
10541126
1127+ if batch_transform_input is not None and endpoint_input is not None :
1128+ message = (
1129+ "Need to have either batch_transform_input or endpoint_input to create an "
1130+ "Amazon Model Monitoring Schedule. "
1131+ "Please provide only one of the above required inputs"
1132+ )
1133+ _LOGGER .error (message )
1134+ raise ValueError (message )
1135+
10551136 # Only need to update schedule expression
10561137 if len (valid_args ) == 1 and schedule_cron_expression is not None :
10571138 self ._update_monitoring_schedule (self .job_definition_name , schedule_cron_expression )
@@ -1084,6 +1165,7 @@ def update_monitoring_schedule(
10841165 env = env ,
10851166 tags = self .tags ,
10861167 network_config = network_config ,
1168+ batch_transform_input = batch_transform_input ,
10871169 )
10881170 self .sagemaker_session .sagemaker_client .create_model_explainability_job_definition (
10891171 ** request_dict
0 commit comments