Skip to content

fix: bug in get latest version was getting the max sorted alphabetically #5014

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Feb 10, 2025
Merged
6 changes: 2 additions & 4 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def _model_id_retrieval_function(
return JumpStartVersionedModelId(model_id, sm_compatible_model_version)

versions_incompatible_with_sagemaker = [
Version(header.version)
header.version
for header in manifest.values() # type: ignore
if header.model_id == model_id
]
Expand Down Expand Up @@ -540,9 +540,7 @@ def _select_version(
"""

if version_str == "*":
if len(available_versions) == 0:
return None
return str(max(available_versions))
return utils.get_latest_version(available_versions)

if model_type == JumpStartModelType.PROPRIETARY:
if "*" in version_str:
Expand Down
10 changes: 9 additions & 1 deletion src/sagemaker/jumpstart/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from urllib.parse import urlparse
import boto3
from botocore.exceptions import ClientError
from packaging.version import Version
from packaging.version import Version, InvalidVersion
import botocore
from sagemaker_core.shapes import ModelAccessConfig
import sagemaker
Expand Down Expand Up @@ -1630,3 +1630,11 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
return get_jumpstart_gated_content_bucket(region=region)
return get_jumpstart_content_bucket(region=region)
return neo_bucket


def get_latest_version(versions: List[str]) -> Optional[str]:
"""Returns the latest version using sem-ver when possible."""
try:
return None if not versions else max(versions, key=Version)
except InvalidVersion:
return max(versions)
125 changes: 125 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@
from mock.mock import MagicMock
import pytest
from mock import patch
from packaging.version import Version


from sagemaker.jumpstart import utils
from sagemaker.jumpstart.cache import (
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY,
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY,
Expand All @@ -33,6 +36,7 @@
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE,
)
from sagemaker.jumpstart.types import (
JumpStartCachedContentValue,
JumpStartModelHeader,
JumpStartModelSpecs,
JumpStartVersionedModelId,
Expand Down Expand Up @@ -1119,3 +1123,124 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
),
]
)


@patch.object(JumpStartModelsCache, "_retrieval_function")
def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
retrieval_function: Mock,
):
sm_version = Version(utils.get_sagemaker_version())
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
print(str(new_sm_version))
versions = ["1.0.0", "2.9.1", "2.16.0"]
manifest = [
{
"model_id": "test-model",
"version": version,
"min_version": "2.49.0",
"spec_key": "spec_key",
}
for version in versions
]

manifest.append(
{
"model_id": "test-model",
"version": "3.0.0",
"min_version": str(new_sm_version),
"spec_key": "spec_key",
}
)

manifest_dict = {}
for header in manifest:
header_obj = JumpStartModelHeader(header)
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
header_obj
)
retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict)
key = JumpStartVersionedModelId("test-model", "*")

cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None)

assert_key = JumpStartVersionedModelId("test-model", "2.16.0")

assert result == assert_key


@patch.object(JumpStartModelsCache, "_retrieval_function")
def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
retrieval_function: Mock,
):
sm_version = Version(utils.get_sagemaker_version())
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
print(str(new_sm_version))
versions = ["1.0.0", "2.9.1", "2.16.0"]
manifest = [
{
"model_id": "test-model",
"version": version,
"min_version": "2.49.0",
"spec_key": "spec_key",
}
for version in versions
]

manifest.append(
{
"model_id": "test-model",
"version": "3.0.0",
"min_version": str(new_sm_version),
"spec_key": "spec_key",
}
)

manifest_dict = {}
for header in manifest:
header_obj = JumpStartModelHeader(header)
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
header_obj
)
retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict)
key = JumpStartVersionedModelId("test-model", "*")

cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
result = cache._get_proprietary_manifest_key_from_model_id(key=key, value=None)

assert_key = JumpStartVersionedModelId("test-model", "2.16.0")

assert result == assert_key


@patch.object(JumpStartModelsCache, "_retrieval_function")
def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(retrieval_function: Mock):
sm_version = Version(utils.get_sagemaker_version())
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
print(str(new_sm_version))
versions = ["abc", "2.9.1", "2.16.0"]
manifest = [
{
"model_id": "test-model",
"version": version,
"min_version": "2.49.0",
"spec_key": "spec_key",
}
for version in versions
]

manifest_dict = {}
for header in manifest:
header_obj = JumpStartModelHeader(header)
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
header_obj
)
retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict)
key = JumpStartVersionedModelId("test-model", "*")

cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None)

assert_key = JumpStartVersionedModelId("test-model", "abc")

assert result == assert_key
16 changes: 16 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2144,6 +2144,22 @@ def test_has_instance_rate_stat(stats, expected):
assert utils.has_instance_rate_stat(stats) is expected


def test_get_latest_version():
assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0"]) == "2.16.0"


def test_get_latest_version_empty_list_is_none():
assert utils.get_latest_version([]) is None


def test_get_latest_version_none_is_none():
assert utils.get_latest_version(None) is None


def test_get_latest_version_with_invalid_sem_ver():
assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0", "abc"]) == "abc"


@pytest.mark.parametrize(
"data, expected",
[(None, []), ([], []), (get_base_deployment_configs_metadata(), get_base_deployment_configs())],
Expand Down