|
21 | 21 | import tempfile |
22 | 22 |
|
23 | 23 | from sagemaker.processing import ProcessingInput, ProcessingOutput, Processor |
24 | | -from sagemaker import image_uris, utils |
| 24 | +from sagemaker import image_uris, s3, utils |
25 | 25 |
|
26 | 26 |
|
27 | 27 | class DataConfig: |
@@ -405,9 +405,15 @@ def _run( |
405 | 405 | analysis_config_file = os.path.join(tmpdirname, "analysis_config.json") |
406 | 406 | with open(analysis_config_file, "w") as f: |
407 | 407 | json.dump(analysis_config, f) |
| 408 | + s3_analysis_config_file = _upload_analysis_config( |
| 409 | + analysis_config_file, |
| 410 | + data_config.s3_output_path, |
| 411 | + self.sagemaker_session, |
| 412 | + kms_key, |
| 413 | + ) |
408 | 414 | config_input = ProcessingInput( |
409 | 415 | input_name="analysis_config", |
410 | | - source=analysis_config_file, |
| 416 | + source=s3_analysis_config_file, |
411 | 417 | destination=self._CLARIFY_CONFIG_INPUT, |
412 | 418 | s3_data_type="S3Prefix", |
413 | 419 | s3_input_mode="File", |
@@ -638,6 +644,30 @@ def run_explainability( |
638 | 644 | self._run(data_config, analysis_config, wait, logs, job_name, kms_key) |
639 | 645 |
|
640 | 646 |
|
| 647 | +def _upload_analysis_config(analysis_config_file, s3_output_path, sagemaker_session, kms_key): |
| 648 | + """Uploads the local analysis_config_file to the s3_output_path. |
| 649 | +
|
| 650 | + Args: |
| 651 | + analysis_config_file (str): File path to the local analysis config file. |
| 652 | + s3_output_path (str): S3 prefix to store the analysis config file. |
| 653 | + sagemaker_session (:class:`~sagemaker.session.Session`): |
| 654 | + Session object which manages interactions with Amazon SageMaker and |
| 655 | + any other AWS services needed. If not specified, the processor creates |
| 656 | + one using the default AWS configuration chain. |
| 657 | + kms_key (str): The ARN of the KMS key that is used to encrypt the |
| 658 | + user code file (default: None). |
| 659 | +
|
| 660 | + Returns: |
| 661 | + The S3 uri of the uploaded file. |
| 662 | + """ |
| 663 | + return s3.S3Uploader.upload( |
| 664 | + local_path=analysis_config_file, |
| 665 | + desired_s3_uri=s3_output_path, |
| 666 | + sagemaker_session=sagemaker_session, |
| 667 | + kms_key=kms_key, |
| 668 | + ) |
| 669 | + |
| 670 | + |
641 | 671 | def _set(value, key, dictionary): |
642 | 672 | """Sets dictionary[key] = value if value is not None.""" |
643 | 673 | if value is not None: |
|
0 commit comments