Skip to content

Commit 5f6dd1a

Browse files
committed
make config creation backwards compatible
1 parent 8027183 commit 5f6dd1a

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

src/sagemaker/modules/train/model_trainer.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,9 @@
2525

2626
from sagemaker_core.main import resources
2727
from sagemaker_core.resources import TrainingJob
28+
from sagemaker_core import shapes
2829
from sagemaker_core.shapes import (
29-
AlgorithmSpecification,
30-
OutputDataConfig,
31-
CheckpointConfig,
32-
TensorBoardOutputConfig,
30+
AlgorithmSpecification
3331
)
3432

3533
from pydantic import BaseModel, ConfigDict, PrivateAttr, validate_call
@@ -224,9 +222,9 @@ class ModelTrainer(BaseModel):
224222
training_image: Optional[str] = None
225223
training_image_config: Optional[TrainingImageConfig] = None
226224
algorithm_name: Optional[str] = None
227-
output_data_config: Optional[OutputDataConfig] = None
225+
output_data_config: Optional[shapes.OutputDataConfig] = None
228226
input_data_config: Optional[List[Union[Channel, InputData]]] = None
229-
checkpoint_config: Optional[CheckpointConfig] = None
227+
checkpoint_config: Optional[shapes.CheckpointConfig] = None
230228
training_input_mode: Optional[str] = "File"
231229
environment: Optional[Dict[str, str]] = {}
232230
hyperparameters: Optional[Union[Dict[str, Any], str]] = {}
@@ -237,7 +235,7 @@ class ModelTrainer(BaseModel):
237235
_latest_training_job: Optional[resources.TrainingJob] = PrivateAttr(default=None)
238236

239237
# Private TrainingJob Parameters
240-
_tensorboard_output_config: Optional[TensorBoardOutputConfig] = PrivateAttr(default=None)
238+
_tensorboard_output_config: Optional[shapes.TensorBoardOutputConfig] = PrivateAttr(default=None)
241239
_retry_strategy: Optional[RetryStrategy] = PrivateAttr(default=None)
242240
_infra_check_config: Optional[InfraCheckConfig] = PrivateAttr(default=None)
243241
_session_chaining_config: Optional[SessionChainingConfig] = PrivateAttr(default=None)
@@ -268,8 +266,8 @@ class ModelTrainer(BaseModel):
268266
"networking": Networking,
269267
"stopping_condition": StoppingCondition,
270268
"training_image_config": TrainingImageConfig,
271-
"output_data_config": OutputDataConfig,
272-
"checkpoint_config": CheckpointConfig,
269+
"output_data_config": configs.OutputDataConfig,
270+
"checkpoint_config": configs.CheckpointConfig,
273271
}
274272

275273
def _populate_intelligent_defaults(self):
@@ -321,7 +319,7 @@ def _populate_intelligent_defaults_from_training_job_space(self):
321319
config_path=TRAINING_JOB_OUTPUT_DATA_CONFIG_PATH
322320
)
323321
if default_output_data_config:
324-
self.output_data_config = OutputDataConfig(
322+
self.output_data_config = configs.OutputDataConfig(
325323
**self._convert_keys_to_snake(default_output_data_config)
326324
)
327325

@@ -537,7 +535,7 @@ def model_post_init(self, __context: Any):
537535
if self.output_data_config is None:
538536
session = self.sagemaker_session
539537
base_job_name = self.base_job_name
540-
self.output_data_config = OutputDataConfig(
538+
self.output_data_config = configs.OutputDataConfig(
541539
s3_output_path=f"s3://{self._fetch_bucket_name_and_prefix(session)}"
542540
f"/{base_job_name}",
543541
compression_type="GZIP",
@@ -959,9 +957,9 @@ def from_recipe(
959957
requirements: Optional[str] = None,
960958
training_image: Optional[str] = None,
961959
training_image_config: Optional[TrainingImageConfig] = None,
962-
output_data_config: Optional[OutputDataConfig] = None,
960+
output_data_config: Optional[shapes.OutputDataConfig] = None,
963961
input_data_config: Optional[List[Union[Channel, InputData]]] = None,
964-
checkpoint_config: Optional[CheckpointConfig] = None,
962+
checkpoint_config: Optional[shapes.CheckpointConfig] = None,
965963
training_input_mode: Optional[str] = "File",
966964
environment: Optional[Dict[str, str]] = None,
967965
tags: Optional[List[Tag]] = None,
@@ -1115,7 +1113,7 @@ def from_recipe(
11151113
return model_trainer
11161114

11171115
def with_tensorboard_output_config(
1118-
self, tensorboard_output_config: Optional[TensorBoardOutputConfig] = None
1116+
self, tensorboard_output_config: Optional[shapes.TensorBoardOutputConfig] = None
11191117
) -> "ModelTrainer": # noqa: D412
11201118
"""Set the TensorBoard output configuration.
11211119
@@ -1232,7 +1230,7 @@ def with_remote_debug_config(
12321230
return self
12331231

12341232
def with_checkpoint_config(
1235-
self, checkpoint_config: Optional[CheckpointConfig] = None
1233+
self, checkpoint_config: Optional[shapes.CheckpointConfig] = None
12361234
) -> "ModelTrainer": # noqa: D412
12371235
"""Set the checkpoint configuration for the training job.
12381236

0 commit comments

Comments
 (0)