Skip to content

Commit 6b49e42

Browse files
author
Eli Davidson
committed
handle invalid sev ver and incompatible sagemaker versions
1 parent d064b89 commit 6b49e42

File tree

4 files changed

+140
-4
lines changed

4 files changed

+140
-4
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def _model_id_retrieval_function(
262262
return JumpStartVersionedModelId(model_id, sm_compatible_model_version)
263263

264264
versions_incompatible_with_sagemaker = [
265-
Version(header.version)
265+
header.version
266266
for header in manifest.values() # type: ignore
267267
if header.model_id == model_id
268268
]

src/sagemaker/jumpstart/utils.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from urllib.parse import urlparse
2121
import boto3
2222
from botocore.exceptions import ClientError
23-
from packaging.version import Version
23+
from packaging.version import Version, InvalidVersion
2424
import botocore
2525
import sagemaker
2626
from sagemaker.config.config_schema import (
@@ -1527,4 +1527,6 @@ def wrapped_f(*args, **kwargs):
15271527
return wrapper_cache(_func)
15281528

15291529
def get_latest_version(versions: List[str]) -> Optional[str]:
1530-
return None if not versions else max(versions, key=Version)
1530+
try: return None if not versions else max(versions, key=Version)
1531+
except InvalidVersion as e:
1532+
return max(versions)

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from mock.mock import MagicMock
2323
import pytest
2424
from mock import patch
25+
from packaging.version import Version
2526

27+
28+
from sagemaker.jumpstart import utils
2629
from sagemaker.jumpstart.cache import (
2730
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
2831
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
@@ -33,6 +36,7 @@
3336
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
3437
)
3538
from sagemaker.jumpstart.types import (
39+
JumpStartCachedContentValue,
3640
JumpStartModelHeader,
3741
JumpStartModelSpecs,
3842
JumpStartVersionedModelId,
@@ -50,7 +54,7 @@
5054
BASE_PROPRIETARY_SPEC,
5155
BASE_PROPRIETARY_MANIFEST,
5256
)
53-
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
57+
from sagemaker.jumpstart.utils import get_formatted_manifest, get_jumpstart_content_bucket
5458

5559

5660
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function)
@@ -1119,3 +1123,130 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
11191123
),
11201124
]
11211125
)
1126+
1127+
@patch.object(JumpStartModelsCache, "_retrieval_function")
1128+
def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
1129+
retrieval_function: Mock
1130+
):
1131+
sm_version = Version(utils.get_sagemaker_version())
1132+
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
1133+
print(str(new_sm_version))
1134+
versions = ["1.0.0", "2.9.1", "2.16.0"]
1135+
manifest = [
1136+
{
1137+
"model_id": "test-model",
1138+
"version": version,
1139+
"min_version": "2.49.0",
1140+
"spec_key": "spec_key"
1141+
}
1142+
for version in versions
1143+
]
1144+
1145+
manifest.append(
1146+
{
1147+
"model_id": "test-model",
1148+
"version": "3.0.0",
1149+
"min_version": str(new_sm_version),
1150+
"spec_key": "spec_key"
1151+
}
1152+
)
1153+
1154+
manifest_dict = {}
1155+
for header in manifest:
1156+
header_obj = JumpStartModelHeader(header)
1157+
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
1158+
header_obj
1159+
)
1160+
retrieval_function.return_value = JumpStartCachedContentValue(
1161+
formatted_content=manifest_dict
1162+
)
1163+
key = JumpStartVersionedModelId("test-model", "*")
1164+
1165+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
1166+
result = cache._get_open_weight_manifest_key_from_model_id( key = key, value = None )
1167+
1168+
assert_key = JumpStartVersionedModelId("test-model", "2.16.0")
1169+
1170+
assert result == assert_key
1171+
1172+
@patch.object(JumpStartModelsCache, "_retrieval_function")
1173+
def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
1174+
retrieval_function: Mock
1175+
):
1176+
sm_version = Version(utils.get_sagemaker_version())
1177+
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
1178+
print(str(new_sm_version))
1179+
versions = ["1.0.0", "2.9.1", "2.16.0"]
1180+
manifest = [
1181+
{
1182+
"model_id": "test-model",
1183+
"version": version,
1184+
"min_version": "2.49.0",
1185+
"spec_key": "spec_key"
1186+
}
1187+
for version in versions
1188+
]
1189+
1190+
manifest.append(
1191+
{
1192+
"model_id": "test-model",
1193+
"version": "3.0.0",
1194+
"min_version": str(new_sm_version),
1195+
"spec_key": "spec_key"
1196+
}
1197+
)
1198+
1199+
manifest_dict = {}
1200+
for header in manifest:
1201+
header_obj = JumpStartModelHeader(header)
1202+
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
1203+
header_obj
1204+
)
1205+
retrieval_function.return_value = JumpStartCachedContentValue(
1206+
formatted_content=manifest_dict
1207+
)
1208+
key = JumpStartVersionedModelId("test-model", "*")
1209+
1210+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
1211+
result = cache._get_proprietary_manifest_key_from_model_id( key = key, value = None )
1212+
1213+
assert_key = JumpStartVersionedModelId("test-model", "2.16.0")
1214+
1215+
assert result == assert_key
1216+
1217+
1218+
@patch.object(JumpStartModelsCache, "_retrieval_function")
1219+
def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(
1220+
retrieval_function: Mock
1221+
):
1222+
sm_version = Version(utils.get_sagemaker_version())
1223+
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
1224+
print(str(new_sm_version))
1225+
versions = ["abc", "2.9.1", "2.16.0"]
1226+
manifest = [
1227+
{
1228+
"model_id": "test-model",
1229+
"version": version,
1230+
"min_version": "2.49.0",
1231+
"spec_key": "spec_key"
1232+
}
1233+
for version in versions
1234+
]
1235+
1236+
manifest_dict = {}
1237+
for header in manifest:
1238+
header_obj = JumpStartModelHeader(header)
1239+
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
1240+
header_obj
1241+
)
1242+
retrieval_function.return_value = JumpStartCachedContentValue(
1243+
formatted_content=manifest_dict
1244+
)
1245+
key = JumpStartVersionedModelId("test-model", "*")
1246+
1247+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
1248+
result = cache._get_open_weight_manifest_key_from_model_id( key = key, value = None )
1249+
1250+
assert_key = JumpStartVersionedModelId("test-model", "abc")
1251+
1252+
assert result == assert_key

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2151,6 +2151,9 @@ def test_get_latest_version_empty_list_is_none():
21512151
def test_get_latest_version_none_is_none():
21522152
assert utils.get_latest_version(None) == None
21532153

2154+
def test_get_latest_version_with_invalid_sem_ver():
2155+
assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0", "abc"]) == "abc"
2156+
21542157
@pytest.mark.parametrize(
21552158
"data, expected",
21562159
[(None, []), ([], []), (get_base_deployment_configs_metadata(), get_base_deployment_configs())],

0 commit comments

Comments
 (0)