3434 TRAINING_DATASETS_MAX_SIZE ,
3535 TRAINING_METRICS_MAX_SIZE ,
3636 USER_PROVIDED_TRAINING_METRICS_MAX_SIZE ,
37+ HYPER_PARAMETERS_MAX_SIZE ,
38+ USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE ,
3739 EVALUATION_DATASETS_MAX_SIZE ,
3840)
3941from sagemaker .model_card .helpers import (
@@ -235,6 +237,27 @@ def __init__(
235237 self .explanations_for_risk_rating = explanations_for_risk_rating
236238
237239
240+ class BusinessDetails (_DefaultToRequestDict , _DefaultFromDict ):
241+ """The business details of a model."""
242+
243+ def __init__ (
244+ self ,
245+ business_problem : Optional [str ] = None ,
246+ business_stakeholders : Optional [str ] = None ,
247+ line_of_business : Optional [str ] = None ,
248+ ):
249+ """Initialize an Business Details object.
250+
251+ Args:
252+ business_problem (str, optional): The business problem of this model (default: None).
253+ business_stakeholders (str, optional): The business stakeholders for this model (default: None).
254+ line_of_business (str, optional): The line of business for this model (default: None).
255+ """ # noqa E501 # pylint: disable=line-too-long
256+ self .business_problem = business_problem
257+ self .business_stakeholders = business_stakeholders
258+ self .line_of_business = line_of_business
259+
260+
238261class Function (_DefaultToRequestDict , _DefaultFromDict ):
239262 """Function details."""
240263
@@ -363,6 +386,24 @@ def __init__(
363386 self .notes = notes
364387
365388
389+ class HyperParameter (_DefaultToRequestDict , _DefaultFromDict ):
390+ """Hyper-Parameters data."""
391+
392+ def __init__ (
393+ self ,
394+ name : str ,
395+ value : str ,
396+ ):
397+ """Initialize a HyperParameter object.
398+
399+ Args:
400+ name (str): The hyper parameter name.
401+ value (str): The hyper parameter value.
402+ """
403+ self .name = name
404+ self .value = value
405+
406+
366407class TrainingJobDetails (_DefaultToRequestDict , _DefaultFromDict ):
367408 """The overview of a training job."""
368409
@@ -371,6 +412,10 @@ class TrainingJobDetails(_DefaultToRequestDict, _DefaultFromDict):
371412 user_provided_training_metrics = _IsList (
372413 TrainingMetric , USER_PROVIDED_TRAINING_METRICS_MAX_SIZE
373414 )
415+ hyper_parameters = _IsList (HyperParameter , HYPER_PARAMETERS_MAX_SIZE )
416+ user_provided_hyper_parameters = _IsList (
417+ HyperParameter , USER_PROVIDED_HYPER_PARAMETERS_MAX_SIZE
418+ )
374419 training_environment = _IsModelCardObject (Environment )
375420
376421 def __init__ (
@@ -380,6 +425,8 @@ def __init__(
380425 training_environment : Optional [Environment ] = None ,
381426 training_metrics : Optional [List [TrainingMetric ]] = None ,
382427 user_provided_training_metrics : Optional [List [TrainingMetric ]] = None ,
428+ hyper_parameters : Optional [List [HyperParameter ]] = None ,
429+ user_provided_hyper_parameters : Optional [List [HyperParameter ]] = None ,
383430 ):
384431 """Initialize a Training Job Details object.
385432
@@ -389,12 +436,16 @@ def __init__(
389436 training_environment (Environment, optional): The SageMaker training image URI. (default: None).
390437 training_metrics (list[TrainingMetric], optional): SageMaker training job results. The maximum `training_metrics` list length is 50 (default: None).
391438 user_provided_training_metrics (list[TrainingMetric], optional): Custom training job results. The maximum `user_provided_training_metrics` list length is 50 (default: None).
439+ hyper_parameters (list[HyperParameter], optional): SageMaker hyper parameter results. The maximum `hyper_parameters` list length is 100 (default: None).
440+ user_provided_hyper_parameters (list[HyperParameter], optional): Custom hyper parameter results. The maximum `user_provided_hyper_parameters` list length is 100 (default: None).
392441 """ # noqa E501 # pylint: disable=line-too-long
393442 self .training_arn = training_arn
394443 self .training_datasets = training_datasets
395444 self .training_environment = training_environment
396445 self .training_metrics = training_metrics
397446 self .user_provided_training_metrics = user_provided_training_metrics
447+ self .hyper_parameters = hyper_parameters
448+ self .user_provided_hyper_parameters = user_provided_hyper_parameters
398449
399450
400451class TrainingDetails (_DefaultToRequestDict , _DefaultFromDict ):
@@ -442,6 +493,10 @@ def _create_training_details(training_job_data: dict, cls: "TrainingDetails", **
442493 ]
443494 if "FinalMetricDataList" in training_job_data
444495 else [],
496+ "hyper_parameters" : [
497+ HyperParameter (key , value )
498+ for key , value in training_job_data ["HyperParameters" ].items ()
499+ ],
445500 }
446501 kwargs .update ({"training_job_details" : TrainingJobDetails (** job )})
447502 instance = cls (** kwargs )
@@ -568,6 +623,16 @@ def add_metric(self, metric: TrainingMetric):
568623 self .training_job_details = TrainingJobDetails ()
569624 self .training_job_details .user_provided_training_metrics .append (metric )
570625
626+ def add_parameter (self , parameter : HyperParameter ):
627+ """Add custom hyper-parameter.
628+
629+ Args:
630+ parameter (HyperParameter): The custom parameter to add.
631+ """
632+ if not self .training_job_details :
633+ self .training_job_details = TrainingJobDetails ()
634+ self .training_job_details .user_provided_hyper_parameters .append (parameter )
635+
571636
572637class MetricGroup (_DefaultToRequestDict , _DefaultFromDict ):
573638 """Group of metric data"""
@@ -777,6 +842,7 @@ class ModelCard(object):
777842 status = _OneOf (ModelCardStatusEnum )
778843 model_overview = _IsModelCardObject (ModelOverview )
779844 intended_uses = _IsModelCardObject (IntendedUses )
845+ business_details = _IsModelCardObject (BusinessDetails )
780846 training_details = _IsModelCardObject (TrainingDetails )
781847 evaluation_details = _IsList (EvaluationJob )
782848 additional_information = _IsModelCardObject (AdditionalInformation )
@@ -793,6 +859,7 @@ def __init__(
793859 last_modified_by : Optional [dict ] = None ,
794860 model_overview : Optional [ModelOverview ] = None ,
795861 intended_uses : Optional [IntendedUses ] = None ,
862+ business_details : Optional [BusinessDetails ] = None ,
796863 training_details : Optional [TrainingDetails ] = None ,
797864 evaluation_details : Optional [List [EvaluationJob ]] = None ,
798865 additional_information : Optional [AdditionalInformation ] = None ,
@@ -811,6 +878,7 @@ def __init__(
811878 last_modified_by (dict, optional): The group or individual that last modified the model card (default: None).
812879 model_overview (ModelOverview, optional): An overview of the model (default: None).
813880 intended_uses (IntendedUses, optional): The intended uses of the model (default: None).
881+ business_details (BusinessDetails, optional): The business details of the model (default: None).
814882 training_details (TrainingDetails, optional): The training details of the model (default: None).
815883 evaluation_details (List[EvaluationJob], optional): The evaluation details of the model (default: None).
816884 additional_information (AdditionalInformation, optional): Additional information about the model (default: None).
@@ -826,6 +894,7 @@ def __init__(
826894 self .last_modified_by = last_modified_by
827895 self .model_overview = model_overview
828896 self .intended_uses = intended_uses
897+ self .business_details = business_details
829898 self .training_details = training_details
830899 self .evaluation_details = evaluation_details
831900 self .additional_information = additional_information
0 commit comments