Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ dependencies = [
"tblib>=1.7.0,<4",
"tqdm",
"urllib3>=1.26.8,<3.0.0",
"uvicorn"
"uvicorn",
"graphene>=3,<4"
]

[project.scripts]
Expand Down
68 changes: 65 additions & 3 deletions src/sagemaker/modules/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
from sagemaker_core.shapes import (
StoppingCondition,
RetryStrategy,
OutputDataConfig,
Channel,
ShuffleConfig,
DataSource,
Expand All @@ -43,8 +42,6 @@
RemoteDebugConfig,
SessionChainingConfig,
InstanceGroup,
TensorBoardOutputConfig,
CheckpointConfig,
)

from sagemaker.modules.utils import convert_unassigned_to_none
Expand Down Expand Up @@ -131,6 +128,8 @@ class Compute(shapes.ResourceConfig):
subsequent training jobs.
instance_groups (Optional[List[InstanceGroup]]):
A list of instance groups for heterogeneous clusters to be used in the training job.
training_plan_arn (Optional[str]):
The Amazon Resource Name (ARN) of the training plan to use for this resource configuration.
enable_managed_spot_training (Optional[bool]):
To train models using managed spot training, choose True. Managed spot training
provides a fully managed and scalable infrastructure for training machine learning
Expand Down Expand Up @@ -224,3 +223,66 @@ class InputData(BaseConfig):

channel_name: str = None
data_source: Union[str, FileSystemDataSource, S3DataSource] = None


class OutputDataConfig(shapes.OutputDataConfig):
"""OutputDataConfig.

The OutputDataConfig class is a subclass of ``sagemaker_core.shapes.OutputDataConfig``
and allows the user to specify the output data configuration for the training job.

Parameters:
s3_output_path (Optional[str]):
The S3 URI where the output data will be stored. This is the location where the
training job will save its output data, such as model artifacts and logs.
kms_key_id (Optional[str]):
The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that
SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side
encryption.
compression_type (Optional[str]):
The model output compression type. Select `NONE` to output an uncompressed model,
recommended for large model outputs. Defaults to `GZIP`.
"""

s3_output_path: Optional[str] = None
kms_key_id: Optional[str] = None
compression_type: Optional[str] = None


class TensorBoardOutputConfig(shapes.TensorBoardOutputConfig):
"""TensorBoardOutputConfig.

The TensorBoardOutputConfig class is a subclass of ``sagemaker_core.shapes.TensorBoardOutputConfig``
and allows the user to specify the storage locations for the Amazon SageMaker
Debugger TensorBoard.

Parameters:
s3_output_path (Optional[str]):
Path to Amazon S3 storage location for TensorBoard output. If not specified, will
default to
``s3://<default_bucket>/<default_prefix>/<base_job_name>/<job_name>/tensorboard-output``
local_path (Optional[str]):
Path to local storage location for tensorBoard output. Defaults to /opt/ml/output/tensorboard.
"""

s3_output_path: Optional[str] = None
local_path: Optional[str] = "/opt/ml/output/tensorboard"


class CheckpointConfig(shapes.CheckpointConfig):
"""CheckpointConfig.

The CheckpointConfig class is a subclass of ``sagemaker_core.shapes.CheckpointConfig``
and allows the user to specify the checkpoint configuration for the training job.

Parameters:
s3_uri (Optional[str]):
Path to Amazon S3 storage location for the Checkpoint data. If not specified, will
default to
``s3://<default_bucket>/<default_prefix>/<base_job_name>/<job_name>/checkpoints``
local_path (Optional[str]):
The local directory where checkpoints are written. The default directory is /opt/ml/checkpoints.
"""

s3_uri: Optional[str] = None
local_path: Optional[str] = "/opt/ml/checkpoints"
Loading
Loading