@@ -181,7 +181,7 @@ class ModelTrainer(BaseModel):
181181 The output data configuration. This is used to specify the output data location
182182 for the training job.
183183 If not specified in the session, will default to
184- `` s3://<default_bucket>/<default_prefix>/<base_job_name>/`` .
184+ s3://<default_bucket>/<default_prefix>/<base_job_name>/.
185185 input_data_config (Optional[List[Union[Channel, InputData]]]):
186186 The input data config for the training job.
187187 Takes a list of Channel or InputData objects. An InputDataSource can be an S3 URI
@@ -477,6 +477,20 @@ def model_post_init(self, __context: Any):
477477 )
478478 logger .warning (f"Compute not provided. Using default:\n { self .compute } " )
479479
480+ if self .compute .instance_type is None :
481+ self .compute .instance_type = DEFAULT_INSTANCE_TYPE
482+ logger .warning (f"Instance type not provided. Using default:\n { DEFAULT_INSTANCE_TYPE } " )
483+ if self .compute .instance_count is None :
484+ self .compute .instance_count = 1
485+ logger .warning (
486+ f"Instance count not provided. Using default:\n { self .compute .instance_count } "
487+ )
488+ if self .compute .volume_size_in_gb is None :
489+ self .compute .volume_size_in_gb = 30
490+ logger .warning (
491+ f"Volume size not provided. Using default:\n { self .compute .volume_size_in_gb } "
492+ )
493+
480494 if self .stopping_condition is None :
481495 self .stopping_condition = StoppingCondition (
482496 max_runtime_in_seconds = 3600 ,
@@ -486,6 +500,12 @@ def model_post_init(self, __context: Any):
486500 logger .warning (
487501 f"StoppingCondition not provided. Using default:\n { self .stopping_condition } "
488502 )
503+ if self .stopping_condition .max_runtime_in_seconds is None :
504+ self .stopping_condition .max_runtime_in_seconds = 3600
505+ logger .info (
506+ "Max runtime not provided. Using default:\n "
507+ f"{ self .stopping_condition .max_runtime_in_seconds } "
508+ )
489509
490510 if self .hyperparameters and isinstance (self .hyperparameters , str ):
491511 if not os .path .exists (self .hyperparameters ):
@@ -511,23 +531,40 @@ def model_post_init(self, __context: Any):
511531 )
512532
513533 if self .training_mode == Mode .SAGEMAKER_TRAINING_JOB and self .output_data_config is None :
514- session = self .sagemaker_session
515- base_job_name = self .base_job_name
516- self .output_data_config = OutputDataConfig (
517- s3_output_path = f"s3://{ self ._fetch_bucket_name_and_prefix (session )} "
518- f"/{ base_job_name } " ,
519- compression_type = "GZIP" ,
520- kms_key_id = None ,
521- )
522- logger .warning (
523- f"OutputDataConfig not provided. Using default:\n { self .output_data_config } "
524- )
534+ if self .output_data_config is None :
535+ session = self .sagemaker_session
536+ base_job_name = self .base_job_name
537+ self .output_data_config = OutputDataConfig (
538+ s3_output_path = f"s3://{ self ._fetch_bucket_name_and_prefix (session )} "
539+ f"/{ base_job_name } " ,
540+ compression_type = "GZIP" ,
541+ kms_key_id = None ,
542+ )
543+ logger .warning (
544+ f"OutputDataConfig not provided. Using default:\n { self .output_data_config } "
545+ )
546+ if self .output_data_config .s3_output_path is None :
547+ session = self .sagemaker_session
548+ base_job_name = self .base_job_name
549+ self .output_data_config .s3_output_path = (
550+ f"s3://{ self ._fetch_bucket_name_and_prefix (session )} /{ base_job_name } "
551+ )
552+ logger .warning (
553+ f"OutputDataConfig s3_output_path not provided. Using default:\n "
554+ f"{ self .output_data_config .s3_output_path } "
555+ )
556+ if self .output_data_config .compression_type is None :
557+ self .output_data_config .compression_type = "GZIP"
558+ logger .warning (
559+ f"OutputDataConfig compression type not provided. Using default:\n "
560+ f"{ self .output_data_config .compression_type } "
561+ )
525562
526- # TODO: Autodetect which image to use if source_code is provided
527563 if self .training_image :
528564 logger .info (f"Training image URI: { self .training_image } " )
529565
530- def _fetch_bucket_name_and_prefix (self , session : Session ) -> str :
566+ @staticmethod
567+ def _fetch_bucket_name_and_prefix (session : Session ) -> str :
531568 """Helper function to get the bucket name with the corresponding prefix if applicable"""
532569 if session .default_bucket_prefix is not None :
533570 return f"{ session .default_bucket ()} /{ session .default_bucket_prefix } "
@@ -558,16 +595,28 @@ def train(
558595 """
559596 self ._populate_intelligent_defaults ()
560597 current_training_job_name = _get_unique_name (self .base_job_name )
561- input_data_key_prefix = f"{ self .base_job_name } /{ current_training_job_name } /input"
562- if input_data_config :
598+ default_artifact_path = f"{ self .base_job_name } /{ current_training_job_name } "
599+ input_data_key_prefix = f"{ default_artifact_path } /input"
600+ if input_data_config and self .input_data_config :
563601 self .input_data_config = input_data_config
602+ # Add missing input data channels to the existing input_data_config
603+ final_input_channel_names = {i .channel_name for i in input_data_config }
604+ for input_data in self .input_data_config :
605+ if input_data .channel_name not in final_input_channel_names :
606+ input_data_config .append (input_data )
607+
608+ self .input_data_config = input_data_config or self .input_data_config or []
564609
565- input_data_config = []
566610 if self .input_data_config :
567611 input_data_config = self ._get_input_data_config (
568612 self .input_data_config , input_data_key_prefix
569613 )
570614
615+ if self .checkpoint_config and not self .checkpoint_config .s3_uri :
616+ self .checkpoint_config .s3_uri = f"s3://{ self ._fetch_bucket_name_and_prefix (self .sagemaker_session )} /{ default_artifact_path } "
617+ if self ._tensorboard_output_config and not self ._tensorboard_output_config .s3_uri :
618+ self ._tensorboard_output_config .s3_uri = f"s3://{ self ._fetch_bucket_name_and_prefix (self .sagemaker_session )} /{ default_artifact_path } "
619+
571620 string_hyper_parameters = {}
572621 if self .hyperparameters :
573622 for hyper_parameter , value in self .hyperparameters .items ():
0 commit comments