Skip to content

Commit e359e21

Browse files
committed
Add override FW version
1 parent 2a2dfd0 commit e359e21

File tree

2 files changed

+3
-1
lines changed

2 files changed

+3
-1
lines changed

src/sagemaker/estimator.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3547,6 +3547,7 @@ def __init__(
35473547
self.checkpoint_s3_uri = checkpoint_s3_uri
35483548
self.checkpoint_local_path = checkpoint_local_path
35493549
self.enable_sagemaker_metrics = enable_sagemaker_metrics
3550+
self.override_fw_version = None
35503551

35513552
def _prepare_for_training(self, job_name=None):
35523553
"""Set hyperparameters needed for training. This method will also validate ``source_dir``.
@@ -3731,7 +3732,7 @@ def training_image_uri(self, region=None):
37313732
return image_uris.get_training_image_uri(
37323733
region=region or self.sagemaker_session.boto_region_name,
37333734
framework=self._framework_name,
3734-
framework_version=self.framework_version, # pylint: disable=no-member
3735+
framework_version=self.override_fw_version or self.framework_version, # pylint: disable=no-member
37353736
py_version=self.py_version, # pylint: disable=no-member
37363737
image_uri=self.image_uri,
37373738
distribution=getattr(self, "distribution", None),

src/sagemaker/tensorflow/estimator.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,7 @@ def __init__(
177177
fw.python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
178178
)
179179
self.framework_version = framework_version
180+
self.override_fw_version = "2.16"
180181
# TF training and inference versions do not have a one-to-one connection. This mismatch
181182
# is accommodated by the underlying dictionary. The key of dictionary relates to the inference
182183
# version and the value relates to training version.

0 commit comments

Comments
 (0)