25
25
26
26
from sagemaker_core .main import resources
27
27
from sagemaker_core .resources import TrainingJob
28
+ from sagemaker_core import shapes
28
29
from sagemaker_core .shapes import (
29
- AlgorithmSpecification ,
30
- OutputDataConfig ,
31
- CheckpointConfig ,
32
- TensorBoardOutputConfig ,
30
+ AlgorithmSpecification
33
31
)
34
32
35
33
from pydantic import BaseModel , ConfigDict , PrivateAttr , validate_call
@@ -224,9 +222,9 @@ class ModelTrainer(BaseModel):
224
222
training_image : Optional [str ] = None
225
223
training_image_config : Optional [TrainingImageConfig ] = None
226
224
algorithm_name : Optional [str ] = None
227
- output_data_config : Optional [OutputDataConfig ] = None
225
+ output_data_config : Optional [shapes . OutputDataConfig ] = None
228
226
input_data_config : Optional [List [Union [Channel , InputData ]]] = None
229
- checkpoint_config : Optional [CheckpointConfig ] = None
227
+ checkpoint_config : Optional [shapes . CheckpointConfig ] = None
230
228
training_input_mode : Optional [str ] = "File"
231
229
environment : Optional [Dict [str , str ]] = {}
232
230
hyperparameters : Optional [Union [Dict [str , Any ], str ]] = {}
@@ -237,7 +235,7 @@ class ModelTrainer(BaseModel):
237
235
_latest_training_job : Optional [resources .TrainingJob ] = PrivateAttr (default = None )
238
236
239
237
# Private TrainingJob Parameters
240
- _tensorboard_output_config : Optional [TensorBoardOutputConfig ] = PrivateAttr (default = None )
238
+ _tensorboard_output_config : Optional [shapes . TensorBoardOutputConfig ] = PrivateAttr (default = None )
241
239
_retry_strategy : Optional [RetryStrategy ] = PrivateAttr (default = None )
242
240
_infra_check_config : Optional [InfraCheckConfig ] = PrivateAttr (default = None )
243
241
_session_chaining_config : Optional [SessionChainingConfig ] = PrivateAttr (default = None )
@@ -268,8 +266,8 @@ class ModelTrainer(BaseModel):
268
266
"networking" : Networking ,
269
267
"stopping_condition" : StoppingCondition ,
270
268
"training_image_config" : TrainingImageConfig ,
271
- "output_data_config" : OutputDataConfig ,
272
- "checkpoint_config" : CheckpointConfig ,
269
+ "output_data_config" : configs . OutputDataConfig ,
270
+ "checkpoint_config" : configs . CheckpointConfig ,
273
271
}
274
272
275
273
def _populate_intelligent_defaults (self ):
@@ -321,7 +319,7 @@ def _populate_intelligent_defaults_from_training_job_space(self):
321
319
config_path = TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH
322
320
)
323
321
if default_output_data_config :
324
- self .output_data_config = OutputDataConfig (
322
+ self .output_data_config = configs . OutputDataConfig (
325
323
** self ._convert_keys_to_snake (default_output_data_config )
326
324
)
327
325
@@ -537,7 +535,7 @@ def model_post_init(self, __context: Any):
537
535
if self .output_data_config is None :
538
536
session = self .sagemaker_session
539
537
base_job_name = self .base_job_name
540
- self .output_data_config = OutputDataConfig (
538
+ self .output_data_config = configs . OutputDataConfig (
541
539
s3_output_path = f"s3://{ self ._fetch_bucket_name_and_prefix (session )} "
542
540
f"/{ base_job_name } " ,
543
541
compression_type = "GZIP" ,
@@ -959,9 +957,9 @@ def from_recipe(
959
957
requirements : Optional [str ] = None ,
960
958
training_image : Optional [str ] = None ,
961
959
training_image_config : Optional [TrainingImageConfig ] = None ,
962
- output_data_config : Optional [OutputDataConfig ] = None ,
960
+ output_data_config : Optional [shapes . OutputDataConfig ] = None ,
963
961
input_data_config : Optional [List [Union [Channel , InputData ]]] = None ,
964
- checkpoint_config : Optional [CheckpointConfig ] = None ,
962
+ checkpoint_config : Optional [shapes . CheckpointConfig ] = None ,
965
963
training_input_mode : Optional [str ] = "File" ,
966
964
environment : Optional [Dict [str , str ]] = None ,
967
965
tags : Optional [List [Tag ]] = None ,
@@ -1115,7 +1113,7 @@ def from_recipe(
1115
1113
return model_trainer
1116
1114
1117
1115
def with_tensorboard_output_config (
1118
- self , tensorboard_output_config : Optional [TensorBoardOutputConfig ] = None
1116
+ self , tensorboard_output_config : Optional [shapes . TensorBoardOutputConfig ] = None
1119
1117
) -> "ModelTrainer" : # noqa: D412
1120
1118
"""Set the TensorBoard output configuration.
1121
1119
@@ -1232,7 +1230,7 @@ def with_remote_debug_config(
1232
1230
return self
1233
1231
1234
1232
def with_checkpoint_config (
1235
- self , checkpoint_config : Optional [CheckpointConfig ] = None
1233
+ self , checkpoint_config : Optional [shapes . CheckpointConfig ] = None
1236
1234
) -> "ModelTrainer" : # noqa: D412
1237
1235
"""Set the checkpoint configuration for the training job.
1238
1236
0 commit comments