@@ -580,7 +580,7 @@ def train(
580580 """Train a model using AWS SageMaker.
581581
582582 Args:
583- input_data_config (Optional[Union[ List[Channel], Dict[str, DataSourceType ]]]):
583+ input_data_config (Optional[List[Union[ Channel, InputData ]]]):
584584 The input data config for the training job.
585585 Takes a list of Channel objects or a dictionary of channel names to DataSourceType.
586586 DataSourceType can be an S3 URI string, local file path string,
@@ -596,11 +596,23 @@ def train(
596596 current_training_job_name = _get_unique_name (self .base_job_name )
597597 input_data_key_prefix = f"{ self .base_job_name } /{ current_training_job_name } /input"
598598
599- self .input_data_config = input_data_config or self .input_data_config or []
599+ final_input_data_config = self .input_data_config .copy () if self .input_data_config else []
600+
601+ if input_data_config :
602+ # merge the inputs with method parameter taking precedence
603+ existing_channels = {input .channel_name : input for input in final_input_data_config }
604+ new_channels = []
605+ for new_input in input_data_config :
606+ if new_input .channel_name in existing_channels :
607+ existing_channels [new_input .channel_name ] = new_input
608+ else :
609+ new_channels .append (new_input )
610+
611+ final_input_data_config = list (existing_channels .values ()) + new_channels
600612
601- if self . input_data_config :
602- self . input_data_config = self ._get_input_data_config (
603- self . input_data_config , input_data_key_prefix
613+ if final_input_data_config :
614+ final_input_data_config = self ._get_input_data_config (
615+ final_input_data_config , input_data_key_prefix
604616 )
605617
606618 if self .checkpoint_config and not self .checkpoint_config .s3_uri :
@@ -643,7 +655,7 @@ def train(
643655 data_source = self .source_code .source_dir ,
644656 key_prefix = input_data_key_prefix ,
645657 )
646- self . input_data_config .append (source_code_channel )
658+ final_input_data_config .append (source_code_channel )
647659
648660 self ._prepare_train_script (
649661 tmp_dir = tmp_dir ,
@@ -664,7 +676,7 @@ def train(
664676 data_source = tmp_dir .name ,
665677 key_prefix = input_data_key_prefix ,
666678 )
667- self . input_data_config .append (sm_drivers_channel )
679+ final_input_data_config .append (sm_drivers_channel )
668680
669681 # If source_code is provided, we will always use
670682 # the default container entrypoint and arguments
@@ -691,7 +703,7 @@ def train(
691703 training_job_name = current_training_job_name ,
692704 algorithm_specification = algorithm_specification ,
693705 hyper_parameters = string_hyper_parameters ,
694- input_data_config = self . input_data_config ,
706+ input_data_config = final_input_data_config ,
695707 resource_config = resource_config ,
696708 vpc_config = vpc_config ,
697709 # Public Instance Attributes
@@ -736,7 +748,7 @@ def train(
736748 sagemaker_session = self .sagemaker_session ,
737749 container_entrypoint = algorithm_specification .container_entrypoint ,
738750 container_arguments = algorithm_specification .container_arguments ,
739- input_data_config = self . input_data_config ,
751+ input_data_config = final_input_data_config ,
740752 hyper_parameters = string_hyper_parameters ,
741753 environment = self .environment ,
742754 )
0 commit comments