Skip to content

fix: forbid extras in Configs #5042

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Feb 24, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/sagemaker/modules/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from __future__ import absolute_import

from typing import Optional, Union
from pydantic import BaseModel, model_validator
from pydantic import BaseModel, model_validator, ConfigDict

import sagemaker_core.shapes as shapes

Expand Down Expand Up @@ -94,6 +94,8 @@ class SourceCode(BaseModel):
If not specified, entry_script must be provided.
"""

model_config = ConfigDict(validate_assignment=True, extra="forbid")

source_dir: Optional[str] = None
requirements: Optional[str] = None
entry_script: Optional[str] = None
Expand Down Expand Up @@ -215,5 +217,7 @@ class InputData(BaseModel):
S3DataSource object, or FileSystemDataSource object.
"""

model_config = ConfigDict(validate_assignment=True, extra="forbid")

channel_name: str = None
data_source: Union[str, FileSystemDataSource, S3DataSource] = None
6 changes: 5 additions & 1 deletion src/sagemaker/modules/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import absolute_import

from typing import Optional, Dict, Any, List
from pydantic import BaseModel, PrivateAttr
from pydantic import BaseModel, PrivateAttr, ConfigDict
from sagemaker.modules.utils import safe_serialize


Expand Down Expand Up @@ -53,6 +53,8 @@ class SMP(BaseModel):
parallelism or expert parallelism.
"""

model_config = ConfigDict(validate_assignment=True, extra="forbid")

hybrid_shard_degree: Optional[int] = None
sm_activation_offloading: Optional[bool] = None
activation_loading_horizon: Optional[int] = None
Expand All @@ -75,6 +77,8 @@ def _to_mp_hyperparameters(self) -> Dict[str, Any]:
class DistributedConfig(BaseModel):
"""Base class for distributed training configurations."""

model_config = ConfigDict(validate_assignment=True, extra="forbid")

_type: str = PrivateAttr()

def model_dump(self, *args, **kwargs):
Expand Down
19 changes: 13 additions & 6 deletions src/sagemaker/modules/train/model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,9 @@ class ModelTrainer(BaseModel):
"LOCAL_CONTAINER" mode.
"""

model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid")
model_config = ConfigDict(
arbitrary_types_allowed=True, validate_assignment=True, extra="forbid"
)

training_mode: Mode = Mode.SAGEMAKER_TRAINING_JOB
sagemaker_session: Optional[Session] = None
Expand Down Expand Up @@ -363,9 +365,10 @@ def _populate_intelligent_defaults_from_model_trainer_space(self):

def __del__(self):
"""Destructor method to clean up the temporary directory."""
# Clean up the temporary directory if it exists
if self._temp_recipe_train_dir is not None:
self._temp_recipe_train_dir.cleanup()
# Clean up the temporary directory if it exists and class was initialized
if hasattr(self, "__pydantic_fields_set__"):
if self._temp_recipe_train_dir is not None:
self._temp_recipe_train_dir.cleanup()

def _validate_training_image_and_algorithm_name(
self, training_image: Optional[str], algorithm_name: Optional[str]
Expand Down Expand Up @@ -792,14 +795,14 @@ def _prepare_train_script(
"""Prepare the training script to be executed in the training job container.

Args:
source_code (SourceCodeConfig): The source code configuration.
source_code (SourceCode): The source code configuration.
"""

base_command = ""
if source_code.command:
if source_code.entry_script:
logger.warning(
"Both 'command' and 'entry_script' are provided in the SourceCodeConfig. "
"Both 'command' and 'entry_script' are provided in the SourceCode. "
+ "Defaulting to 'command'."
)
base_command = source_code.command.split()
Expand Down Expand Up @@ -831,6 +834,10 @@ def _prepare_train_script(
+ "Only .py and .sh scripts are supported."
)
execute_driver = EXECUTE_BASIC_SCRIPT_DRIVER
else:
raise ValueError(
f"Invalid configuration, please provide a valid SourceCode: {source_code}"
)

train_script = TRAIN_SCRIPT_TEMPLATE.format(
working_dir=working_dir,
Expand Down
36 changes: 35 additions & 1 deletion tests/unit/sagemaker/modules/train/test_model_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import json
import os
import pytest
from pydantic import ValidationError
from unittest.mock import patch, MagicMock, ANY

from sagemaker import image_uris
Expand Down Expand Up @@ -438,7 +439,7 @@ def test_create_input_data_channel(mock_default_bucket, mock_upload_data, model_
{
"source_code": DEFAULT_SOURCE_CODE,
"distributed": MPI(
custom_mpi_options=["-x", "VAR1", "-x", "VAR2"],
mpi_additional_options=["-x", "VAR1", "-x", "VAR2"],
),
"expected_template": EXECUTE_MPI_DRIVER,
"expected_hyperparameters": {},
Expand Down Expand Up @@ -1059,3 +1060,36 @@ def mock_upload_data(path, bucket, key_prefix):
hyper_parameters=hyperparameters,
environment=environment,
)


def test_safe_configs():
# Test extra fails
with pytest.raises(ValueError):
SourceCode(entry_point="train.py")
# Test invalid type fails
with pytest.raises(ValueError):
SourceCode(entry_script=1)


@patch("sagemaker.modules.train.model_trainer.TemporaryDirectory")
def test_destructor_cleanup(mock_tmp_dir, modules_session):

with pytest.raises(ValidationError):
model_trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
sagemaker_session=modules_session,
compute="test",
)
mock_tmp_dir.cleanup.assert_not_called()

model_trainer = ModelTrainer(
training_image=DEFAULT_IMAGE,
role=DEFAULT_ROLE,
sagemaker_session=modules_session,
compute=DEFAULT_COMPUTE_CONFIG,
)
model_trainer._temp_recipe_train_dir = mock_tmp_dir
mock_tmp_dir.assert_not_called()
del model_trainer
mock_tmp_dir.cleanup.assert_called_once()