2525
2626from sagemaker_core .main import resources
2727from sagemaker_core .resources import TrainingJob
28+ from sagemaker_core import shapes
2829from sagemaker_core .shapes import (
29- AlgorithmSpecification ,
30- OutputDataConfig ,
31- CheckpointConfig ,
32- TensorBoardOutputConfig ,
30+ AlgorithmSpecification
3331)
3432
3533from pydantic import BaseModel , ConfigDict , PrivateAttr , validate_call
@@ -224,9 +222,9 @@ class ModelTrainer(BaseModel):
224222 training_image : Optional [str ] = None
225223 training_image_config : Optional [TrainingImageConfig ] = None
226224 algorithm_name : Optional [str ] = None
227- output_data_config : Optional [OutputDataConfig ] = None
225+ output_data_config : Optional [shapes . OutputDataConfig ] = None
228226 input_data_config : Optional [List [Union [Channel , InputData ]]] = None
229- checkpoint_config : Optional [CheckpointConfig ] = None
227+ checkpoint_config : Optional [shapes . CheckpointConfig ] = None
230228 training_input_mode : Optional [str ] = "File"
231229 environment : Optional [Dict [str , str ]] = {}
232230 hyperparameters : Optional [Union [Dict [str , Any ], str ]] = {}
@@ -237,7 +235,7 @@ class ModelTrainer(BaseModel):
237235 _latest_training_job : Optional [resources .TrainingJob ] = PrivateAttr (default = None )
238236
239237 # Private TrainingJob Parameters
240- _tensorboard_output_config : Optional [TensorBoardOutputConfig ] = PrivateAttr (default = None )
238+ _tensorboard_output_config : Optional [shapes . TensorBoardOutputConfig ] = PrivateAttr (default = None )
241239 _retry_strategy : Optional [RetryStrategy ] = PrivateAttr (default = None )
242240 _infra_check_config : Optional [InfraCheckConfig ] = PrivateAttr (default = None )
243241 _session_chaining_config : Optional [SessionChainingConfig ] = PrivateAttr (default = None )
@@ -268,8 +266,8 @@ class ModelTrainer(BaseModel):
268266 "networking" : Networking ,
269267 "stopping_condition" : StoppingCondition ,
270268 "training_image_config" : TrainingImageConfig ,
271- "output_data_config" : OutputDataConfig ,
272- "checkpoint_config" : CheckpointConfig ,
269+ "output_data_config" : configs . OutputDataConfig ,
270+ "checkpoint_config" : configs . CheckpointConfig ,
273271 }
274272
275273 def _populate_intelligent_defaults (self ):
@@ -321,7 +319,7 @@ def _populate_intelligent_defaults_from_training_job_space(self):
321319 config_path = TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH
322320 )
323321 if default_output_data_config :
324- self .output_data_config = OutputDataConfig (
322+ self .output_data_config = configs . OutputDataConfig (
325323 ** self ._convert_keys_to_snake (default_output_data_config )
326324 )
327325
@@ -537,7 +535,7 @@ def model_post_init(self, __context: Any):
537535 if self .output_data_config is None :
538536 session = self .sagemaker_session
539537 base_job_name = self .base_job_name
540- self .output_data_config = OutputDataConfig (
538+ self .output_data_config = configs . OutputDataConfig (
541539 s3_output_path = f"s3://{ self ._fetch_bucket_name_and_prefix (session )} "
542540 f"/{ base_job_name } " ,
543541 compression_type = "GZIP" ,
@@ -959,9 +957,9 @@ def from_recipe(
959957 requirements : Optional [str ] = None ,
960958 training_image : Optional [str ] = None ,
961959 training_image_config : Optional [TrainingImageConfig ] = None ,
962- output_data_config : Optional [OutputDataConfig ] = None ,
960+ output_data_config : Optional [shapes . OutputDataConfig ] = None ,
963961 input_data_config : Optional [List [Union [Channel , InputData ]]] = None ,
964- checkpoint_config : Optional [CheckpointConfig ] = None ,
962+ checkpoint_config : Optional [shapes . CheckpointConfig ] = None ,
965963 training_input_mode : Optional [str ] = "File" ,
966964 environment : Optional [Dict [str , str ]] = None ,
967965 tags : Optional [List [Tag ]] = None ,
@@ -1115,7 +1113,7 @@ def from_recipe(
11151113 return model_trainer
11161114
11171115 def with_tensorboard_output_config (
1118- self , tensorboard_output_config : Optional [TensorBoardOutputConfig ] = None
1116+ self , tensorboard_output_config : Optional [shapes . TensorBoardOutputConfig ] = None
11191117 ) -> "ModelTrainer" : # noqa: D412
11201118 """Set the TensorBoard output configuration.
11211119
@@ -1232,7 +1230,7 @@ def with_remote_debug_config(
12321230 return self
12331231
12341232 def with_checkpoint_config (
1235- self , checkpoint_config : Optional [CheckpointConfig ] = None
1233+ self , checkpoint_config : Optional [shapes . CheckpointConfig ] = None
12361234 ) -> "ModelTrainer" : # noqa: D412
12371235 """Set the checkpoint configuration for the training job.
12381236
0 commit comments