Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions src/sagemaker/jumpstart/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,6 +552,10 @@ def _select_version(
)
return version_str if version_str in available_versions else None

if version_str[-1] == "*":
# major or minor version is pinned, e.g 1.* or 1.0.*
return utils.get_latest_version([version for version in available_versions if version.startswith(version_str[:-1])])

try:
spec = SpecifierSet(f"=={version_str}")
except InvalidSpecifier:
Expand Down
12 changes: 12 additions & 0 deletions tests/unit/sagemaker/jumpstart/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15990,6 +15990,18 @@
"spec_key": "community_models_specs/tensorflow-ic-"
"imagenet-inception-v3-classification-4/specs_v3.0.0.json",
},
{
"model_id": "meta-textgeneration-llama-2-7b",
"version": "4.9.0",
"min_version": "2.49.0",
"spec_key": "community_models/meta-textgeneration-llama-2-7b/specs_v4.9.0.json",
},
{
"model_id": "meta-textgeneration-llama-2-7b",
"version": "4.13.0",
"min_version": "2.49.0",
"spec_key": "community_models/meta-textgeneration-llama-2-7b/specs_v4.13.0.json",
},
]

BASE_PROPRIETARY_HEADER = {
Expand Down
24 changes: 24 additions & 0 deletions tests/unit/sagemaker/jumpstart/test_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,30 @@ def test_jumpstart_cache_get_header():
semantic_version_str="1.0.*",
)

assert JumpStartModelHeader(
{
"model_id": "meta-textgeneration-llama-2-7b",
"version": "4.13.0",
"min_version": "2.49.0",
"spec_key": "community_models/meta-textgeneration-llama-2-7b/specs_v4.13.0.json",
}
) == cache.get_header(
model_id="meta-textgeneration-llama-2-7b",
semantic_version_str="*",
)

assert JumpStartModelHeader(
{
"model_id": "meta-textgeneration-llama-2-7b",
"version": "4.13.0",
"min_version": "2.49.0",
"spec_key": "community_models/meta-textgeneration-llama-2-7b/specs_v4.13.0.json",
}
) == cache.get_header(
model_id="meta-textgeneration-llama-2-7b",
semantic_version_str="4.*",
)

assert JumpStartModelHeader(
{
"model_id": "ai21-summarization",
Expand Down
Loading