@@ -511,7 +511,133 @@ def train( # noqa: C901
511511 Returns:
512512 str: ARN of the training job, if it is created.
513513 """
514+ train_request = self ._get_train_request (
515+ input_mode = input_mode ,
516+ input_config = input_config ,
517+ role = role ,
518+ job_name = job_name ,
519+ output_config = output_config ,
520+ resource_config = resource_config ,
521+ vpc_config = vpc_config ,
522+ hyperparameters = hyperparameters ,
523+ stop_condition = stop_condition ,
524+ tags = tags ,
525+ metric_definitions = metric_definitions ,
526+ enable_network_isolation = enable_network_isolation ,
527+ image_uri = image_uri ,
528+ algorithm_arn = algorithm_arn ,
529+ encrypt_inter_container_traffic = encrypt_inter_container_traffic ,
530+ use_spot_instances = use_spot_instances ,
531+ checkpoint_s3_uri = checkpoint_s3_uri ,
532+ checkpoint_local_path = checkpoint_local_path ,
533+ experiment_config = experiment_config ,
534+ debugger_rule_configs = debugger_rule_configs ,
535+ debugger_hook_config = debugger_hook_config ,
536+ tensorboard_output_config = tensorboard_output_config ,
537+ enable_sagemaker_metrics = enable_sagemaker_metrics ,
538+ )
539+ LOGGER .info ("Creating training-job with name: %s" , job_name )
540+ LOGGER .debug ("train request: %s" , json .dumps (train_request , indent = 4 ))
541+ self .sagemaker_client .create_training_job (** train_request )
542+
543+ def _get_train_request ( # noqa: C901
544+ self ,
545+ input_mode ,
546+ input_config ,
547+ role ,
548+ job_name ,
549+ output_config ,
550+ resource_config ,
551+ vpc_config ,
552+ hyperparameters ,
553+ stop_condition ,
554+ tags ,
555+ metric_definitions ,
556+ enable_network_isolation = False ,
557+ image_uri = None ,
558+ algorithm_arn = None ,
559+ encrypt_inter_container_traffic = False ,
560+ use_spot_instances = False ,
561+ checkpoint_s3_uri = None ,
562+ checkpoint_local_path = None ,
563+ experiment_config = None ,
564+ debugger_rule_configs = None ,
565+ debugger_hook_config = None ,
566+ tensorboard_output_config = None ,
567+ enable_sagemaker_metrics = None ,
568+ ):
569+ """Constructs a request compatible for creating an Amazon SageMaker training job.
570+
571+ Args:
572+ input_mode (str): The input mode that the algorithm supports. Valid modes:
573+ * 'File' - Amazon SageMaker copies the training dataset from the S3 location to
574+ a directory in the Docker container.
575+ * 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a
576+ Unix-named pipe.
577+
578+ input_config (list): A list of Channel objects. Each channel is a named input source.
579+ Please refer to the format details described:
580+ https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job
581+ role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
582+ jobs and APIs that create Amazon SageMaker endpoints use this role to access
583+ training data and model artifacts. You must grant sufficient permissions to this
584+ role.
585+ job_name (str): Name of the training job being created.
586+ output_config (dict): The S3 URI where you want to store the training results and
587+ optional KMS key ID.
588+ resource_config (dict): Contains values for ResourceConfig:
589+ * instance_count (int): Number of EC2 instances to use for training.
590+ The key in resource_config is 'InstanceCount'.
591+ * instance_type (str): Type of EC2 instance to use for training, for example,
592+ 'ml.c4.xlarge'. The key in resource_config is 'InstanceType'.
593+
594+ vpc_config (dict): Contains values for VpcConfig:
595+ * subnets (list[str]): List of subnet ids.
596+ The key in vpc_config is 'Subnets'.
597+ * security_group_ids (list[str]): List of security group ids.
598+ The key in vpc_config is 'SecurityGroupIds'.
599+
600+ hyperparameters (dict): Hyperparameters for model training. The hyperparameters are
601+ made accessible as a dict[str, str] to the training code on SageMaker. For
602+ convenience, this accepts other types for keys and values, but ``str()`` will be
603+ called to convert them before training.
604+ stop_condition (dict): Defines when training shall finish. Contains entries that can
605+ be understood by the service like ``MaxRuntimeInSeconds``.
606+ tags (list[dict]): List of tags for labeling a training job. For more, see
607+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
608+ metric_definitions (list[dict]): A list of dictionaries that defines the metric(s)
609+ used to evaluate the training jobs. Each dictionary contains two keys: 'Name' for
610+ the name of the metric, and 'Regex' for the regular expression used to extract the
611+ metric from the logs.
612+ enable_network_isolation (bool): Whether to request for the training job to run with
613+ network isolation or not.
614+ image_uri (str): Docker image containing training code.
615+ algorithm_arn (str): Algorithm Arn from Marketplace.
616+ encrypt_inter_container_traffic (bool): Specifies whether traffic between training
617+ containers is encrypted for the training job (default: ``False``).
618+ use_spot_instances (bool): whether to use spot instances for training.
619+ checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
620+ that the algorithm persists (if any) during training. (default:
621+ ``None``).
622+ checkpoint_local_path (str): The local path that the algorithm
623+ writes its checkpoints to. SageMaker will persist all files
624+ under this path to `checkpoint_s3_uri` continually during
625+ training. On job startup the reverse happens - data from the
626+ s3 location is downloaded to this path before the algorithm is
627+ started. If the path is unset then SageMaker assumes the
628+ checkpoints will be provided under `/opt/ml/checkpoints/`.
629+ (default: ``None``).
630+ experiment_config (dict): Experiment management configuration. Dictionary contains
631+ three optional keys, 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
632+ (default: ``None``)
633+ enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
634+ Series. For more information see:
635+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
636+ (default: ``None``).
514637
638+ Returns:
639+ Dict: a training request dictionary
640+ """
515641 train_request = {
516642 "AlgorithmSpecification" : {"TrainingInputMode" : input_mode },
517643 "OutputDataConfig" : output_config ,
@@ -583,9 +709,7 @@ def train( # noqa: C901
583709 if tensorboard_output_config is not None :
584710 train_request ["TensorBoardOutputConfig" ] = tensorboard_output_config
585711
586- LOGGER .info ("Creating training-job with name: %s" , job_name )
587- LOGGER .debug ("train request: %s" , json .dumps (train_request , indent = 4 ))
588- self .sagemaker_client .create_training_job (** train_request )
712+ return train_request
589713
590714 def process (
591715 self ,
0 commit comments