Skip to content

Commit a2b424d

Browse files
committed
Reformatting
1 parent 3ad5961 commit a2b424d

File tree

5 files changed

+19
-8
lines changed

5 files changed

+19
-8
lines changed

src/sagemaker/estimator.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3727,13 +3727,17 @@ def training_image_uri(self, region=None):
37273727
Returns:
37283728
str: The URI of the Docker image.
37293729
"""
3730-
## TODO: Revert
3731-
print(f"""ABCD1234 self.override_fw_version: {getattr(self, "override_fw_version", None)} and self.framework_version: {self.framework_version}""")
3730+
# TODO: Revert
3731+
print(
3732+
f"""ABCD123 {getattr(self, "override_fw_version", None)}, self.framework_version:{self.framework_version}"""
3733+
)
37323734

37333735
return image_uris.get_training_image_uri(
37343736
region=region or self.sagemaker_session.boto_region_name,
37353737
framework=self._framework_name,
3736-
framework_version= getattr(self, "override_fw_version", self.framework_version), # pylint: disable=no-member
3738+
framework_version=getattr(
3739+
self, "override_fw_version", self.framework_version
3740+
), # pylint: disable=no-member
37373741
py_version=self.py_version, # pylint: disable=no-member
37383742
image_uri=self.image_uri,
37393743
distribution=getattr(self, "distribution", None),

src/sagemaker/tensorflow/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ def __init__(
178178
)
179179
self.framework_version = framework_version
180180
if self.framework_version and "2.16" in self.framework_version:
181-
self.override_fw_version="2.16"
182-
## TODO: Revert
181+
self.override_fw_version = "2.16"
182+
# TODO: Revert
183183
print(f"ABCD123 setting self.override_fw_version to {self.override_fw_version}")
184184

185185
# TF training and inference versions do not have a one-to-one connection. This mismatch

tests/integ/sagemaker/workflow/test_training_steps.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,7 @@ def test_tensorflow_training_step_with_parameterized_code_input(
243243
"This test is failing in TensorFlow 2.16 beacuse of an upstream bug: "
244244
"https://github.com/tensorflow/io/issues/2039"
245245
)
246-
246+
247247
base_dir = os.path.join(DATA_DIR, "tensorflow_mnist")
248248
entry_point1 = "mnist_v2.py"
249249
entry_point2 = "mnist_dummy.py"

tests/integ/test_transformer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,7 @@ def test_transform_tf_kms_network_isolation(
561561
"This test is failing in TensorFlow 2.16 beacuse of an upstream bug: "
562562
"https://github.com/tensorflow/io/issues/2039"
563563
)
564-
564+
565565
data_path = os.path.join(DATA_DIR, "tensorflow_mnist")
566566

567567
tf = TensorFlow(

tests/unit/test_processing.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,14 @@ def test_tensorflow_processor_with_required_parameters(
509509
else:
510510
tensorflow_image_uri = (
511511
"763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-training:{}-cpu-{}"
512-
).format("2.16" if Version(tensorflow_training_version) in SpecifierSet("==2.16.*") else tensorflow_training_version, tensorflow_training_py_version)
512+
).format(
513+
(
514+
"2.16"
515+
if Version(tensorflow_training_version) in SpecifierSet("==2.16.*")
516+
else tensorflow_training_version
517+
),
518+
tensorflow_training_py_version,
519+
)
513520

514521
expected_args["app_specification"]["ImageUri"] = tensorflow_image_uri
515522

0 commit comments

Comments
 (0)