Skip to content

Commit 89c856b

Browse files
committed
figured out bug in unit test where version is referring to pt, not smp
version
1 parent 0e0dad2 commit 89c856b

File tree

2 files changed

+3
-7
lines changed

2 files changed

+3
-7
lines changed

src/sagemaker/image_uri_config/pytorch-smp.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
"2.1": "2.1.2",
99
"2.2": "2.3.1",
1010
"2.2.0": "2.3.1",
11+
"2.3": "2.4.0",
1112
"2.3.1": "2.4.0"
1213
},
1314
"versions": {

tests/unit/sagemaker/image_uris/test_smp_v2.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,13 @@ def test_smp_v2(load_config):
2828
"smdistributed": {"modelparallel": {"enabled": True}},
2929
}
3030

31-
print("load_config", load_config)
32-
print("VERSIONS", VERSIONS)
33-
3431
for processor in PROCESSORS:
3532
for version in VERSIONS:
3633
ACCOUNTS = load_config["training"]["versions"][version]["registries"]
3734
PY_VERSIONS = load_config["training"]["versions"][version]["py_versions"]
3835
for py_version in PY_VERSIONS:
39-
# py311 is only for smp version 2.4.0
40-
if py_version == "py311" and "2.4" not in version:
36+
# py311 is only for PT 2.3.1 and SMP 2.4.0
37+
if py_version == "py311" and "2.3" not in version:
4138
continue
4239

4340
for region in ACCOUNTS.keys():
@@ -46,8 +43,6 @@ def test_smp_v2(load_config):
4643
if "2.1" in version or "2.2" in version or "2.3" in version:
4744
cuda_vers = "cu121"
4845

49-
print("HERE", version, py_version)
50-
5146
uri = image_uris.get_training_image_uri(
5247
region,
5348
framework="pytorch",

0 commit comments

Comments
 (0)