|
19 | 19 | import re |
20 | 20 | import sys |
21 | 21 | import time |
| 22 | +import typing |
22 | 23 | import warnings |
23 | 24 | from typing import List, Dict, Any, Sequence |
24 | 25 |
|
@@ -551,7 +552,6 @@ def train( # noqa: C901 |
551 | 552 | retry_strategy(dict): Defines RetryStrategy for InternalServerFailures. |
552 | 553 | * max_retry_attsmpts (int): Number of times a job should be retried. |
553 | 554 | The key in RetryStrategy is 'MaxRetryAttempts'. |
554 | | -
|
555 | 555 | Returns: |
556 | 556 | str: ARN of the training job, if it is created. |
557 | 557 | """ |
@@ -585,9 +585,13 @@ def train( # noqa: C901 |
585 | 585 | environment=environment, |
586 | 586 | retry_strategy=retry_strategy, |
587 | 587 | ) |
588 | | - LOGGER.info("Creating training-job with name: %s", job_name) |
589 | | - LOGGER.debug("train request: %s", json.dumps(train_request, indent=4)) |
590 | | - self.sagemaker_client.create_training_job(**train_request) |
| 588 | + |
| 589 | + def submit(request): |
| 590 | + LOGGER.info("Creating training-job with name: %s", job_name) |
| 591 | + LOGGER.debug("train request: %s", json.dumps(request, indent=4)) |
| 592 | + self.sagemaker_client.create_training_job(**request) |
| 593 | + |
| 594 | + self._intercept_create_request(train_request, submit) |
591 | 595 |
|
592 | 596 | def _get_train_request( # noqa: C901 |
593 | 597 | self, |
@@ -912,9 +916,13 @@ def process( |
912 | 916 | tags=tags, |
913 | 917 | experiment_config=experiment_config, |
914 | 918 | ) |
915 | | - LOGGER.info("Creating processing-job with name %s", job_name) |
916 | | - LOGGER.debug("process request: %s", json.dumps(process_request, indent=4)) |
917 | | - self.sagemaker_client.create_processing_job(**process_request) |
| 919 | + |
| 920 | + def submit(request): |
| 921 | + LOGGER.info("Creating processing-job with name %s", job_name) |
| 922 | + LOGGER.debug("process request: %s", json.dumps(request, indent=4)) |
| 923 | + self.sagemaker_client.create_processing_job(**request) |
| 924 | + |
| 925 | + self._intercept_create_request(process_request, submit) |
918 | 926 |
|
919 | 927 | def _get_process_request( |
920 | 928 | self, |
@@ -2086,9 +2094,12 @@ def create_tuning_job( |
2086 | 2094 | tags=tags, |
2087 | 2095 | ) |
2088 | 2096 |
|
2089 | | - LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name) |
2090 | | - LOGGER.debug("tune request: %s", json.dumps(tune_request, indent=4)) |
2091 | | - self.sagemaker_client.create_hyper_parameter_tuning_job(**tune_request) |
| 2097 | + def submit(request): |
| 2098 | + LOGGER.info("Creating hyperparameter tuning job with name: %s", job_name) |
| 2099 | + LOGGER.debug("tune request: %s", json.dumps(request, indent=4)) |
| 2100 | + self.sagemaker_client.create_hyper_parameter_tuning_job(**request) |
| 2101 | + |
| 2102 | + self._intercept_create_request(tune_request, submit) |
2092 | 2103 |
|
2093 | 2104 | def _get_tuning_request( |
2094 | 2105 | self, |
@@ -2553,9 +2564,12 @@ def transform( |
2553 | 2564 | model_client_config=model_client_config, |
2554 | 2565 | ) |
2555 | 2566 |
|
2556 | | - LOGGER.info("Creating transform job with name: %s", job_name) |
2557 | | - LOGGER.debug("Transform request: %s", json.dumps(transform_request, indent=4)) |
2558 | | - self.sagemaker_client.create_transform_job(**transform_request) |
| 2567 | + def submit(request): |
| 2568 | + LOGGER.info("Creating transform job with name: %s", job_name) |
| 2569 | + LOGGER.debug("Transform request: %s", json.dumps(request, indent=4)) |
| 2570 | + self.sagemaker_client.create_transform_job(**request) |
| 2571 | + |
| 2572 | + self._intercept_create_request(transform_request, submit) |
2559 | 2573 |
|
2560 | 2574 | def _create_model_request( |
2561 | 2575 | self, |
@@ -4161,6 +4175,18 @@ def account_id(self) -> str: |
4161 | 4175 | ) |
4162 | 4176 | return sts_client.get_caller_identity()["Account"] |
4163 | 4177 |
|
| 4178 | + def _intercept_create_request(self, request: typing.Dict, create): |
| 4179 | + """This function intercepts the create job request. |
| 4180 | +
|
| 4181 | + PipelineSession inherits this Session class and will override |
| 4182 | + this function to intercept the create request. |
| 4183 | +
|
| 4184 | + Args: |
| 4185 | + request (dict): the create job request |
| 4186 | + create (functor): a functor calls the sagemaker client create method |
| 4187 | + """ |
| 4188 | + create(request) |
| 4189 | + |
4164 | 4190 |
|
4165 | 4191 | def get_model_package_args( |
4166 | 4192 | content_types, |
|
0 commit comments