@@ -181,7 +181,7 @@ class ModelTrainer(BaseModel):
181181 The output data configuration. This is used to specify the output data location
182182 for the training job.
183183 If not specified in the session, will default to
184- s3://<default_bucket>/<default_prefix>/<base_job_name>/.
184+ `` s3://<default_bucket>/<default_prefix>/<base_job_name>/`` .
185185 input_data_config (Optional[List[Union[Channel, InputData]]]):
186186 The input data config for the training job.
187187 Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
@@ -595,8 +595,7 @@ def train(
595595 """
596596 self ._populate_intelligent_defaults ()
597597 current_training_job_name = _get_unique_name (self .base_job_name )
598- default_artifact_path = f"{ self .base_job_name } /{ current_training_job_name } "
599- input_data_key_prefix = f"{ default_artifact_path } /input"
598+ input_data_key_prefix = f"{ self .base_job_name } /{ current_training_job_name } /input"
600599 if input_data_config and self .input_data_config :
601600 self .input_data_config = input_data_config
602601 # Add missing input data channels to the existing input_data_config
@@ -613,9 +612,15 @@ def train(
613612 )
614613
615614 if self .checkpoint_config and not self .checkpoint_config .s3_uri :
616- self .checkpoint_config .s3_uri = f"s3://{ self ._fetch_bucket_name_and_prefix (self .sagemaker_session )} /{ default_artifact_path } "
617- if self ._tensorboard_output_config and not self ._tensorboard_output_config .s3_uri :
618- self ._tensorboard_output_config .s3_uri = f"s3://{ self ._fetch_bucket_name_and_prefix (self .sagemaker_session )} /{ default_artifact_path } "
615+ self .checkpoint_config .s3_uri = (
616+ f"s3://{ self ._fetch_bucket_name_and_prefix (self .sagemaker_session )} /"
617+ f"{ self .base_job_name } /{ current_training_job_name } /checkpoints"
618+ )
619+ if self ._tensorboard_output_config and not self ._tensorboard_output_config .s3_output_path :
620+ self ._tensorboard_output_config .s3_output_path = (
621+ f"s3://{ self ._fetch_bucket_name_and_prefix (self .sagemaker_session )} /"
622+ f"{ self .base_job_name } "
623+ )
619624
620625 string_hyper_parameters = {}
621626 if self .hyperparameters :
@@ -646,7 +651,7 @@ def train(
646651 data_source = self .source_code .source_dir ,
647652 key_prefix = input_data_key_prefix ,
648653 )
649- input_data_config .append (source_code_channel )
654+ self . input_data_config .append (source_code_channel )
650655
651656 self ._prepare_train_script (
652657 tmp_dir = tmp_dir ,
@@ -667,7 +672,7 @@ def train(
667672 data_source = tmp_dir .name ,
668673 key_prefix = input_data_key_prefix ,
669674 )
670- input_data_config .append (sm_drivers_channel )
675+ self . input_data_config .append (sm_drivers_channel )
671676
672677 # If source_code is provided, we will always use
673678 # the default container entrypoint and arguments
@@ -970,10 +975,43 @@ def from_recipe(
970975 ) -> "ModelTrainer" :
971976 """Create a ModelTrainer from a training recipe.
972977
978+ Example:
979+
980+ .. code:: python
981+
982+ from sagemaker.modules.train import ModelTrainer
983+ from sagemaker.modules.configs import Compute
984+
985+ recipe_overrides = {
986+ "run": {
987+ "results_dir": "/opt/ml/model",
988+ },
989+ "model": {
990+ "data": {
991+ "use_synthetic_data": True
992+ }
993+ }
994+ }
995+
996+ compute = Compute(
997+ instance_type="ml.p5.48xlarge",
998+ keep_alive_period_in_seconds=3600
999+ )
1000+
1001+ model_trainer = ModelTrainer.from_recipe(
1002+ training_recipe="fine-tuning/deepseek/hf_deepseek_r1_distilled_llama_8b_seq8k_gpu_fine_tuning",
1003+ recipe_overrides=recipe_overrides,
1004+ compute=compute,
1005+ )
1006+
1007+ model_trainer.train(wait=False)
1008+
1009+
9731010 Args:
9741011 training_recipe (str):
9751012 The training recipe to use for training the model. This must be the name of
9761013 a sagemaker training recipe or a path to a local training recipe .yaml file.
1014+ For available training recipes, see: https://github.com/aws/sagemaker-hyperpod-recipes/
9771015 compute (Compute):
9781016 The compute configuration. This is used to specify the compute resources for
9791017 the training job. If not specified, will default to 1 instance of ml.m5.xlarge.
@@ -1081,55 +1119,116 @@ def from_recipe(
10811119 return model_trainer
10821120
10831121 def with_tensorboard_output_config (
1084- self , tensorboard_output_config : TensorBoardOutputConfig
1122+ self , tensorboard_output_config : Optional [ TensorBoardOutputConfig ] = None
10851123 ) -> "ModelTrainer" :
10861124 """Set the TensorBoard output configuration.
10871125
1126+ Example:
1127+
1128+ .. code:: python
1129+
1130+ from sagemaker.modules.train import ModelTrainer
1131+
1132+ model_trainer = ModelTrainer(
1133+ ...
1134+ ).with_tensorboard_output_config()
1135+
10881136 Args:
10891137 tensorboard_output_config (sagemaker.modules.configs.TensorBoardOutputConfig):
10901138 The TensorBoard output configuration.
10911139 """
1092- self ._tensorboard_output_config = tensorboard_output_config
1140+ self ._tensorboard_output_config = tensorboard_output_config or TensorBoardOutputConfig ()
10931141 return self
10941142
10951143 def with_retry_strategy (self , retry_strategy : RetryStrategy ) -> "ModelTrainer" :
10961144 """Set the retry strategy for the training job.
10971145
1146+ Example:
1147+
1148+ .. code:: python
1149+
1150+ from sagemaker.modules.train import ModelTrainer
1151+ from sagemaker.modules.configs import RetryStrategy
1152+
1153+ retry_strategy = RetryStrategy(maximum_retry_attempts=3)
1154+
1155+ model_trainer = ModelTrainer(
1156+ ...
1157+ ).with_retry_strategy(retry_strategy)
1158+
10981159 Args:
1099- retry_strategy (RetryStrategy):
1160+ retry_strategy (sagemaker.modules.configs. RetryStrategy):
11001161 The retry strategy for the training job.
11011162 """
11021163 self ._retry_strategy = retry_strategy
11031164 return self
11041165
1105- def with_infra_check_config (self , infra_check_config : InfraCheckConfig ) -> "ModelTrainer" :
1166+ def with_infra_check_config (
1167+ self , infra_check_config : Optional [InfraCheckConfig ] = None
1168+ ) -> "ModelTrainer" :
11061169 """Set the infra check configuration for the training job.
11071170
1171+ Example:
1172+
1173+ .. code:: python
1174+
1175+ from sagemaker.modules.train import ModelTrainer
1176+
1177+ model_trainer = ModelTrainer(
1178+ ...
1179+ ).with_infra_check_config()
1180+
11081181 Args:
1109- infra_check_config (InfraCheckConfig):
1182+ infra_check_config (sagemaker.modules.configs. InfraCheckConfig):
11101183 The infra check configuration for the training job.
11111184 """
1112- self ._infra_check_config = infra_check_config
1185+ self ._infra_check_config = infra_check_config or InfraCheckConfig ( enable_infra_check = True )
11131186 return self
11141187
11151188 def with_session_chaining_config (
1116- self , session_chaining_config : SessionChainingConfig
1189+ self , session_chaining_config : Optional [ SessionChainingConfig ] = None
11171190 ) -> "ModelTrainer" :
11181191 """Set the session chaining configuration for the training job.
11191192
1193+ Example:
1194+
1195+ .. code:: python
1196+
1197+ from sagemaker.modules.train import ModelTrainer
1198+
1199+ model_trainer = ModelTrainer(
1200+ ...
1201+ ).with_session_chaining_config()
1202+
11201203 Args:
1121- session_chaining_config (SessionChainingConfig):
1204+ session_chaining_config (sagemaker.modules.configs. SessionChainingConfig):
11221205 The session chaining configuration for the training job.
11231206 """
1124- self ._session_chaining_config = session_chaining_config
1207+ self ._session_chaining_config = session_chaining_config or SessionChainingConfig (
1208+ enable_session_tag_chaining = True
1209+ )
11251210 return self
11261211
1127- def with_remote_debug_config (self , remote_debug_config : RemoteDebugConfig ) -> "ModelTrainer" :
1212+ def with_remote_debug_config (
1213+ self , remote_debug_config : Optional [RemoteDebugConfig ] = None
1214+ ) -> "ModelTrainer" :
11281215 """Set the remote debug configuration for the training job.
11291216
1217+ Example:
1218+
1219+ .. code:: python
1220+
1221+ from sagemaker.modules.train import ModelTrainer
1222+
1223+ model_trainer = ModelTrainer(
1224+ ...
1225+ ).with_remote_debug_config()
1226+
11301227 Args:
1131- remote_debug_config (RemoteDebugConfig):
1228+ remote_debug_config (sagemaker.modules.configs. RemoteDebugConfig):
11321229 The remote debug configuration for the training job.
11331230 """
1134- self ._remote_debug_config = remote_debug_config
1231+ self ._remote_debug_config = remote_debug_config or RemoteDebugConfig (
1232+ enable_remote_debug = True
1233+ )
11351234 return self
0 commit comments