@@ -181,7 +181,7 @@ class ModelTrainer(BaseModel):
181
181
The output data configuration. This is used to specify the output data location
182
182
for the training job.
183
183
If not specified in the session, will default to
184
- s3://<default_bucket>/<default_prefix>/<base_job_name>/.
184
+ `` s3://<default_bucket>/<default_prefix>/<base_job_name>/`` .
185
185
input_data_config (Optional[List[Union[Channel, InputData]]]):
186
186
The input data config for the training job.
187
187
Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
@@ -595,8 +595,7 @@ def train(
595
595
"""
596
596
self ._populate_intelligent_defaults ()
597
597
current_training_job_name = _get_unique_name (self .base_job_name )
598
- default_artifact_path = f"{ self .base_job_name } /{ current_training_job_name } "
599
- input_data_key_prefix = f"{ default_artifact_path } /input"
598
+ input_data_key_prefix = f"{ self .base_job_name } /{ current_training_job_name } /input"
600
599
if input_data_config and self .input_data_config :
601
600
self .input_data_config = input_data_config
602
601
# Add missing input data channels to the existing input_data_config
@@ -613,9 +612,15 @@ def train(
613
612
)
614
613
615
614
if self .checkpoint_config and not self .checkpoint_config .s3_uri :
616
- self .checkpoint_config .s3_uri = f"s3://{ self ._fetch_bucket_name_and_prefix (self .sagemaker_session )} /{ default_artifact_path } "
617
- if self ._tensorboard_output_config and not self ._tensorboard_output_config .s3_uri :
618
- self ._tensorboard_output_config .s3_uri = f"s3://{ self ._fetch_bucket_name_and_prefix (self .sagemaker_session )} /{ default_artifact_path } "
615
+ self .checkpoint_config .s3_uri = (
616
+ f"s3://{ self ._fetch_bucket_name_and_prefix (self .sagemaker_session )} /"
617
+ f"{ self .base_job_name } /{ current_training_job_name } /checkpoints"
618
+ )
619
+ if self ._tensorboard_output_config and not self ._tensorboard_output_config .s3_output_path :
620
+ self ._tensorboard_output_config .s3_output_path = (
621
+ f"s3://{ self ._fetch_bucket_name_and_prefix (self .sagemaker_session )} /"
622
+ f"{ self .base_job_name } "
623
+ )
619
624
620
625
string_hyper_parameters = {}
621
626
if self .hyperparameters :
@@ -646,7 +651,7 @@ def train(
646
651
data_source = self .source_code .source_dir ,
647
652
key_prefix = input_data_key_prefix ,
648
653
)
649
- input_data_config .append (source_code_channel )
654
+ self . input_data_config .append (source_code_channel )
650
655
651
656
self ._prepare_train_script (
652
657
tmp_dir = tmp_dir ,
@@ -667,7 +672,7 @@ def train(
667
672
data_source = tmp_dir .name ,
668
673
key_prefix = input_data_key_prefix ,
669
674
)
670
- input_data_config .append (sm_drivers_channel )
675
+ self . input_data_config .append (sm_drivers_channel )
671
676
672
677
# If source_code is provided, we will always use
673
678
# the default container entrypoint and arguments
@@ -970,10 +975,43 @@ def from_recipe(
970
975
) -> "ModelTrainer" :
971
976
"""Create a ModelTrainer from a training recipe.
972
977
978
+ Example:
979
+
980
+ .. code:: python
981
+
982
+ from sagemaker.modules.train import ModelTrainer
983
+ from sagemaker.modules.configs import Compute
984
+
985
+ recipe_overrides = {
986
+ "run": {
987
+ "results_dir": "/opt/ml/model",
988
+ },
989
+ "model": {
990
+ "data": {
991
+ "use_synthetic_data": True
992
+ }
993
+ }
994
+ }
995
+
996
+ compute = Compute(
997
+ instance_type="ml.p5.48xlarge",
998
+ keep_alive_period_in_seconds=3600
999
+ )
1000
+
1001
+ model_trainer = ModelTrainer.from_recipe(
1002
+ training_recipe="fine-tuning/deepseek/hf_deepseek_r1_distilled_llama_8b_seq8k_gpu_fine_tuning",
1003
+ recipe_overrides=recipe_overrides,
1004
+ compute=compute,
1005
+ )
1006
+
1007
+ model_trainer.train(wait=False)
1008
+
1009
+
973
1010
Args:
974
1011
training_recipe (str):
975
1012
The training recipe to use for training the model. This must be the name of
976
1013
a sagemaker training recipe or a path to a local training recipe .yaml file.
1014
+ For available training recipes, see: https://github.com/aws/sagemaker-hyperpod-recipes/
977
1015
compute (Compute):
978
1016
The compute configuration. This is used to specify the compute resources for
979
1017
the training job. If not specified, will default to 1 instance of ml.m5.xlarge.
@@ -1081,55 +1119,116 @@ def from_recipe(
1081
1119
return model_trainer
1082
1120
1083
1121
def with_tensorboard_output_config (
1084
- self , tensorboard_output_config : TensorBoardOutputConfig
1122
+ self , tensorboard_output_config : Optional [ TensorBoardOutputConfig ] = None
1085
1123
) -> "ModelTrainer" :
1086
1124
"""Set the TensorBoard output configuration.
1087
1125
1126
+ Example:
1127
+
1128
+ .. code:: python
1129
+
1130
+ from sagemaker.modules.train import ModelTrainer
1131
+
1132
+ model_trainer = ModelTrainer(
1133
+ ...
1134
+ ).with_tensorboard_output_config()
1135
+
1088
1136
Args:
1089
1137
tensorboard_output_config (sagemaker.modules.configs.TensorBoardOutputConfig):
1090
1138
The TensorBoard output configuration.
1091
1139
"""
1092
- self ._tensorboard_output_config = tensorboard_output_config
1140
+ self ._tensorboard_output_config = tensorboard_output_config or TensorBoardOutputConfig ()
1093
1141
return self
1094
1142
1095
1143
def with_retry_strategy (self , retry_strategy : RetryStrategy ) -> "ModelTrainer" :
1096
1144
"""Set the retry strategy for the training job.
1097
1145
1146
+ Example:
1147
+
1148
+ .. code:: python
1149
+
1150
+ from sagemaker.modules.train import ModelTrainer
1151
+ from sagemaker.modules.configs import RetryStrategy
1152
+
1153
+ retry_strategy = RetryStrategy(maximum_retry_attempts=3)
1154
+
1155
+ model_trainer = ModelTrainer(
1156
+ ...
1157
+ ).with_retry_strategy(retry_strategy)
1158
+
1098
1159
Args:
1099
- retry_strategy (RetryStrategy):
1160
+ retry_strategy (sagemaker.modules.configs. RetryStrategy):
1100
1161
The retry strategy for the training job.
1101
1162
"""
1102
1163
self ._retry_strategy = retry_strategy
1103
1164
return self
1104
1165
1105
- def with_infra_check_config (self , infra_check_config : InfraCheckConfig ) -> "ModelTrainer" :
1166
+ def with_infra_check_config (
1167
+ self , infra_check_config : Optional [InfraCheckConfig ] = None
1168
+ ) -> "ModelTrainer" :
1106
1169
"""Set the infra check configuration for the training job.
1107
1170
1171
+ Example:
1172
+
1173
+ .. code:: python
1174
+
1175
+ from sagemaker.modules.train import ModelTrainer
1176
+
1177
+ model_trainer = ModelTrainer(
1178
+ ...
1179
+ ).with_infra_check_config()
1180
+
1108
1181
Args:
1109
- infra_check_config (InfraCheckConfig):
1182
+ infra_check_config (sagemaker.modules.configs. InfraCheckConfig):
1110
1183
The infra check configuration for the training job.
1111
1184
"""
1112
- self ._infra_check_config = infra_check_config
1185
+ self ._infra_check_config = infra_check_config or InfraCheckConfig ( enable_infra_check = True )
1113
1186
return self
1114
1187
1115
1188
def with_session_chaining_config (
1116
- self , session_chaining_config : SessionChainingConfig
1189
+ self , session_chaining_config : Optional [ SessionChainingConfig ] = None
1117
1190
) -> "ModelTrainer" :
1118
1191
"""Set the session chaining configuration for the training job.
1119
1192
1193
+ Example:
1194
+
1195
+ .. code:: python
1196
+
1197
+ from sagemaker.modules.train import ModelTrainer
1198
+
1199
+ model_trainer = ModelTrainer(
1200
+ ...
1201
+ ).with_session_chaining_config()
1202
+
1120
1203
Args:
1121
- session_chaining_config (SessionChainingConfig):
1204
+ session_chaining_config (sagemaker.modules.configs. SessionChainingConfig):
1122
1205
The session chaining configuration for the training job.
1123
1206
"""
1124
- self ._session_chaining_config = session_chaining_config
1207
+ self ._session_chaining_config = session_chaining_config or SessionChainingConfig (
1208
+ enable_session_tag_chaining = True
1209
+ )
1125
1210
return self
1126
1211
1127
- def with_remote_debug_config (self , remote_debug_config : RemoteDebugConfig ) -> "ModelTrainer" :
1212
+ def with_remote_debug_config (
1213
+ self , remote_debug_config : Optional [RemoteDebugConfig ] = None
1214
+ ) -> "ModelTrainer" :
1128
1215
"""Set the remote debug configuration for the training job.
1129
1216
1217
+ Example:
1218
+
1219
+ .. code:: python
1220
+
1221
+ from sagemaker.modules.train import ModelTrainer
1222
+
1223
+ model_trainer = ModelTrainer(
1224
+ ...
1225
+ ).with_remote_debug_config()
1226
+
1130
1227
Args:
1131
- remote_debug_config (RemoteDebugConfig):
1228
+ remote_debug_config (sagemaker.modules.configs. RemoteDebugConfig):
1132
1229
The remote debug configuration for the training job.
1133
1230
"""
1134
- self ._remote_debug_config = remote_debug_config
1231
+ self ._remote_debug_config = remote_debug_config or RemoteDebugConfig (
1232
+ enable_remote_debug = True
1233
+ )
1135
1234
return self
0 commit comments