Skip to content

Commit 405dfd9

Browse files
committed
add tests & update docs
1 parent e265d78 commit 405dfd9

File tree

4 files changed

+161
-25
lines changed

4 files changed

+161
-25
lines changed

pyproject.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,8 @@ dependencies = [
5555
"tblib>=1.7.0,<4",
5656
"tqdm",
5757
"urllib3>=1.26.8,<3.0.0",
58-
"uvicorn"
58+
"uvicorn",
59+
"graphene>=3,<4"
5960
]
6061

6162
[project.scripts]

src/sagemaker/modules/configs.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -258,8 +258,9 @@ class TensorBoardOutputConfig(shapes.TensorBoardOutputConfig):
258258
259259
Parameters:
260260
s3_output_path (Optional[str]):
261-
Path to Amazon S3 storage location for TensorBoard output. If not specified, will default to the default artifact location for the training job.
262-
``s3://<default_bucket>/<default_prefix>/<base_job_name>/<job_name>/``
261+
Path to Amazon S3 storage location for TensorBoard output. If not specified, will
262+
default to
263+
``s3://<default_bucket>/<default_prefix>/<base_job_name>/<job_name>/tensorboard-output``
263264
local_path (Optional[str]):
264265
Path to local storage location for tensorBoard output. Defaults to /opt/ml/output/tensorboard.
265266
"""
@@ -276,8 +277,9 @@ class CheckpointConfig(shapes.CheckpointConfig):
276277
277278
Parameters:
278279
s3_uri (Optional[str]):
279-
Path to Amazon S3 storage location for the Checkpoint data. If not specified, will default to the default artifact location for the training job.
280-
``s3://<default_bucket>/<default_prefix>/<base_job_name>/<job_name>/``
280+
Path to Amazon S3 storage location for the Checkpoint data. If not specified, will
281+
default to
282+
``s3://<default_bucket>/<default_prefix>/<base_job_name>/<job_name>/checkpoints``
281283
local_path (Optional[str]):
282284
The local directory where checkpoints are written. The default directory is /opt/ml/checkpoints.
283285
"""

src/sagemaker/modules/train/model_trainer.py

