Skip to content

Commit 4b11a35

Browse files
committed
Add ignore_patterns in ModelTrainer to ignore specific files/folders
1 parent baf1601 commit 4b11a35

File tree

2 files changed

+35
-8
lines changed

2 files changed

+35
-8
lines changed

src/sagemaker/modules/configs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from __future__ import absolute_import
2323

24-
from typing import Optional, Union
24+
from typing import Optional, Union, List
2525
from pydantic import BaseModel, model_validator, ConfigDict
2626

2727
import sagemaker_core.shapes as shapes
@@ -96,12 +96,16 @@ class SourceCode(BaseConfig):
9696
command (Optional[str]):
9797
The command(s) to execute in the training job container. Example: "python my_script.py".
9898
If not specified, entry_script must be provided.
99+
ignore_patterns: (Optional[List[str]]) :
100+
The ignore patterns to ignore specific files/folders when uploading to S3. Example:
101+
['.env', '.git', 'data', '__pycache__'].
99102
"""
100103

101104
source_dir: Optional[str] = None
102105
requirements: Optional[str] = None
103106
entry_script: Optional[str] = None
104107
command: Optional[str] = None
108+
ignore_patterns: Optional[List[str]] = None
105109

106110

107111
class Compute(shapes.ResourceConfig):

src/sagemaker/modules/train/model_trainer.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ class ModelTrainer(BaseModel):
119119
from sagemaker.modules.train import ModelTrainer
120120
from sagemaker.modules.configs import SourceCode, Compute, InputData
121121
122-
source_code = SourceCode(source_dir="source", entry_script="train.py")
122+
ignore_patterns = ['.env', '.git', 'data', '__pycache__']
123+
source_code = SourceCode(source_dir="source", entry_script="train.py", ignore_patterns=ignore_patterns)
123124
training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image"
124125
model_trainer = ModelTrainer(
125126
training_image=training_image,
@@ -654,6 +655,7 @@ def train(
654655
channel_name=SM_CODE,
655656
data_source=self.source_code.source_dir,
656657
key_prefix=input_data_key_prefix,
658+
ignore_patterns=self.source_code.ignore_patterns,
657659
)
658660
final_input_data_config.append(source_code_channel)
659661

@@ -755,7 +757,11 @@ def train(
755757
local_container.train(wait)
756758

757759
def create_input_data_channel(
758-
self, channel_name: str, data_source: DataSourceType, key_prefix: Optional[str] = None
760+
self,
761+
channel_name: str,
762+
data_source: DataSourceType,
763+
key_prefix: Optional[str] = None,
764+
ignore_patterns: Optional[List[str]] = None,
759765
) -> Channel:
760766
"""Create an input data channel for the training job.
761767
@@ -771,6 +777,9 @@ def create_input_data_channel(
771777
772778
If specified, local data will be uploaded to:
773779
``s3://<default_bucket_path>/<key_prefix>/<channel_name>/``
780+
ignore_patterns: (Optional[List[str]]) :
781+
The ignore patterns to ignore specific files/folders when uploading to S3.
782+
Example: ['.env', '.git', 'data', '__pycache__'].
774783
"""
775784
channel = None
776785
if isinstance(data_source, str):
@@ -810,11 +819,25 @@ def create_input_data_channel(
810819
)
811820
if self.sagemaker_session.default_bucket_prefix:
812821
key_prefix = f"{self.sagemaker_session.default_bucket_prefix}/{key_prefix}"
813-
s3_uri = self.sagemaker_session.upload_data(
814-
path=data_source,
815-
bucket=self.sagemaker_session.default_bucket(),
816-
key_prefix=key_prefix,
817-
)
822+
if ignore_patterns:
823+
tmp_dir = TemporaryDirectory()
824+
shutil.copytree(
825+
data_source,
826+
os.path.join(tmp_dir.name, os.path.basename(data_source)),
827+
dirs_exist_ok=True,
828+
ignore=shutil.ignore_patterns(*ignore_patterns)
829+
)
830+
s3_uri = self.sagemaker_session.upload_data(
831+
path=tmp_dir.name,
832+
bucket=self.sagemaker_session.default_bucket(),
833+
key_prefix=key_prefix,
834+
)
835+
else:
836+
s3_uri = self.sagemaker_session.upload_data(
837+
path=data_source,
838+
bucket=self.sagemaker_session.default_bucket(),
839+
key_prefix=key_prefix,
840+
)
818841
channel = Channel(
819842
channel_name=channel_name,
820843
data_source=DataSource(

0 commit comments

Comments
 (0)