Skip to content

Commit 6c750be

Browse files
committed
fix input_data_config
1 parent a50e868 commit 6c750be

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

src/sagemaker/modules/configs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,8 @@ class OutputDataConfig(shapes.OutputDataConfig):
240240
SageMaker uses to encrypt the model artifacts at rest using Amazon S3 server-side
241241
encryption.
242242
compression_type (Optional[str]):
243-
The model output compression type. Select None to output an uncompressed model,
244-
recommended for large model outputs. Defaults to gzip.
243+
The model output compression type. Select `NONE` to output an uncompressed model,
244+
recommended for large model outputs. Defaults to `GZIP`.
245245
"""
246246

247247
s3_output_path: Optional[str] = None

src/sagemaker/modules/train/model_trainer.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -597,14 +597,16 @@ def train(
597597
current_training_job_name = _get_unique_name(self.base_job_name)
598598
input_data_key_prefix = f"{self.base_job_name}/{current_training_job_name}/input"
599599
if input_data_config and self.input_data_config:
600-
self.input_data_config = input_data_config
601-
# Add missing input data channels to the existing input_data_config
602-
final_input_channel_names = {i.channel_name for i in input_data_config}
603-
for input_data in self.input_data_config:
604-
if input_data.channel_name not in final_input_channel_names:
605-
input_data_config.append(input_data)
606-
607-
self.input_data_config = input_data_config or self.input_data_config or []
600+
final_channels = {
601+
input_data.channel_name: input_data for input_data in self.input_data_config
602+
}
603+
# Update with precedence on the input_data_config passed into the train method
604+
final_channels.update(
605+
{input_data.channel_name: input_data for input_data in input_data_config}
606+
)
607+
self.input_data_config = list(final_channels.values())
608+
else:
609+
self.input_data_config = input_data_config or self.input_data_config or []
608610

609611
if self.input_data_config:
610612
input_data_config = self._get_input_data_config(
@@ -699,7 +701,7 @@ def train(
699701
training_job_name=current_training_job_name,
700702
algorithm_specification=algorithm_specification,
701703
hyper_parameters=string_hyper_parameters,
702-
input_data_config=input_data_config,
704+
input_data_config=self.input_data_config,
703705
resource_config=resource_config,
704706
vpc_config=vpc_config,
705707
# Public Instance Attributes

0 commit comments

Comments
 (0)