@@ -119,7 +119,7 @@ class ModelTrainer(BaseModel):
119
119
from sagemaker.modules.train import ModelTrainer
120
120
from sagemaker.modules.configs import SourceCode, Compute, InputData
121
121
122
- ignore_patterns = ['.env', '.git', 'data ', '__pycache__ ']
122
+ ignore_patterns = ['.env', '.git', '__pycache__ ', '.DS_Store', 'data ']
123
123
source_code = SourceCode(source_dir="source", entry_script="train.py", ignore_patterns=ignore_patterns)
124
124
training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image"
125
125
model_trainer = ModelTrainer(
@@ -677,6 +677,7 @@ def train(
677
677
channel_name = SM_DRIVERS ,
678
678
data_source = tmp_dir .name ,
679
679
key_prefix = input_data_key_prefix ,
680
+ ignore_patterns = self .source_code .ignore_patterns ,
680
681
)
681
682
final_input_data_config .append (sm_drivers_channel )
682
683
@@ -779,7 +780,7 @@ def create_input_data_channel(
779
780
``s3://<default_bucket_path>/<key_prefix>/<channel_name>/``
780
781
ignore_patterns: (Optional[List[str]]) :
781
782
The ignore patterns to ignore specific files/folders when uploading to S3.
782
- Example : ['.env', '.git', 'data ', '__pycache__ '].
783
+ If not specified, default to : ['.env', '.git', '__pycache__ ', '.DS_Store '].
783
784
"""
784
785
channel = None
785
786
if isinstance (data_source , str ):
@@ -819,16 +820,19 @@ def create_input_data_channel(
819
820
)
820
821
if self .sagemaker_session .default_bucket_prefix :
821
822
key_prefix = f"{ self .sagemaker_session .default_bucket_prefix } /{ key_prefix } "
822
- if ignore_patterns :
823
+ if ignore_patterns and _is_valid_path ( data_source , path_type = "Directory" ) :
823
824
tmp_dir = TemporaryDirectory ()
825
+ copied_path = os .path .join (
826
+ tmp_dir .name , os .path .basename (os .path .normpath (data_source ))
827
+ )
824
828
shutil .copytree (
825
829
data_source ,
826
- os . path . join ( tmp_dir . name , os . path . basename ( data_source )) ,
830
+ copied_path ,
827
831
dirs_exist_ok = True ,
828
832
ignore = shutil .ignore_patterns (* ignore_patterns ),
829
833
)
830
834
s3_uri = self .sagemaker_session .upload_data (
831
- path = tmp_dir . name ,
835
+ path = copied_path ,
832
836
bucket = self .sagemaker_session .default_bucket (),
833
837
key_prefix = key_prefix ,
834
838
)
@@ -884,7 +888,10 @@ def _get_input_data_config(
884
888
channels .append (input_data )
885
889
elif isinstance (input_data , InputData ):
886
890
channel = self .create_input_data_channel (
887
- input_data .channel_name , input_data .data_source , key_prefix = key_prefix
891
+ input_data .channel_name ,
892
+ input_data .data_source ,
893
+ key_prefix = key_prefix ,
894
+ ignore_patterns = self .source_code .ignore_patterns ,
888
895
)
889
896
channels .append (channel )
890
897
else :
0 commit comments