Skip to content

Commit 7a5e11f

Browse files
authored
change: break out methods to get processing arguments (#1851)
1 parent d54ae22 commit 7a5e11f

File tree

2 files changed

+106
-21
lines changed

2 files changed

+106
-21
lines changed

src/sagemaker/processing.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -570,11 +570,48 @@ def start_new(cls, processor, inputs, outputs, experiment_config):
570570
:class:`~sagemaker.processing.ProcessingJob`: The instance of ``ProcessingJob`` created
571571
using the ``Processor``.
572572
"""
573+
process_args = cls._get_process_args(processor, inputs, outputs, experiment_config)
574+
575+
# Print the job name and the user's inputs and outputs as lists of dictionaries.
576+
print()
577+
print("Job Name: ", process_args["job_name"])
578+
print("Inputs: ", process_args["inputs"])
579+
print("Outputs: ", process_args["output_config"]["Outputs"])
580+
581+
# Call sagemaker_session.process using the arguments dictionary.
582+
processor.sagemaker_session.process(**process_args)
583+
584+
return cls(
585+
processor.sagemaker_session,
586+
processor._current_job_name,
587+
inputs,
588+
outputs,
589+
processor.output_kms_key,
590+
)
591+
592+
@classmethod
593+
def _get_process_args(cls, processor, inputs, outputs, experiment_config):
594+
"""Gets a dict of arguments for a new Amazon SageMaker processing job from the processor
595+
596+
Args:
597+
processor (:class:`~sagemaker.processing.Processor`): The ``Processor`` instance
598+
that started the job.
599+
inputs (list[:class:`~sagemaker.processing.ProcessingInput`]): A list of
600+
:class:`~sagemaker.processing.ProcessingInput` objects.
601+
outputs (list[:class:`~sagemaker.processing.ProcessingOutput`]): A list of
602+
:class:`~sagemaker.processing.ProcessingOutput` objects.
603+
experiment_config (dict[str, str]): Experiment management configuration.
604+
Dictionary contains three optional keys:
605+
'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
606+
607+
Returns:
608+
Dict: dict for `sagemaker.session.Session.process` method
609+
"""
573610
# Initialize an empty dictionary for arguments to be passed to sagemaker_session.process.
574611
process_request_args = {}
575612

576613
# Add arguments to the dictionary.
577-
process_request_args["inputs"] = [input._to_request_dict() for input in inputs]
614+
process_request_args["inputs"] = [inp._to_request_dict() for inp in inputs]
578615

579616
process_request_args["output_config"] = {
580617
"Outputs": [output._to_request_dict() for output in outputs]
@@ -622,22 +659,7 @@ def start_new(cls, processor, inputs, outputs, experiment_config):
622659

623660
process_request_args["tags"] = processor.tags
624661

625-
# Print the job name and the user's inputs and outputs as lists of dictionaries.
626-
print()
627-
print("Job Name: ", process_request_args["job_name"])
628-
print("Inputs: ", process_request_args["inputs"])
629-
print("Outputs: ", process_request_args["output_config"]["Outputs"])
630-
631-
# Call sagemaker_session.process using the arguments dictionary.
632-
processor.sagemaker_session.process(**process_request_args)
633-
634-
return cls(
635-
processor.sagemaker_session,
636-
processor._current_job_name,
637-
inputs,
638-
outputs,
639-
processor.output_kms_key,
640-
)
662+
return process_request_args
641663

642664
@classmethod
643665
def from_processing_name(cls, sagemaker_session, processing_job_name):

src/sagemaker/session.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -636,7 +636,7 @@ def _get_train_request( # noqa: C901
636636
(default: ``None``).
637637
638638
Returns:
639-
Dict: a training request dictionary
639+
Dict: a training request dict
640640
"""
641641
train_request = {
642642
"AlgorithmSpecification": {"TrainingInputMode": input_mode},
@@ -756,6 +756,71 @@ def process(
756756
three optional keys, 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
757757
(default: ``None``)
758758
"""
759+
process_request = self._get_process_request(
760+
inputs=inputs,
761+
output_config=output_config,
762+
job_name=job_name,
763+
resources=resources,
764+
stopping_condition=stopping_condition,
765+
app_specification=app_specification,
766+
environment=environment,
767+
network_config=network_config,
768+
role_arn=role_arn,
769+
tags=tags,
770+
experiment_config=experiment_config,
771+
)
772+
LOGGER.info("Creating processing-job with name %s", job_name)
773+
LOGGER.debug("process request: %s", json.dumps(process_request, indent=4))
774+
self.sagemaker_client.create_processing_job(**process_request)
775+
776+
def _get_process_request(
777+
self,
778+
inputs,
779+
output_config,
780+
job_name,
781+
resources,
782+
stopping_condition,
783+
app_specification,
784+
environment,
785+
network_config,
786+
role_arn,
787+
tags,
788+
experiment_config=None,
789+
):
790+
"""Constructs a request compatible for an Amazon SageMaker processing job.
791+
792+
Args:
793+
inputs ([dict]): List of up to 10 ProcessingInput dictionaries.
794+
output_config (dict): A config dictionary, which contains a list of up
795+
to 10 ProcessingOutput dictionaries, as well as an optional KMS key ID.
796+
job_name (str): The name of the processing job. The name must be unique
797+
within an AWS Region in an AWS account. Names should have minimum
798+
length of 1 and maximum length of 63 characters.
799+
resources (dict): Encapsulates the resources, including ML instances
800+
and storage, to use for the processing job.
801+
stopping_condition (dict[str,int]): Specifies a limit to how long
802+
the processing job can run, in seconds.
803+
app_specification (dict[str,str]): Configures the processing job to
804+
run the given image. Details are in the processing container
805+
specification.
806+
environment (dict): Environment variables to start the processing
807+
container with.
808+
network_config (dict): Specifies networking options, such as network
809+
traffic encryption between processing containers, whether to allow
810+
inbound and outbound network calls to and from processing containers,
811+
and VPC subnets and security groups to use for VPC-enabled processing
812+
jobs.
813+
role_arn (str): The Amazon Resource Name (ARN) of an IAM role that
814+
Amazon SageMaker can assume to perform tasks on your behalf.
815+
tags ([dict[str,str]]): A list of dictionaries containing key-value
816+
pairs.
817+
experiment_config (dict): Experiment management configuration. Dictionary contains
818+
three optional keys, 'ExperimentName', 'TrialName', and 'TrialComponentDisplayName'.
819+
(default: ``None``)
820+
821+
Returns:
822+
Dict: a processing job request dict
823+
"""
759824
process_request = {
760825
"ProcessingJobName": job_name,
761826
"ProcessingResources": resources,
@@ -784,9 +849,7 @@ def process(
784849
if experiment_config:
785850
process_request["ExperimentConfig"] = experiment_config
786851

787-
LOGGER.info("Creating processing-job with name %s", job_name)
788-
LOGGER.debug("process request: %s", json.dumps(process_request, indent=4))
789-
self.sagemaker_client.create_processing_job(**process_request)
852+
return process_request
790853

791854
def create_monitoring_schedule(
792855
self,

0 commit comments

Comments
 (0)