@@ -597,14 +597,16 @@ def train(
597
597
current_training_job_name = _get_unique_name (self .base_job_name )
598
598
input_data_key_prefix = f"{ self .base_job_name } /{ current_training_job_name } /input"
599
599
if input_data_config and self .input_data_config :
600
- self .input_data_config = input_data_config
601
- # Add missing input data channels to the existing input_data_config
602
- final_input_channel_names = {i .channel_name for i in input_data_config }
603
- for input_data in self .input_data_config :
604
- if input_data .channel_name not in final_input_channel_names :
605
- input_data_config .append (input_data )
606
-
607
- self .input_data_config = input_data_config or self .input_data_config or []
600
+ final_channels = {
601
+ input_data .channel_name : input_data for input_data in self .input_data_config
602
+ }
603
+ # Update with precedence on the input_data_config passed into the train method
604
+ final_channels .update (
605
+ {input_data .channel_name : input_data for input_data in input_data_config }
606
+ )
607
+ self .input_data_config = list (final_channels .values ())
608
+ else :
609
+ self .input_data_config = input_data_config or self .input_data_config or []
608
610
609
611
if self .input_data_config :
610
612
input_data_config = self ._get_input_data_config (
@@ -699,7 +701,7 @@ def train(
699
701
training_job_name = current_training_job_name ,
700
702
algorithm_specification = algorithm_specification ,
701
703
hyper_parameters = string_hyper_parameters ,
702
- input_data_config = input_data_config ,
704
+ input_data_config = self . input_data_config ,
703
705
resource_config = resource_config ,
704
706
vpc_config = vpc_config ,
705
707
# Public Instance Attributes
0 commit comments