@@ -119,7 +119,8 @@ class ModelTrainer(BaseModel):
119119 from sagemaker.modules.train import ModelTrainer
120120 from sagemaker.modules.configs import SourceCode, Compute, InputData
121121
122- source_code = SourceCode(source_dir="source", entry_script="train.py")
122+ ignore_patterns = ['.env', '.git', 'data', '__pycache__']
123+ source_code = SourceCode(source_dir="source", entry_script="train.py", ignore_patterns=ignore_patterns)
123124 training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image"
124125 model_trainer = ModelTrainer(
125126 training_image=training_image,
@@ -654,6 +655,7 @@ def train(
654655 channel_name = SM_CODE ,
655656 data_source = self .source_code .source_dir ,
656657 key_prefix = input_data_key_prefix ,
658+ ignore_patterns = self .source_code .ignore_patterns ,
657659 )
658660 final_input_data_config .append (source_code_channel )
659661
@@ -755,7 +757,11 @@ def train(
755757 local_container .train (wait )
756758
757759 def create_input_data_channel (
758- self , channel_name : str , data_source : DataSourceType , key_prefix : Optional [str ] = None
760+ self ,
761+ channel_name : str ,
762+ data_source : DataSourceType ,
763+ key_prefix : Optional [str ] = None ,
764+ ignore_patterns : Optional [List [str ]] = None ,
759765 ) -> Channel :
760766 """Create an input data channel for the training job.
761767
@@ -771,6 +777,9 @@ def create_input_data_channel(
771777
772778 If specified, local data will be uploaded to:
773779 ``s3://<default_bucket_path>/<key_prefix>/<channel_name>/``
780+ ignore_patterns: (Optional[List[str]]) :
781+ The ignore patterns to ignore specific files/folders when uploading to S3.
782+ Example: ['.env', '.git', 'data', '__pycache__'].
774783 """
775784 channel = None
776785 if isinstance (data_source , str ):
@@ -810,11 +819,25 @@ def create_input_data_channel(
810819 )
811820 if self .sagemaker_session .default_bucket_prefix :
812821 key_prefix = f"{ self .sagemaker_session .default_bucket_prefix } /{ key_prefix } "
813- s3_uri = self .sagemaker_session .upload_data (
814- path = data_source ,
815- bucket = self .sagemaker_session .default_bucket (),
816- key_prefix = key_prefix ,
817- )
822+ if ignore_patterns :
823+ tmp_dir = TemporaryDirectory ()
824+ shutil .copytree (
825+ data_source ,
826+ os .path .join (tmp_dir .name , os .path .basename (data_source )),
827+ dirs_exist_ok = True ,
828+ ignore = shutil .ignore_patterns (* ignore_patterns )
829+ )
830+ s3_uri = self .sagemaker_session .upload_data (
831+ path = tmp_dir .name ,
832+ bucket = self .sagemaker_session .default_bucket (),
833+ key_prefix = key_prefix ,
834+ )
835+ else :
836+ s3_uri = self .sagemaker_session .upload_data (
837+ path = data_source ,
838+ bucket = self .sagemaker_session .default_bucket (),
839+ key_prefix = key_prefix ,
840+ )
818841 channel = Channel (
819842 channel_name = channel_name ,
820843 data_source = DataSource (
0 commit comments