@@ -403,6 +403,7 @@ def _run(
403403 logs ,
404404 job_name ,
405405 kms_key ,
406+ experiment_config ,
406407 ):
407408 """Runs a ProcessingJob with the Sagemaker Clarify container and an analysis config.
408409
@@ -415,6 +416,9 @@ def _run(
415416 job_name (str): Processing job name.
416417 kms_key (str): The ARN of the KMS key that is used to encrypt the
417418 user code file (default: None).
419+ experiment_config (dict[str, str]): Experiment management configuration.
420+ Dictionary contains three optional keys:
421+ 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
418422 """
419423 analysis_config ["methods" ]["report" ] = {"name" : "report" , "title" : "Analysis Report" }
420424 with tempfile .TemporaryDirectory () as tmpdirname :
@@ -457,6 +461,7 @@ def _run(
457461 logs = logs ,
458462 job_name = job_name ,
459463 kms_key = kms_key ,
464+ experiment_config = experiment_config ,
460465 )
461466
462467 def run_pre_training_bias (
@@ -468,6 +473,7 @@ def run_pre_training_bias(
468473 logs = True ,
469474 job_name = None ,
470475 kms_key = None ,
476+ experiment_config = None ,
471477 ):
472478 """Runs a ProcessingJob to compute the requested bias 'methods' of the input data.
473479
@@ -487,13 +493,16 @@ def run_pre_training_bias(
487493 "Clarify-Pretraining-Bias" and current timestamp.
488494 kms_key (str): The ARN of the KMS key that is used to encrypt the
489495 user code file (default: None).
496+ experiment_config (dict[str, str]): Experiment management configuration.
497+ Dictionary contains three optional keys:
498+ 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
490499 """
491500 analysis_config = data_config .get_config ()
492501 analysis_config .update (data_bias_config .get_config ())
493502 analysis_config ["methods" ] = {"pre_training_bias" : {"methods" : methods }}
494503 if job_name is None :
495504 job_name = utils .name_from_base ("Clarify-Pretraining-Bias" )
496- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key )
505+ self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
497506
498507 def run_post_training_bias (
499508 self ,
@@ -506,6 +515,7 @@ def run_post_training_bias(
506515 logs = True ,
507516 job_name = None ,
508517 kms_key = None ,
518+ experiment_config = None ,
509519 ):
510520 """Runs a ProcessingJob to compute the requested bias 'methods' of the model predictions.
511521
@@ -532,6 +542,9 @@ def run_post_training_bias(
532542 "Clarify-Posttraining-Bias" and current timestamp.
533543 kms_key (str): The ARN of the KMS key that is used to encrypt the
534544 user code file (default: None).
545+ experiment_config (dict[str, str]): Experiment management configuration.
546+ Dictionary contains three optional keys:
547+ 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
535548 """
536549 analysis_config = data_config .get_config ()
537550 analysis_config .update (data_bias_config .get_config ())
@@ -545,7 +558,7 @@ def run_post_training_bias(
545558 _set (probability_threshold , "probability_threshold" , analysis_config )
546559 if job_name is None :
547560 job_name = utils .name_from_base ("Clarify-Posttraining-Bias" )
548- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key )
561+ self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
549562
550563 def run_bias (
551564 self ,
@@ -559,6 +572,7 @@ def run_bias(
559572 logs = True ,
560573 job_name = None ,
561574 kms_key = None ,
575+ experiment_config = None ,
562576 ):
563577 """Runs a ProcessingJob to compute the requested bias 'methods' of the model predictions.
564578
@@ -589,6 +603,9 @@ def run_bias(
589603 "Clarify-Bias" and current timestamp.
590604 kms_key (str): The ARN of the KMS key that is used to encrypt the
591605 user code file (default: None).
606+ experiment_config (dict[str, str]): Experiment management configuration.
607+ Dictionary contains three optional keys:
608+ 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
592609 """
593610 analysis_config = data_config .get_config ()
594611 analysis_config .update (bias_config .get_config ())
@@ -609,7 +626,7 @@ def run_bias(
609626 }
610627 if job_name is None :
611628 job_name = utils .name_from_base ("Clarify-Bias" )
612- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key )
629+ self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
613630
614631 def run_explainability (
615632 self ,
@@ -621,6 +638,7 @@ def run_explainability(
621638 logs = True ,
622639 job_name = None ,
623640 kms_key = None ,
641+ experiment_config = None ,
624642 ):
625643 """Runs a ProcessingJob computing for each example in the input the feature importance.
626644
@@ -649,6 +667,9 @@ def run_explainability(
649667 "Clarify-Explainability" and current timestamp.
650668 kms_key (str): The ARN of the KMS key that is used to encrypt the
651669 user code file (default: None).
670+ experiment_config (dict[str, str]): Experiment management configuration.
671+ Dictionary contains three optional keys:
672+ 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
652673 """
653674 analysis_config = data_config .get_config ()
654675 predictor_config = model_config .get_predictor_config ()
@@ -657,7 +678,7 @@ def run_explainability(
657678 analysis_config ["predictor" ] = predictor_config
658679 if job_name is None :
659680 job_name = utils .name_from_base ("Clarify-Explainability" )
660- self ._run (data_config , analysis_config , wait , logs , job_name , kms_key )
681+ self ._run (data_config , analysis_config , wait , logs , job_name , kms_key , experiment_config )
661682
662683
663684def _upload_analysis_config (analysis_config_file , s3_output_path , sagemaker_session , kms_key ):
0 commit comments