@@ -597,14 +597,16 @@ def train(
597597 current_training_job_name = _get_unique_name (self .base_job_name )
598598 input_data_key_prefix = f"{ self .base_job_name } /{ current_training_job_name } /input"
599599 if input_data_config and self .input_data_config :
600- self .input_data_config = input_data_config
601- # Add missing input data channels to the existing input_data_config
602- final_input_channel_names = {i .channel_name for i in input_data_config }
603- for input_data in self .input_data_config :
604- if input_data .channel_name not in final_input_channel_names :
605- input_data_config .append (input_data )
606-
607- self .input_data_config = input_data_config or self .input_data_config or []
600+ final_channels = {
601+ input_data .channel_name : input_data for input_data in self .input_data_config
602+ }
603+ # Update with precedence on the input_data_config passed into the train method
604+ final_channels .update (
605+ {input_data .channel_name : input_data for input_data in input_data_config }
606+ )
607+ self .input_data_config = list (final_channels .values ())
608+ else :
609+ self .input_data_config = input_data_config or self .input_data_config or []
608610
609611 if self .input_data_config :
610612 input_data_config = self ._get_input_data_config (
@@ -699,7 +701,7 @@ def train(
699701 training_job_name = current_training_job_name ,
700702 algorithm_specification = algorithm_specification ,
701703 hyper_parameters = string_hyper_parameters ,
702- input_data_config = input_data_config ,
704+ input_data_config = self . input_data_config ,
703705 resource_config = resource_config ,
704706 vpc_config = vpc_config ,
705707 # Public Instance Attributes
0 commit comments