Skip to content

change: Improve defaults handling in ModelTrainer #5170

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 20 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ dependencies = [
"tblib>=1.7.0,<4",
"tqdm",
"urllib3>=1.26.8,<3.0.0",
"uvicorn"
"uvicorn",
"graphene>=3,<4"
]

[project.scripts]
Expand Down
80 changes: 74 additions & 6 deletions src/sagemaker/modules/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from sagemaker_core.shapes import (
StoppingCondition,
RetryStrategy,
OutputDataConfig,
Channel,
ShuffleConfig,
DataSource,
Expand All @@ -43,8 +42,6 @@
RemoteDebugConfig,
SessionChainingConfig,
InstanceGroup,
TensorBoardOutputConfig,
CheckpointConfig,
)

from sagemaker.modules.utils import convert_unassigned_to_none
Expand Down Expand Up @@ -131,6 +128,8 @@ class Compute(shapes.ResourceConfig):
subsequent training jobs.
instance_groups (Optional[List[InstanceGroup]]):
A list of instance groups for heterogeneous clusters to be used in the training job.
training_plan_arn (Optional[str]):
The Amazon Resource Name (ARN) of the training plan to use for this resource configuration.
enable_managed_spot_training (Optional[bool]):
To train models using managed spot training, choose True. Managed spot training
provides a fully managed and scalable infrastructure for training machine learning
Expand All @@ -151,8 +150,12 @@ def _to_resource_config(self) -> shapes.ResourceConfig:
compute_config_dict = self.model_dump()
resource_config_fields = set(shapes.ResourceConfig.__annotations__.keys())
filtered_dict = {
k: v for k, v in compute_config_dict.items() if k in resource_config_fields
k: v
for k, v in compute_config_dict.items()
if k in resource_config_fields and v is not None
}
if not filtered_dict:
return None
return shapes.ResourceConfig(**filtered_dict)


Expand Down Expand Up @@ -194,10 +197,12 @@ def _model_validator(self) -> "Networking":
def _to_vpc_config(self) -> shapes.VpcConfig:
"""Convert to a sagemaker_core.shapes.VpcConfig object."""
compute_config_dict = self.model_dump()
resource_config_fields = set(shapes.VpcConfig.__annotations__.keys())
vpc_config_fields = set(shapes.VpcConfig.__annotations__.keys())
filtered_dict = {
k: v for k, v in compute_config_dict.items() if k in resource_config_fields
k: v for k, v in compute_config_dict.items() if k in vpc_config_fields and v is not None
}
if not filtered_dict:
return None
return shapes.VpcConfig(**filtered_dict)


Expand All @@ -224,3 +229,66 @@ class InputData(BaseConfig):

channel_name: str = None
data_source: Union[str, FileSystemDataSource, S3DataSource] = None


class OutputDataConfig(shapes.OutputDataConfig):
"""OutputDataConfig.

The OutputDataConfig class is a subclass of ``sagemaker_core.shapes.OutputDataConfig``
and allows the user to specify the output data configuration for the training job.

Parameters:
s3_output_path (Optional[str]):
The S3 URI where the output data will be stored. This is the location where the
training job will save its output data, such as model artifacts and logs.
kms_key_id (Optional[str]):
The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that
SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side
encryption.
compression_type (Optional[str]):
The model output compression type. Select `NONE` to output an uncompressed model,
recommended for large model outputs. Defaults to `GZIP`.
"""

s3_output_path: Optional[str] = None
kms_key_id: Optional[str] = None
compression_type: Optional[str] = None


class TensorBoardOutputConfig(shapes.TensorBoardOutputConfig):
"""TensorBoardOutputConfig.

The TensorBoardOutputConfig class is a subclass of ``sagemaker_core.shapes.TensorBoardOutputConfig``
and allows the user to specify the storage locations for the Amazon SageMaker
Debugger TensorBoard.

Parameters:
s3_output_path (Optional[str]):
Path to Amazon S3 storage location for TensorBoard output. If not specified, will
default to
``s3://<default_bucket>/<default_prefix>/<base_job_name>/<job_name>/tensorboard-output``
local_path (Optional[str]):
Path to local storage location for tensorBoard output. Defaults to /opt/ml/output/tensorboard.
"""

s3_output_path: Optional[str] = None
local_path: Optional[str] = "/opt/ml/output/tensorboard"


class CheckpointConfig(shapes.CheckpointConfig):
"""CheckpointConfig.

The CheckpointConfig class is a subclass of ``sagemaker_core.shapes.CheckpointConfig``
and allows the user to specify the checkpoint configuration for the training job.

Parameters:
s3_uri (Optional[str]):
Path to Amazon S3 storage location for the Checkpoint data. If not specified, will
default to
``s3://<default_bucket>/<default_prefix>/<base_job_name>/<job_name>/checkpoints``
local_path (Optional[str]):
The local directory where checkpoints are written. The default directory is /opt/ml/checkpoints.
"""

s3_uri: Optional[str] = None
local_path: Optional[str] = "/opt/ml/checkpoints"
Loading