Skip to content

Commit da684b8

Browse files
authored
Merge branch 'aws:master' into numpy-upgrade
2 parents 98719d7 + 844b558 commit da684b8

File tree

7 files changed

+71
-18
lines changed

7 files changed

+71
-18
lines changed

src/sagemaker/modules/configs.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from __future__ import absolute_import
2323

24-
from typing import Optional, Union
24+
from typing import Optional, Union, List
2525
from pydantic import BaseModel, model_validator, ConfigDict
2626

2727
import sagemaker_core.shapes as shapes
@@ -96,12 +96,23 @@ class SourceCode(BaseConfig):
9696
command (Optional[str]):
9797
The command(s) to execute in the training job container. Example: "python my_script.py".
9898
If not specified, entry_script must be provided.
99+
ignore_patterns: (Optional[List[str]]) :
100+
The ignore patterns to ignore specific files/folders when uploading to S3. If not specified,
101+
default to: ['.env', '.git', '__pycache__', '.DS_Store', '.cache', '.ipynb_checkpoints'].
99102
"""
100103

101104
source_dir: Optional[str] = None
102105
requirements: Optional[str] = None
103106
entry_script: Optional[str] = None
104107
command: Optional[str] = None
108+
ignore_patterns: Optional[List[str]] = [
109+
".env",
110+
".git",
111+
"__pycache__",
112+
".DS_Store",
113+
".cache",
114+
".ipynb_checkpoints",
115+
]
105116

106117

107118
class Compute(shapes.ResourceConfig):

src/sagemaker/modules/train/model_trainer.py

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,8 @@ class ModelTrainer(BaseModel):
119119
from sagemaker.modules.train import ModelTrainer
120120
from sagemaker.modules.configs import SourceCode, Compute, InputData
121121
122-
source_code = SourceCode(source_dir="source", entry_script="train.py")
122+
ignore_patterns = ['.env', '.git', '__pycache__', '.DS_Store', 'data']
123+
source_code = SourceCode(source_dir="source", entry_script="train.py", ignore_patterns=ignore_patterns)
123124
training_image = "123456789012.dkr.ecr.us-west-2.amazonaws.com/my-training-image"
124125
model_trainer = ModelTrainer(
125126
training_image=training_image,
@@ -654,6 +655,7 @@ def train(
654655
channel_name=SM_CODE,
655656
data_source=self.source_code.source_dir,
656657
key_prefix=input_data_key_prefix,
658+
ignore_patterns=self.source_code.ignore_patterns,
657659
)
658660
final_input_data_config.append(source_code_channel)
659661

@@ -675,6 +677,7 @@ def train(
675677
channel_name=SM_DRIVERS,
676678
data_source=tmp_dir.name,
677679
key_prefix=input_data_key_prefix,
680+
ignore_patterns=self.source_code.ignore_patterns,
678681
)
679682
final_input_data_config.append(sm_drivers_channel)
680683

@@ -755,7 +758,11 @@ def train(
755758
local_container.train(wait)
756759

757760
def create_input_data_channel(
758-
self, channel_name: str, data_source: DataSourceType, key_prefix: Optional[str] = None
761+
self,
762+
channel_name: str,
763+
data_source: DataSourceType,
764+
key_prefix: Optional[str] = None,
765+
ignore_patterns: Optional[List[str]] = None,
759766
) -> Channel:
760767
"""Create an input data channel for the training job.
761768
@@ -771,6 +778,10 @@ def create_input_data_channel(
771778
772779
If specified, local data will be uploaded to:
773780
``s3://<default_bucket_path>/<key_prefix>/<channel_name>/``
781+
ignore_patterns: (Optional[List[str]]) :
782+
The ignore patterns to ignore specific files/folders when uploading to S3.
783+
If not specified, default to: ['.env', '.git', '__pycache__', '.DS_Store',
784+
'.cache', '.ipynb_checkpoints'].
774785
"""
775786
channel = None
776787
if isinstance(data_source, str):
@@ -810,11 +821,28 @@ def create_input_data_channel(
810821
)
811822
if self.sagemaker_session.default_bucket_prefix:
812823
key_prefix = f"{self.sagemaker_session.default_bucket_prefix}/{key_prefix}"
813-
s3_uri = self.sagemaker_session.upload_data(
814-
path=data_source,
815-
bucket=self.sagemaker_session.default_bucket(),
816-
key_prefix=key_prefix,
817-
)
824+
if ignore_patterns and _is_valid_path(data_source, path_type="Directory"):
825+
tmp_dir = TemporaryDirectory()
826+
copied_path = os.path.join(
827+
tmp_dir.name, os.path.basename(os.path.normpath(data_source))
828+
)
829+
shutil.copytree(
830+
data_source,
831+
copied_path,
832+
dirs_exist_ok=True,
833+
ignore=shutil.ignore_patterns(*ignore_patterns),
834+
)
835+
s3_uri = self.sagemaker_session.upload_data(
836+
path=copied_path,
837+
bucket=self.sagemaker_session.default_bucket(),
838+
key_prefix=key_prefix,
839+
)
840+
else:
841+
s3_uri = self.sagemaker_session.upload_data(
842+
path=data_source,
843+
bucket=self.sagemaker_session.default_bucket(),
844+
key_prefix=key_prefix,
845+
)
818846
channel = Channel(
819847
channel_name=channel_name,
820848
data_source=DataSource(
@@ -861,7 +889,9 @@ def _get_input_data_config(
861889
channels.append(input_data)
862890
elif isinstance(input_data, InputData):
863891
channel = self.create_input_data_channel(
864-
input_data.channel_name, input_data.data_source, key_prefix=key_prefix
892+
input_data.channel_name,
893+
input_data.data_source,
894+
key_prefix=key_prefix,
865895
)
866896
channels.append(channel)
867897
else:

src/sagemaker/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7509,7 +7509,7 @@ def get_model_package_args(
75097509
if source_uri is not None:
75107510
model_package_args["source_uri"] = source_uri
75117511
if model_life_cycle is not None:
7512-
model_package_args["model_life_cycle"] = model_life_cycle
7512+
model_package_args["model_life_cycle"] = model_life_cycle._to_request_dict()
75137513
if model_card is not None:
75147514
original_req = model_card._create_request_args()
75157515
if original_req.get("ModelCardName") is not None:

tests/integ/sagemaker/workflow/test_model_create_and_registration.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
from sagemaker.s3 import S3Uploader
4949
from sagemaker.sklearn import SKLearnModel, SKLearnProcessor
5050
from sagemaker.mxnet.model import MXNetModel
51+
from sagemaker.model_life_cycle import ModelLifeCycle
5152
from sagemaker.workflow.condition_step import ConditionStep
5253
from sagemaker.workflow.parameters import ParameterInteger, ParameterString
5354
from sagemaker.workflow.pipeline import Pipeline
@@ -1005,11 +1006,11 @@ def test_model_registration_with_model_life_cycle_object(
10051006
py_version="py3",
10061007
role=role,
10071008
)
1008-
create_model_life_cycle = {
1009-
"Stage": "Development",
1010-
"StageStatus": "In-Progress",
1011-
"StageDescription": "Development In Progress",
1012-
}
1009+
create_model_life_cycle = ModelLifeCycle(
1010+
stage="Development",
1011+
stage_status="In-Progress",
1012+
stage_description="Development In Progress",
1013+
)
10131014

10141015
step_register = RegisterModel(
10151016
name="MyRegisterModelStep",

tests/integ/test_model_package.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def test_update_model_life_cycle_model_package(sagemaker_session):
103103
inference_instances=["ml.m5.large"],
104104
transform_instances=["ml.m5.large"],
105105
model_package_group_name=model_group_name,
106-
model_life_cycle=create_model_life_cycle._to_request_dict(),
106+
model_life_cycle=create_model_life_cycle,
107107
)
108108

109109
desc_model_package = sagemaker_session.sagemaker_client.describe_model_package(

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,17 @@ def model_trainer():
202202
},
203203
"should_throw": False,
204204
},
205+
{
206+
"init_params": {
207+
"training_image": DEFAULT_IMAGE,
208+
"source_code": SourceCode(
209+
source_dir=DEFAULT_SOURCE_DIR,
210+
command="python custom_script.py",
211+
ignore_patterns=["data"],
212+
),
213+
},
214+
"should_throw": False,
215+
},
205216
],
206217
ids=[
207218
"no_params",
@@ -213,6 +224,7 @@ def model_trainer():
213224
"supported_source_code_local_tar_file",
214225
"supported_source_code_s3_dir",
215226
"supported_source_code_s3_tar_file",
227+
"supported_source_code_ignore_patterns",
216228
],
217229
)
218230
def test_model_trainer_param_validation(test_case, modules_session):

tests/unit/test_estimator.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4369,7 +4369,6 @@ def test_register_default_image(sagemaker_session):
43694369
stage_status="In-Progress",
43704370
stage_description="Sending for Staging Verification",
43714371
)
4372-
update_model_life_cycle_req = update_model_life_cycle._to_request_dict()
43734372

43744373
estimator.register(
43754374
content_types=content_types,
@@ -4384,7 +4383,7 @@ def test_register_default_image(sagemaker_session):
43844383
nearest_model_name=nearest_model_name,
43854384
data_input_configuration=data_input_config,
43864385
model_card=model_card,
4387-
model_life_cycle=update_model_life_cycle_req,
4386+
model_life_cycle=update_model_life_cycle,
43884387
)
43894388
sagemaker_session.create_model.assert_not_called()
43904389
exp_model_card = {

0 commit comments

Comments
 (0)