Skip to content

Commit c9981f9

Browse files
authored
Merge branch 'master' into master-distributed-config-extensible
2 parents 3ef2f55 + b116e2f commit c9981f9

File tree

7 files changed

+83
-51
lines changed

7 files changed

+83
-51
lines changed

doc/frameworks/pytorch/using_pytorch.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@ To train a PyTorch model by using the SageMaker Python SDK:
2828
Prepare a PyTorch Training Script
2929
=================================
3030

31-
Your PyTorch training script must be a Python 3.6 compatible source file.
32-
3331
Prepare your script in a separate source file than the notebook, terminal session, or source file you're
3432
using to submit the script to SageMaker via a ``PyTorch`` Estimator. This will be discussed in further detail below.
3533

src/sagemaker/modules/configs.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from __future__ import absolute_import
2323

2424
from typing import Optional, Union
25-
from pydantic import BaseModel, model_validator
25+
from pydantic import BaseModel, model_validator, ConfigDict
2626

2727
import sagemaker_core.shapes as shapes
2828

@@ -74,7 +74,13 @@
7474
]
7575

7676

77-
class SourceCode(BaseModel):
77+
class BaseConfig(BaseModel):
78+
"""BaseConfig"""
79+
80+
model_config = ConfigDict(validate_assignment=True, extra="forbid")
81+
82+
83+
class SourceCode(BaseConfig):
7884
"""SourceCode.
7985
8086
The SourceCode class allows the user to specify the source code location, dependencies,
@@ -194,7 +200,7 @@ def _to_vpc_config(self) -> shapes.VpcConfig:
194200
return shapes.VpcConfig(**filtered_dict)
195201

196202

197-
class InputData(BaseModel):
203+
class InputData(BaseConfig):
198204
"""InputData.
199205
200206
This config allows the user to specify an input data source for the training job.

src/sagemaker/modules/distributed.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@
1717

1818
from abc import ABC, abstractmethod
1919
from typing import Optional, Dict, Any, List
20-
from pydantic import BaseModel
20+
from pydantic import PrivateAttr
21+
2122
from sagemaker.modules.utils import safe_serialize
2223
from sagemaker.modules.constants import SM_DRIVERS_LOCAL_PATH
24+
from sagemaker.modules.configs import BaseConfig
2325

2426

25-
class SMP(BaseModel):
27+
class SMP(BaseConfig):
2628
"""SMP.
2729
2830
This class is used for configuring the SageMaker Model Parallelism v2 parameters.
@@ -76,7 +78,7 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]:
7678
return hyperparameters
7779

7880

79-
class DistributedConfig(BaseModel, ABC):
81+
class DistributedConfig(BaseConfig, ABC):
8082
"""Abstract base class for distributed training configurations.
8183
8284
This class defines the interface that all distributed training configurations

