diff --git a/src/sagemaker/jumpstart/cache.py b/src/sagemaker/jumpstart/cache.py index f862d4702a..29a903e00b 100644 --- a/src/sagemaker/jumpstart/cache.py +++ b/src/sagemaker/jumpstart/cache.py @@ -552,6 +552,12 @@ def _select_version( ) return version_str if version_str in available_versions else None + if version_str[-1] == "*": + # major or minor version is pinned, e.g 1.* or 1.0.* + return utils.get_latest_version( + [version for version in available_versions if version.startswith(version_str[:-1])] + ) + try: spec = SpecifierSet(f"=={version_str}") except InvalidSpecifier: diff --git a/tests/unit/sagemaker/jumpstart/constants.py b/tests/unit/sagemaker/jumpstart/constants.py index 83e8a44a32..2eb7469e21 100644 --- a/tests/unit/sagemaker/jumpstart/constants.py +++ b/tests/unit/sagemaker/jumpstart/constants.py @@ -15990,6 +15990,18 @@ "spec_key": "community_models_specs/tensorflow-ic-" "imagenet-inception-v3-classification-4/specs_v3.0.0.json", }, + { + "model_id": "meta-textgeneration-llama-2-7b", + "version": "4.9.0", + "min_version": "2.49.0", + "spec_key": "community_models/meta-textgeneration-llama-2-7b/specs_v4.9.0.json", + }, + { + "model_id": "meta-textgeneration-llama-2-7b", + "version": "4.13.0", + "min_version": "2.49.0", + "spec_key": "community_models/meta-textgeneration-llama-2-7b/specs_v4.13.0.json", + }, ] BASE_PROPRIETARY_HEADER = { diff --git a/tests/unit/sagemaker/jumpstart/test_cache.py b/tests/unit/sagemaker/jumpstart/test_cache.py index b7edc124d3..17996f4f15 100644 --- a/tests/unit/sagemaker/jumpstart/test_cache.py +++ b/tests/unit/sagemaker/jumpstart/test_cache.py @@ -184,6 +184,30 @@ def test_jumpstart_cache_get_header(): semantic_version_str="1.0.*", ) + assert JumpStartModelHeader( + { + "model_id": "meta-textgeneration-llama-2-7b", + "version": "4.13.0", + "min_version": "2.49.0", + "spec_key": "community_models/meta-textgeneration-llama-2-7b/specs_v4.13.0.json", + } + ) == cache.get_header( + model_id="meta-textgeneration-llama-2-7b", + semantic_version_str="*", + ) + + assert JumpStartModelHeader( + { + "model_id": "meta-textgeneration-llama-2-7b", + "version": "4.13.0", + "min_version": "2.49.0", + "spec_key": "community_models/meta-textgeneration-llama-2-7b/specs_v4.13.0.json", + } + ) == cache.get_header( + model_id="meta-textgeneration-llama-2-7b", + semantic_version_str="4.*", + ) + assert JumpStartModelHeader( { "model_id": "ai21-summarization",