@@ -511,7 +511,133 @@ def train( # noqa: C901
511
511
Returns:
512
512
str: ARN of the training job, if it is created.
513
513
"""
514
+ train_request = self ._get_train_request (
515
+ input_mode = input_mode ,
516
+ input_config = input_config ,
517
+ role = role ,
518
+ job_name = job_name ,
519
+ output_config = output_config ,
520
+ resource_config = resource_config ,
521
+ vpc_config = vpc_config ,
522
+ hyperparameters = hyperparameters ,
523
+ stop_condition = stop_condition ,
524
+ tags = tags ,
525
+ metric_definitions = metric_definitions ,
526
+ enable_network_isolation = enable_network_isolation ,
527
+ image_uri = image_uri ,
528
+ algorithm_arn = algorithm_arn ,
529
+ encrypt_inter_container_traffic = encrypt_inter_container_traffic ,
530
+ use_spot_instances = use_spot_instances ,
531
+ checkpoint_s3_uri = checkpoint_s3_uri ,
532
+ checkpoint_local_path = checkpoint_local_path ,
533
+ experiment_config = experiment_config ,
534
+ debugger_rule_configs = debugger_rule_configs ,
535
+ debugger_hook_config = debugger_hook_config ,
536
+ tensorboard_output_config = tensorboard_output_config ,
537
+ enable_sagemaker_metrics = enable_sagemaker_metrics ,
538
+ )
539
+ LOGGER .info ("Creating training-job with name: %s" , job_name )
540
+ LOGGER .debug ("train request: %s" , json .dumps (train_request , indent = 4 ))
541
+ self .sagemaker_client .create_training_job (** train_request )
542
+
543
+ def _get_train_request ( # noqa: C901
544
+ self ,
545
+ input_mode ,
546
+ input_config ,
547
+ role ,
548
+ job_name ,
549
+ output_config ,
550
+ resource_config ,
551
+ vpc_config ,
552
+ hyperparameters ,
553
+ stop_condition ,
554
+ tags ,
555
+ metric_definitions ,
556
+ enable_network_isolation = False ,
557
+ image_uri = None ,
558
+ algorithm_arn = None ,
559
+ encrypt_inter_container_traffic = False ,
560
+ use_spot_instances = False ,
561
+ checkpoint_s3_uri = None ,
562
+ checkpoint_local_path = None ,
563
+ experiment_config = None ,
564
+ debugger_rule_configs = None ,
565
+ debugger_hook_config = None ,
566
+ tensorboard_output_config = None ,
567
+ enable_sagemaker_metrics = None ,
568
+ ):
569
+ """Constructs a request compatible for creating an Amazon SageMaker training job.
570
+
571
+ Args:
572
+ input_mode (str): The input mode that the algorithm supports. Valid modes:
573
+ * 'File' - Amazon SageMaker copies the training dataset from the S3 location to
574
+ a directory in the Docker container.
575
+ * 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a
576
+ Unix-named pipe.
577
+
578
+ input_config (list): A list of Channel objects. Each channel is a named input source.
579
+ Please refer to the format details described:
580
+ https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job
581
+ role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
582
+ jobs and APIs that create Amazon SageMaker endpoints use this role to access
583
+ training data and model artifacts. You must grant sufficient permissions to this
584
+ role.
585
+ job_name (str): Name of the training job being created.
586
+ output_config (dict): The S3 URI where you want to store the training results and
587
+ optional KMS key ID.
588
+ resource_config (dict): Contains values for ResourceConfig:
589
+ * instance_count (int): Number of EC2 instances to use for training.
590
+ The key in resource_config is 'InstanceCount'.
591
+ * instance_type (str): Type of EC2 instance to use for training, for example,
592
+ 'ml.c4.xlarge'. The key in resource_config is 'InstanceType'.
593
+
594
+ vpc_config (dict): Contains values for VpcConfig:
595
+ * subnets (list[str]): List of subnet ids.
596
+ The key in vpc_config is 'Subnets'.
597
+ * security_group_ids (list[str]): List of security group ids.
598
+ The key in vpc_config is 'SecurityGroupIds'.
599
+
600
+ hyperparameters (dict): Hyperparameters for model training. The hyperparameters are
601
+ made accessible as a dict[str, str] to the training code on SageMaker. For
602
+ convenience, this accepts other types for keys and values, but ``str()`` will be
603
+ called to convert them before training.
604
+ stop_condition (dict): Defines when training shall finish. Contains entries that can
605
+ be understood by the service like ``MaxRuntimeInSeconds``.
606
+ tags (list[dict]): List of tags for labeling a training job. For more, see
607
+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
608
+ metric_definitions (list[dict]): A list of dictionaries that defines the metric(s)
609
+ used to evaluate the training jobs. Each dictionary contains two keys: 'Name' for
610
+ the name of the metric, and 'Regex' for the regular expression used to extract the
611
+ metric from the logs.
612
+ enable_network_isolation (bool): Whether to request for the training job to run with
613
+ network isolation or not.
614
+ image_uri (str): Docker image containing training code.
615
+ algorithm_arn (str): Algorithm Arn from Marketplace.
616
+ encrypt_inter_container_traffic (bool): Specifies whether traffic between training
617
+ containers is encrypted for the training job (default: ``False``).
618
+ use_spot_instances (bool): whether to use spot instances for training.
619
+ checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
620
+ that the algorithm persists (if any) during training. (default:
621
+ ``None``).
622
+ checkpoint_local_path (str): The local path that the algorithm
623
+ writes its checkpoints to. SageMaker will persist all files
624
+ under this path to `checkpoint_s3_uri` continually during
625
+ training. On job startup the reverse happens - data from the
626
+ s3 location is downloaded to this path before the algorithm is
627
+ started. If the path is unset then SageMaker assumes the
628
+ checkpoints will be provided under `/opt/ml/checkpoints/`.
629
+ (default: ``None``).
630
+ experiment_config (dict): Experiment management configuration. Dictionary contains
631
+ three optional keys, 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
632
+ (default: ``None``)
633
+ enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
634
+ Series. For more information see:
635
+ https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
636
+ (default: ``None``).
514
637
638
+ Returns:
639
+ Dict: a training request dictionary
640
+ """
515
641
train_request = {
516
642
"AlgorithmSpecification" : {"TrainingInputMode" : input_mode },
517
643
"OutputDataConfig" : output_config ,
@@ -583,9 +709,7 @@ def train( # noqa: C901
583
709
if tensorboard_output_config is not None :
584
710
train_request ["TensorBoardOutputConfig" ] = tensorboard_output_config
585
711
586
- LOGGER .info ("Creating training-job with name: %s" , job_name )
587
- LOGGER .debug ("train request: %s" , json .dumps (train_request , indent = 4 ))
588
- self .sagemaker_client .create_training_job (** train_request )
712
+ return train_request
589
713
590
714
def process (
591
715
self ,
0 commit comments