src/sagemaker/modules/train/model_trainer.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,9 @@ class ModelTrainer(BaseModel):
204204
"LOCAL_CONTAINER" mode.
205205
"""
206206

207-
model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
207+
model_config = ConfigDict(
208+
arbitrary_types_allowed=True, validate_assignment=True, extra="forbid"
209+
)
208210

209211
training_mode: Mode = Mode.SAGEMAKER_TRAINING_JOB
210212
sagemaker_session: Optional[Session] = None
@@ -362,9 +364,10 @@ def _populate_intelligent_defaults_from_model_trainer_space(self):
362364

363365
def __del__(self):
364366
"""Destructor method to clean up the temporary directory."""
365-
# Clean up the temporary directory if it exists
366-
if self._temp_recipe_train_dir is not None:
367-
self._temp_recipe_train_dir.cleanup()
367+
# Clean up the temporary directory if it exists and class was initialized
368+
if hasattr(self, "__pydantic_fields_set__"):
369+
if self._temp_recipe_train_dir is not None:
370+
self._temp_recipe_train_dir.cleanup()
368371

369372
def _validate_training_image_and_algorithm_name(
370373
self, training_image: Optional[str], algorithm_name: Optional[str]
@@ -796,14 +799,14 @@ def _prepare_train_script(
796799
"""Prepare the training script to be executed in the training job container.
797800
798801
Args:
799-
source_code (SourceCodeConfig): The source code configuration.
802+
source_code (SourceCode): The source code configuration.
800803
"""
801804

802805
base_command = ""
803806
if source_code.command:
804807
if source_code.entry_script:
805808
logger.warning(
806-
"Both 'command' and 'entry_script' are provided in the SourceCodeConfig. "
809+
"Both 'command' and 'entry_script' are provided in the SourceCode. "
807810
+ "Defaulting to 'command'."
808811
)
809812
base_command = source_code.command.split()
@@ -832,6 +835,13 @@ def _prepare_train_script(
832835
+ "Only .py and .sh scripts are supported."
833836
)
834837
execute_driver = EXECUTE_BASIC_SCRIPT_DRIVER
838+
else:
839+
# This should never be reached, as the source_code should have been validated.
840+
raise ValueError(
841+
f"Unsupported SourceCode or DistributedConfig: {source_code}, {distributed}."
842+
+ "Please provide a valid configuration with atleast one of 'command'"
843+
+ " or entry_script'."
844+
)
835845

836846
train_script = TRAIN_SCRIPT_TEMPLATE.format(
837847
working_dir=working_dir,

src/sagemaker/serve/detector/dependency_manager.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,22 +34,34 @@ def capture_dependencies(dependencies: dict, work_dir: Path, capture_all: bool =
3434
"""Placeholder docstring"""
3535
path = work_dir.joinpath("requirements.txt")
3636
if "auto" in dependencies and dependencies["auto"]:
37+
import site
38+
39+
pkl_path = work_dir.joinpath(PKL_FILE_NAME)
40+
dest_path = path
41+
site_packages_dir = site.getsitepackages()[0]
42+
pickle_command_dir = "/sagemaker/serve/detector"
43+
3744
command = [
3845
sys.executable,
39-
Path(__file__).parent.joinpath("pickle_dependencies.py"),
40-
"--pkl_path",
41-
work_dir.joinpath(PKL_FILE_NAME),
42-
"--dest",
43-
path,
46+
"-c",
4447
]
4548

4649
if capture_all:
47-
command.append("--capture_all")
50+
command.append(
51+
f"from pickle_dependencies import get_all_requirements;"
52+
f'get_all_requirements("{dest_path}")'
53+
)
54+
else:
55+
command.append(
56+
f"from pickle_dependencies import get_requirements_for_pkl_file;"
57+
f'get_requirements_for_pkl_file("{pkl_path}", "{dest_path}")'
58+
)
4859

4960
subprocess.run(
5061
command,
5162
env={"SETUPTOOLS_USE_DISTUTILS": "stdlib"},
5263
check=True,
64+
cwd=site_packages_dir + pickle_command_dir,
5365
)
5466

5567
with open(path, "r") as f:

src/sagemaker/serve/detector/pickle_dependencies.py

Lines changed: 0 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import absolute_import
44
from pathlib import Path
55
from typing import List
6-
import argparse
76
import email.parser
87
import email.policy
98
import json
@@ -129,32 +128,3 @@ def get_all_requirements(dest: Path):
129128
version = package_info.get("version")
130129

131130
out.write(f"{name}=={version}\n")
132-
133-
134-
def parse_args():
135-
"""Placeholder docstring"""
136-
parser = argparse.ArgumentParser(
137-
prog="pkl_requirements", description="Generates a requirements.txt for a cloudpickle file"
138-
)
139-
parser.add_argument("--pkl_path", required=True, help="path of the pkl file")
140-
parser.add_argument("--dest", required=True, help="path of the destination requirements.txt")
141-
parser.add_argument(
142-
"--capture_all",
143-
action="store_true",
144-
help="capture all dependencies in current environment",
145-
)
146-
args = parser.parse_args()
147-
return (Path(args.pkl_path), Path(args.dest), args.capture_all)
148-
149-
150-
def main():
151-
"""Placeholder docstring"""
152-
pkl_path, dest, capture_all = parse_args()
153-
if capture_all:
154-
get_all_requirements(dest)
155-
else:
156-
get_requirements_for_pkl_file(pkl_path, dest)
157-
158-
159-
if __name__ == "__main__":
160-
main()

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import os
2020
import pytest
21+
from pydantic import ValidationError
2122
from unittest.mock import patch, MagicMock, ANY
2223

2324
from sagemaker import image_uris
@@ -442,7 +443,7 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_
442443
{
443444
"source_code": DEFAULT_SOURCE_CODE,
444445
"distributed": MPI(
445-
custom_mpi_options=["-x", "VAR1", "-x", "VAR2"],
446+
mpi_additional_options=["-x", "VAR1", "-x", "VAR2"],
446447
),
447448
"expected_template": EXEUCTE_DISTRIBUTED_DRIVER.format(
448449
driver_name="MPI", driver_script="mpi_driver.py"
@@ -1059,3 +1060,36 @@ def mock_upload_data(path, bucket, key_prefix):
10591060
hyper_parameters=hyperparameters,
10601061
environment=environment,
10611062
)
1063+
1064+
1065+
def test_safe_configs():
1066+
# Test extra fails
1067+
with pytest.raises(ValueError):
1068+
SourceCode(entry_point="train.py")
1069+
# Test invalid type fails
1070+
with pytest.raises(ValueError):
1071+
SourceCode(entry_script=1)
1072+
1073+
1074+
@patch("sagemaker.modules.train.model_trainer.TemporaryDirectory")
1075+
def test_destructor_cleanup(mock_tmp_dir, modules_session):
1076+
1077+
with pytest.raises(ValidationError):
1078+
model_trainer = ModelTrainer(
1079+
training_image=DEFAULT_IMAGE,
1080+
role=DEFAULT_ROLE,
1081+
sagemaker_session=modules_session,
1082+
compute="test",
1083+
)
1084+
mock_tmp_dir.cleanup.assert_not_called()
1085+
1086+
model_trainer = ModelTrainer(
1087+
training_image=DEFAULT_IMAGE,
1088+
role=DEFAULT_ROLE,
1089+
sagemaker_session=modules_session,
1090+
compute=DEFAULT_COMPUTE_CONFIG,
1091+
)
1092+
model_trainer._temp_recipe_train_dir = mock_tmp_dir
1093+
mock_tmp_dir.assert_not_called()
1094+
del model_trainer
1095+
mock_tmp_dir.cleanup.assert_called_once()

0 commit comments

Comments
 (0)