3030from sagemaker .local import LocalSession
3131from sagemaker .utils import base_name_from_image , name_from_base
3232from sagemaker .session import Session
33- from sagemaker .network import NetworkConfig # noqa: F401 # pylint: disable=unused-import
3433from sagemaker .workflow .properties import Properties
3534from sagemaker .workflow .parameters import Parameter
3635from sagemaker .workflow .entities import Expression
@@ -219,14 +218,14 @@ def _normalize_args(
219218 """
220219 self ._current_job_name = self ._generate_current_job_name (job_name = job_name )
221220
222- inputs_with_code = self ._include_code_in_inputs (inputs , code )
221+ inputs_with_code = self ._include_code_in_inputs (inputs , code , kms_key )
223222 normalized_inputs = self ._normalize_inputs (inputs_with_code , kms_key )
224223 normalized_outputs = self ._normalize_outputs (outputs )
225224 self .arguments = arguments
226225
227226 return normalized_inputs , normalized_outputs
228227
229- def _include_code_in_inputs (self , inputs , _code ):
228+ def _include_code_in_inputs (self , inputs , _code , _kms_key ):
230229 """A no op in the base class to include code in the processing job inputs.
231230
232231 Args:
@@ -235,6 +234,8 @@ def _include_code_in_inputs(self, inputs, _code):
235234 :class:`~sagemaker.processing.ProcessingInput` objects.
236235 _code (str): This can be an S3 URI or a local path to a file with the framework
237236 script to run (default: None). A no op in the base class.
237+ kms_key (str): The ARN of the KMS key that is used to encrypt the
238+ user code file (default: None).
238239
239240 Returns:
240241 list[:class:`~sagemaker.processing.ProcessingInput`]: inputs
@@ -528,7 +529,7 @@ def run(
528529 if wait :
529530 self .latest_job .wait (logs = logs )
530531
531- def _include_code_in_inputs (self , inputs , code ):
532+ def _include_code_in_inputs (self , inputs , code , kms_key = None ):
532533 """Converts code to appropriate input and includes in input list.
533534
534535 Side effects include:
@@ -541,12 +542,14 @@ def _include_code_in_inputs(self, inputs, code):
541542 :class:`~sagemaker.processing.ProcessingInput` objects.
542543 code (str): This can be an S3 URI or a local path to a file with the framework
543544 script to run (default: None).
545+ kms_key (str): The ARN of the KMS key that is used to encrypt the
546+ user code file (default: None).
544547
545548 Returns:
546549 list[:class:`~sagemaker.processing.ProcessingInput`]: inputs together with the
547550 code as `ProcessingInput`.
548551 """
549- user_code_s3_uri = self ._handle_user_code_url (code )
552+ user_code_s3_uri = self ._handle_user_code_url (code , kms_key )
550553 user_script_name = self ._get_user_code_name (code )
551554
552555 inputs_with_code = self ._convert_code_and_add_to_inputs (inputs , user_code_s3_uri )
@@ -567,14 +570,16 @@ def _get_user_code_name(self, code):
567570 code_url = urlparse (code )
568571 return os .path .basename (code_url .path )
569572
570- def _handle_user_code_url (self , code ):
573+ def _handle_user_code_url (self , code , kms_key = None ):
571574 """Gets the S3 URL containing the user's code.
572575
573576 Inspects the scheme the customer passed in ("s3://" for code in S3, "file://" or nothing
574577 for absolute or local file paths. Uploads the code to S3 if the code is a local file.
575578
576579 Args:
577580 code (str): A URL to the customer's code.
581+ kms_key (str): The ARN of the KMS key that is used to encrypt the
582+ user code file (default: None).
578583
579584 Returns:
580585 str: The S3 URL to the customer's code.
@@ -603,7 +608,7 @@ def _handle_user_code_url(self, code):
603608 code
604609 )
605610 )
606- user_code_s3_uri = self ._upload_code (code_path )
611+ user_code_s3_uri = self ._upload_code (code_path , kms_key )
607612 else :
608613 raise ValueError (
609614 "code {} url scheme {} is not recognized. Please pass a file path or S3 url" .format (
@@ -612,11 +617,13 @@ def _handle_user_code_url(self, code):
612617 )
613618 return user_code_s3_uri
614619
615- def _upload_code (self , code ):
620+ def _upload_code (self , code , kms_key = None ):
616621 """Uploads a code file or directory specified as a string and returns the S3 URI.
617622
618623 Args:
619624 code (str): A file or directory to be uploaded to S3.
625+ kms_key (str): The ARN of the KMS key that is used to encrypt the
626+ user code file (default: None).
620627
621628 Returns:
622629 str: The S3 URI of the uploaded file or directory.
@@ -630,7 +637,10 @@ def _upload_code(self, code):
630637 self ._CODE_CONTAINER_INPUT_NAME ,
631638 )
632639 return s3 .S3Uploader .upload (
633- local_path = code , desired_s3_uri = desired_s3_uri , sagemaker_session = self .sagemaker_session
640+ local_path = code ,
641+ desired_s3_uri = desired_s3_uri ,
642+ kms_key = kms_key ,
643+ sagemaker_session = self .sagemaker_session ,
634644 )
635645
636646 def _convert_code_and_add_to_inputs (self , inputs , s3_uri ):
@@ -666,7 +676,9 @@ def _set_entrypoint(self, command, user_script_name):
666676 """
667677 user_script_location = str (
668678 pathlib .PurePosixPath (
669- self ._CODE_CONTAINER_BASE_PATH , self ._CODE_CONTAINER_INPUT_NAME , user_script_name
679+ self ._CODE_CONTAINER_BASE_PATH ,
680+ self ._CODE_CONTAINER_INPUT_NAME ,
681+ user_script_name ,
670682 )
671683 )
672684 self .entrypoint = command + [user_script_location ]
@@ -1066,7 +1078,10 @@ def _to_request_dict(self):
10661078 """Generates a request dictionary using the parameters provided to the class."""
10671079
10681080 # Create the request dictionary.
1069- s3_input_request = {"InputName" : self .input_name , "AppManaged" : self .app_managed }
1081+ s3_input_request = {
1082+ "InputName" : self .input_name ,
1083+ "AppManaged" : self .app_managed ,
1084+ }
10701085
10711086 if self .s3_input :
10721087 # Check the compression type, then add it to the dictionary.
0 commit comments