Skip to content

Commit d064b89

Browse files
author
Eli Davidson
committed
fix: bug in get latest version was getting the max sorted alphabetically
instead of sem-ver
1 parent 3d8ffb8 commit d064b89

File tree

3 files changed

+12
-3
lines changed

3 files changed

+12
-3
lines changed

src/sagemaker/jumpstart/cache.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -540,9 +540,7 @@ def _select_version(
540540
"""
541541

542542
if version_str == "*":
543-
if len(available_versions) == 0:
544-
return None
545-
return str(max(available_versions))
543+
return utils.get_latest_version(available_versions)
546544

547545
if model_type == JumpStartModelType.PROPRIETARY:
548546
if "*" in version_str:

src/sagemaker/jumpstart/utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,3 +1525,6 @@ def wrapped_f(*args, **kwargs):
15251525
if _func is None:
15261526
return wrapper_cache
15271527
return wrapper_cache(_func)
1528+
1529+
def get_latest_version(versions: List[str]) -> Optional[str]:
1530+
return None if not versions else max(versions, key=Version)

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2142,6 +2142,14 @@ def test_add_instance_rate_stats_to_benchmark_metrics_client_ex(
21422142
def test_has_instance_rate_stat(stats, expected):
21432143
assert utils.has_instance_rate_stat(stats) is expected
21442144

2145+
def test_get_latest_version():
2146+
assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0"]) == "2.16.0"
2147+
2148+
def test_get_latest_version_empty_list_is_none():
2149+
assert utils.get_latest_version([]) == None
2150+
2151+
def test_get_latest_version_none_is_none():
2152+
assert utils.get_latest_version(None) == None
21452153

21462154
@pytest.mark.parametrize(
21472155
"data, expected",

0 commit comments

Comments
 (0)