Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/sagemaker/modules/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from __future__ import absolute_import

from typing import Optional, Union
from pydantic import BaseModel, model_validator
from pydantic import BaseModel, model_validator, ConfigDict

import sagemaker_core.shapes as shapes

Expand Down Expand Up @@ -94,6 +94,8 @@ class SourceCode(BaseModel):
If not specified, entry_script must be provided.
"""

model_config = ConfigDict(validate_assignment=True, extra="forbid")

source_dir: Optional[str] = None
requirements: Optional[str] = None
entry_script: Optional[str] = None
Expand Down Expand Up @@ -215,5 +217,7 @@ class InputData(BaseModel):
S3DataSource object, or FileSystemDataSource object.
"""

model_config = ConfigDict(validate_assignment=True, extra="forbid")

channel_name: str = None
data_source: Union[str, FileSystemDataSource, S3DataSource] = None
6 changes: 5 additions & 1 deletion src/sagemaker/modules/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import

from typing import Optional, Dict, Any, List
from pydantic import BaseModel, PrivateAttr
from pydantic import BaseModel, PrivateAttr, ConfigDict
from sagemaker.modules.utils import safe_serialize


Expand Down Expand Up @@ -53,6 +53,8 @@ class SMP(BaseModel):
parallelism or expert parallelism.
"""

model_config = ConfigDict(validate_assignment=True, extra="forbid")

hybrid_shard_degree: Optional[int] = None
sm_activation_offloading: Optional[bool] = None
activation_loading_horizon: Optional[int] = None
Expand All @@ -75,6 +77,8 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]:
class DistributedConfig(BaseModel):
"""Base class for distributed training configurations."""

model_config = ConfigDict(validate_assignment=True, extra="forbid")

_type: str = PrivateAttr()

def model_dump(self, *args, **kwargs):
Expand Down
22 changes: 16 additions & 6 deletions src/sagemaker/modules/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ class ModelTrainer(BaseModel):
"LOCAL_CONTAINER" mode.
"""

model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
model_config = ConfigDict(
arbitrary_types_allowed=True, validate_assignment=True, extra="forbid"
)

training_mode: Mode = Mode.SAGEMAKER_TRAINING_JOB
sagemaker_session: Optional[Session] = None
Expand Down Expand Up @@ -363,9 +365,10 @@ def _populate_intelligent_defaults_from_model_trainer_space(self):

def __del__(self):
"""Destructor method to clean up the temporary directory."""
# Clean up the temporary directory if it exists
if self._temp_recipe_train_dir is not None:
self._temp_recipe_train_dir.cleanup()
# Clean up the temporary directory if it exists and class was initialized
if hasattr(self, "__pydantic_fields_set__"):
if self._temp_recipe_train_dir is not None:
self._temp_recipe_train_dir.cleanup()

def _validate_training_image_and_algorithm_name(
self, training_image: Optional[str], algorithm_name: Optional[str]
Expand Down Expand Up @@ -792,14 +795,14 @@ def _prepare_train_script(
"""Prepare the training script to be executed in the training job container.

Args:
source_code (SourceCodeConfig): The source code configuration.
source_code (SourceCode): The source code configuration.
"""

base_command = ""
if source_code.command:
if source_code.entry_script:
logger.warning(
"Both 'command' and 'entry_script' are provided in the SourceCodeConfig. "
"Both 'command' and 'entry_script' are provided in the SourceCode. "
+ "Defaulting to 'command'."
)
base_command = source_code.command.split()
Expand Down Expand Up @@ -831,6 +834,13 @@ def _prepare_train_script(
+ "Only .py and .sh scripts are supported."
)
execute_driver = EXECUTE_BASIC_SCRIPT_DRIVER
else:
# This should never be reached, as the source_code should have been validated.
raise ValueError(
f"Unsupported SourceCode or DistributedConfig: {source_code}, {distributed}."
+ "Please provide a valid configuration with atleast one of 'command'"
+ " or entry_script'."
)

train_script = TRAIN_SCRIPT_TEMPLATE.format(
working_dir=working_dir,
Expand Down
36 changes: 35 additions & 1 deletion tests/unit/sagemaker/modules/train/test_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import os
import pytest
from pydantic import ValidationError
from unittest.mock import patch, MagicMock, ANY

from sagemaker import image_uris
Expand Down Expand Up @@ -438,7 +439,7 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_
{
"source_code": DEFAULT_SOURCE_CODE,
"distributed": MPI(
custom_mpi_options=["-x", "VAR1", "-x", "VAR2"],
mpi_additional_options=["-x", "VAR1", "-x", "VAR2"],
),
"expected_template": EXECUTE_MPI_DRIVER,
"expected_hyperparameters": {},
Expand Down Expand Up @@ -1059,3 +1060,36 @@ def mock_upload_data(path, bucket, key_prefix):
hyper_parameters=hyperparameters,
environment=environment,
)


def test_safe_configs():
# Test extra fails
with pytest.raises(ValueError):
SourceCode(entry_point="train.py")
# Test invalid type fails
with pytest.raises(ValueError):
SourceCode(entry_script=1)


@patch("sagemaker.modules.train.model_trainer.TemporaryDirectory")
def test_destructor_cleanup(mock_tmp_dir, modules_session):

with pytest.raises(ValidationError):
model_trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
sagemaker_session=modules_session,
compute="test",
)
mock_tmp_dir.cleanup.assert_not_called()

model_trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
sagemaker_session=modules_session,
compute=DEFAULT_COMPUTE_CONFIG,
)
model_trainer._temp_recipe_train_dir = mock_tmp_dir
mock_tmp_dir.assert_not_called()
del model_trainer
mock_tmp_dir.cleanup.assert_called_once()
Loading