Skip to content

Commit d1a3475

Browse files
committed
Revert unit test changes
1 parent 48c2507 commit d1a3475

File tree

5 files changed

+10
-29
lines changed

5 files changed

+10
-29
lines changed

src/sagemaker/estimator.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3727,12 +3727,11 @@ def training_image_uri(self, region=None):
37273727
Returns:
37283728
str: The URI of the Docker image.
37293729
"""
3730+
37303731
return image_uris.get_training_image_uri(
37313732
region=region or self.sagemaker_session.boto_region_name,
37323733
framework=self._framework_name,
3733-
framework_version=getattr(
3734-
self, "override_fw_version", self.framework_version # pylint: disable=no-member
3735-
),
3734+
framework_version=self.framework_version, # pylint: disable=no-member
37363735
py_version=self.py_version, # pylint: disable=no-member
37373736
image_uri=self.image_uri,
37383737
distribution=getattr(self, "distribution", None),

src/sagemaker/tensorflow/estimator.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -177,11 +177,6 @@ def __init__(
177177
fw.python_deprecation_warning(self._framework_name, defaults.LATEST_PY2_VERSION)
178178
)
179179
self.framework_version = framework_version
180-
# if self.framework_version and "2.16" in self.framework_version:
181-
# self.override_fw_version = "2.16"
182-
# # TODO: Revert
183-
# print(f"ABCD123 setting self.override_fw_version to {self.override_fw_version}")
184-
185180
self.py_version = py_version
186181
self.instance_type = instance_type
187182

tests/conftest.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,7 +557,10 @@ def tf_full_version(tensorflow_training_latest_version, tensorflow_inference_lat
557557
Otherwise, this would simply be a single latest version.
558558
"""
559559
if Version(tensorflow_training_latest_version) in SpecifierSet(">=2.16"):
560-
return f"{Version(tensorflow_training_latest_version).major}.{Version(tensorflow_training_latest_version).minor}"
560+
return (
561+
f"{Version(tensorflow_training_latest_version).major}"
562+
f".{Version(tensorflow_training_latest_version).minor}"
563+
)
561564
return str(
562565
min(
563566
Version(tensorflow_training_latest_version),

tests/unit/sagemaker/tensorflow/test_estimator.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
from mock import patch, Mock, MagicMock
2020
from packaging import version
21-
from packaging.version import Version
22-
from packaging.specifiers import SpecifierSet
2321
import pytest
2422

2523
from sagemaker.estimator import _TrainingJob
@@ -556,14 +554,9 @@ def test_fit_mwms(
556554

557555
expected_train_args = _create_train_job(framework_version, py_version=py_version)
558556
expected_train_args["input_config"][0]["DataSource"]["S3DataSource"]["S3Uri"] = inputs
559-
if Version(framework_version) in SpecifierSet("==2.16.*"):
560-
expected_train_args["image_uri"] = (
561-
f"763104351884.dkr.ecr.{REGION}.amazonaws.com/tensorflow-training:2.16-cpu-{py_version}"
562-
)
563-
else:
564-
expected_train_args["image_uri"] = (
565-
f"763104351884.dkr.ecr.{REGION}.amazonaws.com/tensorflow-training:{framework_version}-cpu-{py_version}"
566-
)
557+
expected_train_args["image_uri"] = (
558+
f"763104351884.dkr.ecr.{REGION}.amazonaws.com/tensorflow-training:{framework_version}-cpu-{py_version}"
559+
)
567560
expected_train_args["job_name"] = f"tensorflow-training-{TIMESTAMP}"
568561
expected_train_args["hyperparameters"][TensorFlow.LAUNCH_MWMS_ENV_NAME] = json.dumps(True)
569562
expected_train_args["hyperparameters"]["sagemaker_job_name"] = json.dumps(

tests/unit/test_processing.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
import pytest
1919
from mock import Mock, patch, MagicMock
2020
from packaging import version
21-
from packaging.version import Version
22-
from packaging.specifiers import SpecifierSet
2321

2422
from sagemaker import LocalSession
2523
from sagemaker.dataset_definition.inputs import (
@@ -509,14 +507,7 @@ def test_tensorflow_processor_with_required_parameters(
509507
else:
510508
tensorflow_image_uri = (
511509
"763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:{}-cpu-{}"
512-
).format(
513-
(
514-
"2.16"
515-
if Version(tensorflow_training_version) in SpecifierSet("==2.16.*")
516-
else tensorflow_training_version
517-
),
518-
tensorflow_training_py_version,
519-
)
510+
).format(tensorflow_training_version, tensorflow_training_py_version)
520511

521512
expected_args["app_specification"]["ImageUri"] = tensorflow_image_uri
522513

0 commit comments

Comments
 (0)