Skip to content

Commit 8219160

Browse files
author
Eli Davidson
committed
handle invalid sev ver and incompatible sagemaker versions
1 parent e4f4816 commit 8219160

File tree

4 files changed

+140
-5
lines changed

4 files changed

+140
-5
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,7 @@ def _model_id_retrieval_function(
262262
return JumpStartVersionedModelId(model_id, sm_compatible_model_version)
263263

264264
versions_incompatible_with_sagemaker = [
265-
Version(header.version)
265+
header.version
266266
for header in manifest.values() # type: ignore
267267
if header.model_id == model_id
268268
]

src/sagemaker/jumpstart/utils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from urllib.parse import urlparse
2222
import boto3
2323
from botocore.exceptions import ClientError
24-
from packaging.version import Version
24+
from packaging.version import Version, InvalidVersion
2525
import botocore
2626
from sagemaker_core.shapes import ModelAccessConfig
2727
import sagemaker
@@ -1632,5 +1632,6 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
16321632
return neo_bucket
16331633

16341634
def get_latest_version(versions: List[str]) -> Optional[str]:
1635-
return None if not versions else max(versions, key=Version)
1636-
1635+
try: return None if not versions else max(versions, key=Version)
1636+
except InvalidVersion as e:
1637+
return max(versions)

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 132 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from mock.mock import MagicMock
2323
import pytest
2424
from mock import patch
25+
from packaging.version import Version
2526

27+
28+
from sagemaker.jumpstart import utils
2629
from sagemaker.jumpstart.cache import (
2730
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
2831
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
@@ -33,6 +36,7 @@
3336
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
3437
)
3538
from sagemaker.jumpstart.types import (
39+
JumpStartCachedContentValue,
3640
JumpStartModelHeader,
3741
JumpStartModelSpecs,
3842
JumpStartVersionedModelId,
@@ -50,7 +54,7 @@
5054
BASE_PROPRIETARY_SPEC,
5155
BASE_PROPRIETARY_MANIFEST,
5256
)
53-
from sagemaker.jumpstart.utils import get_jumpstart_content_bucket
57+
from sagemaker.jumpstart.utils import get_formatted_manifest, get_jumpstart_content_bucket
5458

5559

5660
@patch.object(JumpStartModelsCache, "_retrieval_function", patched_retrieval_function)
@@ -1119,3 +1123,130 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
11191123
),
11201124
]
11211125
)
1126+
1127+
@patch.object(JumpStartModelsCache, "_retrieval_function")
1128+
def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
1129+
retrieval_function: Mock
1130+
):
1131+
sm_version = Version(utils.get_sagemaker_version())
1132+
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
1133+
print(str(new_sm_version))
1134+
versions = ["1.0.0", "2.9.1", "2.16.0"]
1135+
manifest = [
1136+
{
1137+
"model_id": "test-model",
1138+
"version": version,
1139+
"min_version": "2.49.0",
1140+
"spec_key": "spec_key"
1141+
}
1142+
for version in versions
1143+
]
1144+
1145+
manifest.append(
1146+
{
1147+
"model_id": "test-model",
1148+
"version": "3.0.0",
1149+
"min_version": str(new_sm_version),
1150+
"spec_key": "spec_key"
1151+
}
1152+
)
1153+
1154+
manifest_dict = {}
1155+
for header in manifest:
1156+
header_obj = JumpStartModelHeader(header)
1157+
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
1158+
header_obj
1159+
)
1160+
retrieval_function.return_value = JumpStartCachedContentValue(
1161+
formatted_content=manifest_dict
1162+
)
1163+
key = JumpStartVersionedModelId("test-model", "*")
1164+
1165+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
1166+
result = cache._get_open_weight_manifest_key_from_model_id( key = key, value = None )
1167+
1168+
assert_key = JumpStartVersionedModelId("test-model", "2.16.0")
1169+
1170+
assert result == assert_key
1171+
1172+
@patch.object(JumpStartModelsCache, "_retrieval_function")
1173+
def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
1174+
retrieval_function: Mock
1175+
):
1176+
sm_version = Version(utils.get_sagemaker_version())
1177+
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
1178+
print(str(new_sm_version))
1179+
versions = ["1.0.0", "2.9.1", "2.16.0"]
1180+
manifest = [
1181+
{
1182+
"model_id": "test-model",
1183+
"version": version,
1184+
"min_version": "2.49.0",
1185+
"spec_key": "spec_key"
1186+
}
1187+
for version in versions
1188+
]
1189+
1190+
manifest.append(
1191+
{
1192+
"model_id": "test-model",
1193+
"version": "3.0.0",
1194+
"min_version": str(new_sm_version),
1195+
"spec_key": "spec_key"
1196+
}
1197+
)
1198+
1199+
manifest_dict = {}
1200+
for header in manifest:
1201+
header_obj = JumpStartModelHeader(header)
1202+
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
1203+
header_obj
1204+
)
1205+
retrieval_function.return_value = JumpStartCachedContentValue(
1206+
formatted_content=manifest_dict
1207+
)
1208+
key = JumpStartVersionedModelId("test-model", "*")
1209+
1210+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
1211+
result = cache._get_proprietary_manifest_key_from_model_id( key = key, value = None )
1212+
1213+
assert_key = JumpStartVersionedModelId("test-model", "2.16.0")
1214+
1215+
assert result == assert_key
1216+
1217+
1218+
@patch.object(JumpStartModelsCache, "_retrieval_function")
1219+
def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(
1220+
retrieval_function: Mock
1221+
):
1222+
sm_version = Version(utils.get_sagemaker_version())
1223+
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
1224+
print(str(new_sm_version))
1225+
versions = ["abc", "2.9.1", "2.16.0"]
1226+
manifest = [
1227+
{
1228+
"model_id": "test-model",
1229+
"version": version,
1230+
"min_version": "2.49.0",
1231+
"spec_key": "spec_key"
1232+
}
1233+
for version in versions
1234+
]
1235+
1236+
manifest_dict = {}
1237+
for header in manifest:
1238+
header_obj = JumpStartModelHeader(header)
1239+
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
1240+
header_obj
1241+
)
1242+
retrieval_function.return_value = JumpStartCachedContentValue(
1243+
formatted_content=manifest_dict
1244+
)
1245+
key = JumpStartVersionedModelId("test-model", "*")
1246+
1247+
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
1248+
result = cache._get_open_weight_manifest_key_from_model_id( key = key, value = None )
1249+
1250+
assert_key = JumpStartVersionedModelId("test-model", "abc")
1251+
1252+
assert result == assert_key

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2152,6 +2152,9 @@ def test_get_latest_version_empty_list_is_none():
21522152
def test_get_latest_version_none_is_none():
21532153
assert utils.get_latest_version(None) == None
21542154

2155+
def test_get_latest_version_with_invalid_sem_ver():
2156+
assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0", "abc"]) == "abc"
2157+
21552158
@pytest.mark.parametrize(
21562159
"data, expected",
21572160
[(None, []), ([], []), (get_base_deployment_configs_metadata(), get_base_deployment_configs())],

0 commit comments

Comments
 (0)