3030from sagemaker_core .shapes import (
3131 StoppingCondition ,
3232 RetryStrategy ,
33- OutputDataConfig ,
3433 Channel ,
3534 ShuffleConfig ,
3635 DataSource ,
4342 RemoteDebugConfig ,
4443 SessionChainingConfig ,
4544 InstanceGroup ,
46- TensorBoardOutputConfig ,
47- CheckpointConfig ,
4845)
4946
5047from sagemaker .modules .utils import convert_unassigned_to_none
@@ -131,6 +128,8 @@ class Compute(shapes.ResourceConfig):
131128 subsequent training jobs.
132129 instance_groups (Optional[List[InstanceGroup]]):
133130 A list of instance groups for heterogeneous clusters to be used in the training job.
131+ training_plan_arn (Optional[str]):
132+ The Amazon Resource Name (ARN) of the training plan to use for this resource configuration.
134133 enable_managed_spot_training (Optional[bool]):
135134 To train models using managed spot training, choose True. Managed spot training
136135 provides a fully managed and scalable infrastructure for training machine learning
@@ -151,8 +150,12 @@ def _to_resource_config(self) -> shapes.ResourceConfig:
151150 compute_config_dict = self .model_dump ()
152151 resource_config_fields = set (shapes .ResourceConfig .__annotations__ .keys ())
153152 filtered_dict = {
154- k : v for k , v in compute_config_dict .items () if k in resource_config_fields
153+ k : v
154+ for k , v in compute_config_dict .items ()
155+ if k in resource_config_fields and v is not None
155156 }
157+ if not filtered_dict :
158+ return None
156159 return shapes .ResourceConfig (** filtered_dict )
157160
158161
@@ -194,10 +197,12 @@ def _model_validator(self) -> "Networking":
194197 def _to_vpc_config (self ) -> shapes .VpcConfig :
195198 """Convert to a sagemaker_core.shapes.VpcConfig object."""
196199 compute_config_dict = self .model_dump ()
197- resource_config_fields = set (shapes .VpcConfig .__annotations__ .keys ())
200+ vpc_config_fields = set (shapes .VpcConfig .__annotations__ .keys ())
198201 filtered_dict = {
199- k : v for k , v in compute_config_dict .items () if k in resource_config_fields
202+ k : v for k , v in compute_config_dict .items () if k in vpc_config_fields and v is not None
200203 }
204+ if not filtered_dict :
205+ return None
201206 return shapes .VpcConfig (** filtered_dict )
202207
203208
@@ -224,3 +229,66 @@ class InputData(BaseConfig):
224229
225230 channel_name : str = None
226231 data_source : Union [str , FileSystemDataSource , S3DataSource ] = None
232+
233+
234+ class OutputDataConfig (shapes .OutputDataConfig ):
235+ """OutputDataConfig.
236+
237+ The OutputDataConfig class is a subclass of ``sagemaker_core.shapes.OutputDataConfig``
238+ and allows the user to specify the output data configuration for the training job.
239+
240+ Parameters:
241+ s3_output_path (Optional[str]):
242+ The S3 URI where the output data will be stored. This is the location where the
243+ training job will save its output data, such as model artifacts and logs.
244+ kms_key_id (Optional[str]):
245+ The Amazon Web Services Key Management Service (Amazon Web Services KMS) key that
246+ SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side
247+ encryption.
248+ compression_type (Optional[str]):
249+ The model output compression type. Select `NONE` to output an uncompressed model,
250+ recommended for large model outputs. Defaults to `GZIP`.
251+ """
252+
253+ s3_output_path : Optional [str ] = None
254+ kms_key_id : Optional [str ] = None
255+ compression_type : Optional [str ] = None
256+
257+
258+ class TensorBoardOutputConfig (shapes .TensorBoardOutputConfig ):
259+ """TensorBoardOutputConfig.
260+
261+ The TensorBoardOutputConfig class is a subclass of ``sagemaker_core.shapes.TensorBoardOutputConfig``
262+ and allows the user to specify the storage locations for the Amazon SageMaker
263+ Debugger TensorBoard.
264+
265+ Parameters:
266+ s3_output_path (Optional[str]):
267+ Path to Amazon S3 storage location for TensorBoard output. If not specified, will
268+ default to
269+ ``s3://<default_bucket>/<default_prefix>/<base_job_name>/<job_name>/tensorboard-output``
270+ local_path (Optional[str]):
271+ Path to local storage location for tensorBoard output. Defaults to /opt/ml/output/tensorboard.
272+ """
273+
274+ s3_output_path : Optional [str ] = None
275+ local_path : Optional [str ] = "/opt/ml/output/tensorboard"
276+
277+
278+ class CheckpointConfig (shapes .CheckpointConfig ):
279+ """CheckpointConfig.
280+
281+ The CheckpointConfig class is a subclass of ``sagemaker_core.shapes.CheckpointConfig``
282+ and allows the user to specify the checkpoint configuration for the training job.
283+
284+ Parameters:
285+ s3_uri (Optional[str]):
286+ Path to Amazon S3 storage location for the Checkpoint data. If not specified, will
287+ default to
288+ ``s3://<default_bucket>/<default_prefix>/<base_job_name>/<job_name>/checkpoints``
289+ local_path (Optional[str]):
290+ The local directory where checkpoints are written. The default directory is /opt/ml/checkpoints.
291+ """
292+
293+ s3_uri : Optional [str ] = None
294+ local_path : Optional [str ] = "/opt/ml/checkpoints"
0 commit comments