Skip to content
Merged
6 changes: 2 additions & 4 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _model_id_retrieval_function(
return JumpStartVersionedModelId(model_id, sm_compatible_model_version)

versions_incompatible_with_sagemaker = [
Version(header.version)
header.version
for header in manifest.values() # type: ignore
if header.model_id == model_id
]
Expand Down Expand Up @@ -540,9 +540,7 @@ def _select_version(
"""

if version_str == "*":
if len(available_versions) == 0:
return None
return str(max(available_versions))
return utils.get_latest_version(available_versions)

if model_type == JumpStartModelType.PROPRIETARY:
if "*" in version_str:
Expand Down
10 changes: 9 additions & 1 deletion src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from urllib.parse import urlparse
import boto3
from botocore.exceptions import ClientError
from packaging.version import Version
from packaging.version import Version, InvalidVersion
import botocore
from sagemaker_core.shapes import ModelAccessConfig
import sagemaker
Expand Down Expand Up @@ -1630,3 +1630,11 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
return get_jumpstart_gated_content_bucket(region=region)
return get_jumpstart_content_bucket(region=region)
return neo_bucket


def get_latest_version(versions: List[str]) -> Optional[str]:
"""Returns the latest version using sem-ver when possible."""
try:
return None if not versions else max(versions, key=Version)
except InvalidVersion:
return max(versions)
125 changes: 125 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from mock.mock import MagicMock
import pytest
from mock import patch
from packaging.version import Version


from sagemaker.jumpstart import utils
from sagemaker.jumpstart.cache import (
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
Expand All @@ -33,6 +36,7 @@
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
)
from sagemaker.jumpstart.types import (
JumpStartCachedContentValue,
JumpStartModelHeader,
JumpStartModelSpecs,
JumpStartVersionedModelId,
Expand Down Expand Up @@ -1119,3 +1123,124 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
),
]
)


@patch.object(JumpStartModelsCache, "_retrieval_function")
def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
retrieval_function: Mock,
):
sm_version = Version(utils.get_sagemaker_version())
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
print(str(new_sm_version))
versions = ["1.0.0", "2.9.1", "2.16.0"]
manifest = [
{
"model_id": "test-model",
"version": version,
"min_version": "2.49.0",
"spec_key": "spec_key",
}
for version in versions
]

manifest.append(
{
"model_id": "test-model",
"version": "3.0.0",
"min_version": str(new_sm_version),
"spec_key": "spec_key",
}
)

manifest_dict = {}
for header in manifest:
header_obj = JumpStartModelHeader(header)
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
header_obj
)
retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict)
key = JumpStartVersionedModelId("test-model", "*")

cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None)

assert_key = JumpStartVersionedModelId("test-model", "2.16.0")

assert result == assert_key


@patch.object(JumpStartModelsCache, "_retrieval_function")
def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
retrieval_function: Mock,
):
sm_version = Version(utils.get_sagemaker_version())
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
print(str(new_sm_version))
versions = ["1.0.0", "2.9.1", "2.16.0"]
manifest = [
{
"model_id": "test-model",
"version": version,
"min_version": "2.49.0",
"spec_key": "spec_key",
}
for version in versions
]

manifest.append(
{
"model_id": "test-model",
"version": "3.0.0",
"min_version": str(new_sm_version),
"spec_key": "spec_key",
}
)

manifest_dict = {}
for header in manifest:
header_obj = JumpStartModelHeader(header)
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
header_obj
)
retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict)
key = JumpStartVersionedModelId("test-model", "*")

cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
result = cache._get_proprietary_manifest_key_from_model_id(key=key, value=None)

assert_key = JumpStartVersionedModelId("test-model", "2.16.0")

assert result == assert_key


@patch.object(JumpStartModelsCache, "_retrieval_function")
def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(retrieval_function: Mock):
sm_version = Version(utils.get_sagemaker_version())
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
print(str(new_sm_version))
versions = ["abc", "2.9.1", "2.16.0"]
manifest = [
{
"model_id": "test-model",
"version": version,
"min_version": "2.49.0",
"spec_key": "spec_key",
}
for version in versions
]

manifest_dict = {}
for header in manifest:
header_obj = JumpStartModelHeader(header)
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
header_obj
)
retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict)
key = JumpStartVersionedModelId("test-model", "*")

cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None)

assert_key = JumpStartVersionedModelId("test-model", "abc")

assert result == assert_key
16 changes: 16 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2144,6 +2144,22 @@ def test_has_instance_rate_stat(stats, expected):
assert utils.has_instance_rate_stat(stats) is expected


def test_get_latest_version():
assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0"]) == "2.16.0"


def test_get_latest_version_empty_list_is_none():
assert utils.get_latest_version([]) is None


def test_get_latest_version_none_is_none():
assert utils.get_latest_version(None) is None


def test_get_latest_version_with_invalid_sem_ver():
assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0", "abc"]) == "abc"


@pytest.mark.parametrize(
"data, expected",
[(None, []), ([], []), (get_base_deployment_configs_metadata(), get_base_deployment_configs())],
Expand Down