Skip to content

Commit d54ae22

Browse files
authored
change: break out methods to get train arguments (#1850)
1 parent 8b82099 commit d54ae22

File tree

2 files changed

+149
-7
lines changed

2 files changed

+149
-7
lines changed

src/sagemaker/estimator.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1037,11 +1037,31 @@ def start_new(cls, estimator, inputs, experiment_config):
10371037
:meth:`~sagemaker.estimator.EstimatorBase.fit`. Dictionary contains
10381038
three optional keys, 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
10391039
1040-
10411040
Returns:
10421041
sagemaker.estimator._TrainingJob: Constructed object that captures
10431042
all information about the started training job.
10441043
"""
1044+
train_args = cls._get_train_args(estimator, inputs, experiment_config)
1045+
estimator.sagemaker_session.train(**train_args)
1046+
1047+
return cls(estimator.sagemaker_session, estimator._current_job_name)
1048+
1049+
@classmethod
1050+
def _get_train_args(cls, estimator, inputs, experiment_config):
1051+
"""Constructs a dict of arguments for an Amazon SageMaker training job from the estimator.
1052+
1053+
Args:
1054+
estimator (sagemaker.estimator.EstimatorBase): Estimator object
1055+
created by the user.
1056+
inputs (str): Parameters used when called
1057+
:meth:`~sagemaker.estimator.EstimatorBase.fit`.
1058+
experiment_config (dict[str, str]): Experiment management configuration used when called
1059+
:meth:`~sagemaker.estimator.EstimatorBase.fit`. Dictionary contains
1060+
three optional keys, 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
1061+
1062+
Returns:
1063+
Dict: dict for `sagemaker.session.Session.train` method
1064+
"""
10451065

10461066
local_mode = estimator.sagemaker_session.local_mode
10471067
model_uri = estimator.model_uri
@@ -1102,9 +1122,7 @@ def start_new(cls, estimator, inputs, experiment_config):
11021122
if estimator.enable_sagemaker_metrics is not None:
11031123
train_args["enable_sagemaker_metrics"] = estimator.enable_sagemaker_metrics
11041124

1105-
estimator.sagemaker_session.train(**train_args)
1106-
1107-
return cls(estimator.sagemaker_session, estimator._current_job_name)
1125+
return train_args
11081126

11091127
@classmethod
11101128
def _add_spot_checkpoint_args(cls, local_mode, estimator, train_args):

src/sagemaker/session.py

Lines changed: 127 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -511,7 +511,133 @@ def train( # noqa: C901
511511
Returns:
512512
str: ARN of the training job, if it is created.
513513
"""
514+
train_request = self._get_train_request(
515+
input_mode=input_mode,
516+
input_config=input_config,
517+
role=role,
518+
job_name=job_name,
519+
output_config=output_config,
520+
resource_config=resource_config,
521+
vpc_config=vpc_config,
522+
hyperparameters=hyperparameters,
523+
stop_condition=stop_condition,
524+
tags=tags,
525+
metric_definitions=metric_definitions,
526+
enable_network_isolation=enable_network_isolation,
527+
image_uri=image_uri,
528+
algorithm_arn=algorithm_arn,
529+
encrypt_inter_container_traffic=encrypt_inter_container_traffic,
530+
use_spot_instances=use_spot_instances,
531+
checkpoint_s3_uri=checkpoint_s3_uri,
532+
checkpoint_local_path=checkpoint_local_path,
533+
experiment_config=experiment_config,
534+
debugger_rule_configs=debugger_rule_configs,
535+
debugger_hook_config=debugger_hook_config,
536+
tensorboard_output_config=tensorboard_output_config,
537+
enable_sagemaker_metrics=enable_sagemaker_metrics,
538+
)
539+
LOGGER.info("Creating training-job with name: %s", job_name)
540+
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
541+
self.sagemaker_client.create_training_job(**train_request)
542+
543+
def _get_train_request( # noqa: C901
544+
self,
545+
input_mode,
546+
input_config,
547+
role,
548+
job_name,
549+
output_config,
550+
resource_config,
551+
vpc_config,
552+
hyperparameters,
553+
stop_condition,
554+
tags,
555+
metric_definitions,
556+
enable_network_isolation=False,
557+
image_uri=None,
558+
algorithm_arn=None,
559+
encrypt_inter_container_traffic=False,
560+
use_spot_instances=False,
561+
checkpoint_s3_uri=None,
562+
checkpoint_local_path=None,
563+
experiment_config=None,
564+
debugger_rule_configs=None,
565+
debugger_hook_config=None,
566+
tensorboard_output_config=None,
567+
enable_sagemaker_metrics=None,
568+
):
569+
"""Constructs a request compatible for creating an Amazon SageMaker training job.
570+
571+
Args:
572+
input_mode (str): The input mode that the algorithm supports. Valid modes:
573+
* 'File' - Amazon SageMaker copies the training dataset from the S3 location to
574+
a directory in the Docker container.
575+
* 'Pipe' - Amazon SageMaker streams data directly from S3 to the container via a
576+
Unix-named pipe.
577+
578+
input_config (list): A list of Channel objects. Each channel is a named input source.
579+
Please refer to the format details described:
580+
https://botocore.readthedocs.io/en/latest/reference/services/sagemaker.html#SageMaker.Client.create_training_job
581+
role (str): An AWS IAM role (either name or full ARN). The Amazon SageMaker training
582+
jobs and APIs that create Amazon SageMaker endpoints use this role to access
583+
training data and model artifacts. You must grant sufficient permissions to this
584+
role.
585+
job_name (str): Name of the training job being created.
586+
output_config (dict): The S3 URI where you want to store the training results and
587+
optional KMS key ID.
588+
resource_config (dict): Contains values for ResourceConfig:
589+
* instance_count (int): Number of EC2 instances to use for training.
590+
The key in resource_config is 'InstanceCount'.
591+
* instance_type (str): Type of EC2 instance to use for training, for example,
592+
'ml.c4.xlarge'. The key in resource_config is 'InstanceType'.
593+
594+
vpc_config (dict): Contains values for VpcConfig:
595+
* subnets (list[str]): List of subnet ids.
596+
The key in vpc_config is 'Subnets'.
597+
* security_group_ids (list[str]): List of security group ids.
598+
The key in vpc_config is 'SecurityGroupIds'.
599+
600+
hyperparameters (dict): Hyperparameters for model training. The hyperparameters are
601+
made accessible as a dict[str, str] to the training code on SageMaker. For
602+
convenience, this accepts other types for keys and values, but ``str()`` will be
603+
called to convert them before training.
604+
stop_condition (dict): Defines when training shall finish. Contains entries that can
605+
be understood by the service like ``MaxRuntimeInSeconds``.
606+
tags (list[dict]): List of tags for labeling a training job. For more, see
607+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_Tag.html.
608+
metric_definitions (list[dict]): A list of dictionaries that defines the metric(s)
609+
used to evaluate the training jobs. Each dictionary contains two keys: 'Name' for
610+
the name of the metric, and 'Regex' for the regular expression used to extract the
611+
metric from the logs.
612+
enable_network_isolation (bool): Whether to request for the training job to run with
613+
network isolation or not.
614+
image_uri (str): Docker image containing training code.
615+
algorithm_arn (str): Algorithm Arn from Marketplace.
616+
encrypt_inter_container_traffic (bool): Specifies whether traffic between training
617+
containers is encrypted for the training job (default: ``False``).
618+
use_spot_instances (bool): whether to use spot instances for training.
619+
checkpoint_s3_uri (str): The S3 URI in which to persist checkpoints
620+
that the algorithm persists (if any) during training. (default:
621+
``None``).
622+
checkpoint_local_path (str): The local path that the algorithm
623+
writes its checkpoints to. SageMaker will persist all files
624+
under this path to `checkpoint_s3_uri` continually during
625+
training. On job startup the reverse happens - data from the
626+
s3 location is downloaded to this path before the algorithm is
627+
started. If the path is unset then SageMaker assumes the
628+
checkpoints will be provided under `/opt/ml/checkpoints/`.
629+
(default: ``None``).
630+
experiment_config (dict): Experiment management configuration. Dictionary contains
631+
three optional keys, 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
632+
(default: ``None``)
633+
enable_sagemaker_metrics (bool): enable SageMaker Metrics Time
634+
Series. For more information see:
635+
https://docs.aws.amazon.com/sagemaker/latest/dg/API_AlgorithmSpecification.html#SageMaker-Type-AlgorithmSpecification-EnableSageMakerMetricsTimeSeries
636+
(default: ``None``).
514637
638+
Returns:
639+
Dict: a training request dictionary
640+
"""
515641
train_request = {
516642
"AlgorithmSpecification": {"TrainingInputMode": input_mode},
517643
"OutputDataConfig": output_config,
@@ -583,9 +709,7 @@ def train( # noqa: C901
583709
if tensorboard_output_config is not None:
584710
train_request["TensorBoardOutputConfig"] = tensorboard_output_config
585711

586-
LOGGER.info("Creating training-job with name: %s", job_name)
587-
LOGGER.debug("train request: %s", json.dumps(train_request, indent=4))
588-
self.sagemaker_client.create_training_job(**train_request)
712+
return train_request
589713

590714
def process(
591715
self,

0 commit comments

Comments
 (0)