Skip to content

Commit 22a119d

Browse files
committed
add default ignore_patterns, fix minor path issue when uploaded to s3
1 parent c521d19 commit 22a119d

File tree

3 files changed

+18
-11
lines changed

3 files changed

+18
-11
lines changed

src/sagemaker/modules/configs.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,15 +97,15 @@ class SourceCode(BaseConfig):
9797
The command(s) to execute in the training job container. Example: "python my_script.py".
9898
If not specified, entry_script must be provided.
9999
ignore_patterns: (Optional[List[str]]) :
100-
The ignore patterns to ignore specific files/folders when uploading to S3. Example:
101-
['.env', '.git', 'data', '__pycache__'].
100+
The ignore patterns to ignore specific files/folders when uploading to S3. If not specified,
101+
default to: ['.env', '.git', '__pycache__', '.DS_Store'].
102102
"""
103103

104104
source_dir: Optional[str] = None
105105
requirements: Optional[str] = None
106106
entry_script: Optional[str] = None
107107
command: Optional[str] = None
108-
ignore_patterns: Optional[List[str]] = None
108+
ignore_patterns: Optional[List[str]] = [".env", ".git", "__pycache__", ".DS_Store"]
109109

110110

111111
class Compute(shapes.ResourceConfig):

src/sagemaker/modules/train/model_trainer.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class ModelTrainer(BaseModel):
119119
from sagemaker.modules.train import ModelTrainer
120120
from sagemaker.modules.configs import SourceCode, Compute, InputData
121121
122-
ignore_patterns = ['.env', '.git', 'data', '__pycache__']
122+
ignore_patterns = ['.env', '.git', '__pycache__', '.DS_Store', 'data']
123123
source_code = SourceCode(source_dir="source", entry_script="train.py", ignore_patterns=ignore_patterns)
124124
training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image"
125125
model_trainer = ModelTrainer(
@@ -677,6 +677,7 @@ def train(
677677
channel_name=SM_DRIVERS,
678678
data_source=tmp_dir.name,
679679
key_prefix=input_data_key_prefix,
680+
ignore_patterns=self.source_code.ignore_patterns,
680681
)
681682
final_input_data_config.append(sm_drivers_channel)
682683

@@ -779,7 +780,7 @@ def create_input_data_channel(
779780
``s3://<default_bucket_path>/<key_prefix>/<channel_name>/``
780781
ignore_patterns: (Optional[List[str]]) :
781782
The ignore patterns to ignore specific files/folders when uploading to S3.
782-
Example: ['.env', '.git', 'data', '__pycache__'].
783+
If not specified, default to: ['.env', '.git', '__pycache__', '.DS_Store'].
783784
"""
784785
channel = None
785786
if isinstance(data_source, str):
@@ -819,16 +820,19 @@ def create_input_data_channel(
819820
)
820821
if self.sagemaker_session.default_bucket_prefix:
821822
key_prefix = f"{self.sagemaker_session.default_bucket_prefix}/{key_prefix}"
822-
if ignore_patterns:
823+
if ignore_patterns and _is_valid_path(data_source, path_type="Directory"):
823824
tmp_dir = TemporaryDirectory()
825+
copied_path = os.path.join(
826+
tmp_dir.name, os.path.basename(os.path.normpath(data_source))
827+
)
824828
shutil.copytree(
825829
data_source,
826-
os.path.join(tmp_dir.name, os.path.basename(data_source)),
830+
copied_path,
827831
dirs_exist_ok=True,
828832
ignore=shutil.ignore_patterns(*ignore_patterns),
829833
)
830834
s3_uri = self.sagemaker_session.upload_data(
831-
path=tmp_dir.name,
835+
path=copied_path,
832836
bucket=self.sagemaker_session.default_bucket(),
833837
key_prefix=key_prefix,
834838
)
@@ -884,7 +888,10 @@ def _get_input_data_config(
884888
channels.append(input_data)
885889
elif isinstance(input_data, InputData):
886890
channel = self.create_input_data_channel(
887-
input_data.channel_name, input_data.data_source, key_prefix=key_prefix
891+
input_data.channel_name,
892+
input_data.data_source,
893+
key_prefix=key_prefix,
894+
ignore_patterns=self.source_code.ignore_patterns,
888895
)
889896
channels.append(channel)
890897
else:

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def model_trainer():
208208
"source_code": SourceCode(
209209
source_dir=DEFAULT_SOURCE_DIR,
210210
command="python custom_script.py",
211-
ignore_patterns=["data"]
211+
ignore_patterns=["data"],
212212
),
213213
},
214214
"should_throw": False,
@@ -224,7 +224,7 @@ def model_trainer():
224224
"supported_source_code_local_tar_file",
225225
"supported_source_code_s3_dir",
226226
"supported_source_code_s3_tar_file",
227-
"supported_source_code_ignore_patterns"
227+
"supported_source_code_ignore_patterns",
228228
],
229229
)
230230
def test_model_trainer_param_validation(test_case, modules_session):

0 commit comments

Comments
 (0)