Lines changed: 119 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ class ModelTrainer(BaseModel):
181181
The output data configuration. This is used to specify the output data location
182182
for the training job.
183183
If not specified in the session, will default to
184-
s3://<default_bucket>/<default_prefix>/<base_job_name>/.
184+
``s3://<default_bucket>/<default_prefix>/<base_job_name>/``.
185185
input_data_config (Optional[List[Union[Channel, InputData]]]):
186186
The input data config for the training job.
187187
Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
@@ -595,8 +595,7 @@ def train(
595595
"""
596596
self._populate_intelligent_defaults()
597597
current_training_job_name = _get_unique_name(self.base_job_name)
598-
default_artifact_path = f"{self.base_job_name}/{current_training_job_name}"
599-
input_data_key_prefix = f"{default_artifact_path}/input"
598+
input_data_key_prefix = f"{self.base_job_name}/{current_training_job_name}/input"
600599
if input_data_config and self.input_data_config:
601600
self.input_data_config = input_data_config
602601
# Add missing input data channels to the existing input_data_config
@@ -613,9 +612,15 @@ def train(
613612
)
614613

615614
if self.checkpoint_config and not self.checkpoint_config.s3_uri:
616-
self.checkpoint_config.s3_uri = f"s3://{self._fetch_bucket_name_and_prefix(self.sagemaker_session)}/{default_artifact_path}"
617-
if self._tensorboard_output_config and not self._tensorboard_output_config.s3_uri:
618-
self._tensorboard_output_config.s3_uri = f"s3://{self._fetch_bucket_name_and_prefix(self.sagemaker_session)}/{default_artifact_path}"
615+
self.checkpoint_config.s3_uri = (
616+
f"s3://{self._fetch_bucket_name_and_prefix(self.sagemaker_session)}/"
617+
f"{self.base_job_name}/{current_training_job_name}/checkpoints"
618+
)
619+
if self._tensorboard_output_config and not self._tensorboard_output_config.s3_output_path:
620+
self._tensorboard_output_config.s3_output_path = (
621+
f"s3://{self._fetch_bucket_name_and_prefix(self.sagemaker_session)}/"
622+
f"{self.base_job_name}"
623+
)
619624

620625
string_hyper_parameters = {}
621626
if self.hyperparameters:
@@ -646,7 +651,7 @@ def train(
646651
data_source=self.source_code.source_dir,
647652
key_prefix=input_data_key_prefix,
648653
)
649-
input_data_config.append(source_code_channel)
654+
self.input_data_config.append(source_code_channel)
650655

651656
self._prepare_train_script(
652657
tmp_dir=tmp_dir,
@@ -667,7 +672,7 @@ def train(
667672
data_source=tmp_dir.name,
668673
key_prefix=input_data_key_prefix,
669674
)
670-
input_data_config.append(sm_drivers_channel)
675+
self.input_data_config.append(sm_drivers_channel)
671676

672677
# If source_code is provided, we will always use
673678
# the default container entrypoint and arguments
@@ -970,10 +975,43 @@ def from_recipe(
970975
) -> "ModelTrainer":
971976
"""Create a ModelTrainer from a training recipe.
972977
978+
Example:
979+
980+
.. code:: python
981+
982+
from sagemaker.modules.train import ModelTrainer
983+
from sagemaker.modules.configs import Compute
984+
985+
recipe_overrides = {
986+
"run": {
987+
"results_dir": "/opt/ml/model",
988+
},
989+
"model": {
990+
"data": {
991+
"use_synthetic_data": True
992+
}
993+
}
994+
}
995+
996+
compute = Compute(
997+
instance_type="ml.p5.48xlarge",
998+
keep_alive_period_in_seconds=3600
999+
)
1000+
1001+
model_trainer = ModelTrainer.from_recipe(
1002+
training_recipe="fine-tuning/deepseek/hf_deepseek_r1_distilled_llama_8b_seq8k_gpu_fine_tuning",
1003+
recipe_overrides=recipe_overrides,
1004+
compute=compute,
1005+
)
1006+
1007+
model_trainer.train(wait=False)
1008+
1009+
9731010
Args:
9741011
training_recipe (str):
9751012
The training recipe to use for training the model. This must be the name of
9761013
a sagemaker training recipe or a path to a local training recipe .yaml file.
1014+
For available training recipes, see: https://github.com/aws/sagemaker-hyperpod-recipes/
9771015
compute (Compute):
9781016
The compute configuration. This is used to specify the compute resources for
9791017
the training job. If not specified, will default to 1 instance of ml.m5.xlarge.
@@ -1081,55 +1119,116 @@ def from_recipe(
10811119
return model_trainer
10821120

10831121
def with_tensorboard_output_config(
1084-
self, tensorboard_output_config: TensorBoardOutputConfig
1122+
self, tensorboard_output_config: Optional[TensorBoardOutputConfig] = None
10851123
) -> "ModelTrainer":
10861124
"""Set the TensorBoard output configuration.
10871125
1126+
Example:
1127+
1128+
.. code:: python
1129+
1130+
from sagemaker.modules.train import ModelTrainer
1131+
1132+
model_trainer = ModelTrainer(
1133+
...
1134+
).with_tensorboard_output_config()
1135+
10881136
Args:
10891137
tensorboard_output_config (sagemaker.modules.configs.TensorBoardOutputConfig):
10901138
The TensorBoard output configuration.
10911139
"""
1092-
self._tensorboard_output_config = tensorboard_output_config
1140+
self._tensorboard_output_config = tensorboard_output_config or TensorBoardOutputConfig()
10931141
return self
10941142

10951143
def with_retry_strategy(self, retry_strategy: RetryStrategy) -> "ModelTrainer":
10961144
"""Set the retry strategy for the training job.
10971145
1146+
Example:
1147+
1148+
.. code:: python
1149+
1150+
from sagemaker.modules.train import ModelTrainer
1151+
from sagemaker.modules.configs import RetryStrategy
1152+
1153+
retry_strategy = RetryStrategy(maximum_retry_attempts=3)
1154+
1155+
model_trainer = ModelTrainer(
1156+
...
1157+
).with_retry_strategy(retry_strategy)
1158+
10981159
Args:
1099-
retry_strategy (RetryStrategy):
1160+
retry_strategy (sagemaker.modules.configs.RetryStrategy):
11001161
The retry strategy for the training job.
11011162
"""
11021163
self._retry_strategy = retry_strategy
11031164
return self
11041165

1105-
def with_infra_check_config(self, infra_check_config: InfraCheckConfig) -> "ModelTrainer":
1166+
def with_infra_check_config(
1167+
self, infra_check_config: Optional[InfraCheckConfig] = None
1168+
) -> "ModelTrainer":
11061169
"""Set the infra check configuration for the training job.
11071170
1171+
Example:
1172+
1173+
.. code:: python
1174+
1175+
from sagemaker.modules.train import ModelTrainer
1176+
1177+
model_trainer = ModelTrainer(
1178+
...
1179+
).with_infra_check_config()
1180+
11081181
Args:
1109-
infra_check_config (InfraCheckConfig):
1182+
infra_check_config (sagemaker.modules.configs.InfraCheckConfig):
11101183
The infra check configuration for the training job.
11111184
"""
1112-
self._infra_check_config = infra_check_config
1185+
self._infra_check_config = infra_check_config or InfraCheckConfig(enable_infra_check=True)
11131186
return self
11141187

11151188
def with_session_chaining_config(
1116-
self, session_chaining_config: SessionChainingConfig
1189+
self, session_chaining_config: Optional[SessionChainingConfig] = None
11171190
) -> "ModelTrainer":
11181191
"""Set the session chaining configuration for the training job.
11191192
1193+
Example:
1194+
1195+
.. code:: python
1196+
1197+
from sagemaker.modules.train import ModelTrainer
1198+
1199+
model_trainer = ModelTrainer(
1200+
...
1201+
).with_session_chaining_config()
1202+
11201203
Args:
1121-
session_chaining_config (SessionChainingConfig):
1204+
session_chaining_config (sagemaker.modules.configs.SessionChainingConfig):
11221205
The session chaining configuration for the training job.
11231206
"""
1124-
self._session_chaining_config = session_chaining_config
1207+
self._session_chaining_config = session_chaining_config or SessionChainingConfig(
1208+
enable_session_tag_chaining=True
1209+
)
11251210
return self
11261211

1127-
def with_remote_debug_config(self, remote_debug_config: RemoteDebugConfig) -> "ModelTrainer":
1212+
def with_remote_debug_config(
1213+
self, remote_debug_config: Optional[RemoteDebugConfig] = None
1214+
) -> "ModelTrainer":
11281215
"""Set the remote debug configuration for the training job.
11291216
1217+
Example:
1218+
1219+
.. code:: python
1220+
1221+
from sagemaker.modules.train import ModelTrainer
1222+
1223+
model_trainer = ModelTrainer(
1224+
...
1225+
).with_remote_debug_config()
1226+
11301227
Args:
1131-
remote_debug_config (RemoteDebugConfig):
1228+
remote_debug_config (sagemaker.modules.configs.RemoteDebugConfig):
11321229
The remote debug configuration for the training job.
11331230
"""
1134-
self._remote_debug_config = remote_debug_config
1231+
self._remote_debug_config = remote_debug_config or RemoteDebugConfig(
1232+
enable_remote_debug=True
1233+
)
11351234
return self

tests/unit/sagemaker/modules/train/test_model_trainer.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,3 +1228,37 @@ def test_hyperparameters_invalid(mock_exists, modules_session):
12281228
compute=DEFAULT_COMPUTE_CONFIG,
12291229
hyperparameters="hyperparameters.yaml",
12301230
)
1231+
1232+
1233+
@patch("sagemaker.modules.train.model_trainer._get_unique_name")
1234+
@patch("sagemaker.modules.train.model_trainer.TrainingJob")
1235+
def test_model_trainer_default_paths(mock_training_job, mock_unique_name, modules_session):
1236+
def mock_upload_data(path, bucket, key_prefix):
1237+
return f"s3://{bucket}/{key_prefix}"
1238+
1239+
unique_name = "base-job-0123456789"
1240+
base_name = "base-job"
1241+
1242+
modules_session.upload_data.side_effect = mock_upload_data
1243+
mock_unique_name.return_value = unique_name
1244+
1245+
model_trainer = ModelTrainer(
1246+
training_image=DEFAULT_IMAGE,
1247+
sagemaker_session=modules_session,
1248+
checkpoint_config=CheckpointConfig(),
1249+
base_job_name=base_name,
1250+
).with_tensorboard_output_config(TensorBoardOutputConfig())
1251+
model_trainer.train()
1252+
1253+
_, kwargs = mock_training_job.create.call_args
1254+
1255+
default_base_path = f"s3://{DEFAULT_BUCKET}/{DEFAULT_BUCKET_PREFIX}/{base_name}"
1256+
1257+
assert kwargs["output_data_config"].s3_output_path == default_base_path
1258+
assert kwargs["output_data_config"].compression_type == "GZIP"
1259+
1260+
assert kwargs["checkpoint_config"].s3_uri == f"{default_base_path}/{unique_name}/checkpoints"
1261+
assert kwargs["checkpoint_config"].local_path == "/opt/ml/checkpoints"
1262+
1263+
assert kwargs["tensor_board_output_config"].s3_output_path == default_base_path
1264+
assert kwargs["tensor_board_output_config"].local_path == "/opt/ml/output/tensorboard"

0 commit comments

Comments
 (0)