2222from mock .mock import MagicMock
2323import pytest
2424from mock import patch
25+ from packaging .version import Version
2526
27+
28+ from sagemaker .jumpstart import utils
2629from sagemaker .jumpstart .cache import (
2730 JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ,
2831 JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY ,
3336 ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE ,
3437)
3538from sagemaker .jumpstart .types import (
39+ JumpStartCachedContentValue ,
3640 JumpStartModelHeader ,
3741 JumpStartModelSpecs ,
3842 JumpStartVersionedModelId ,
5054 BASE_PROPRIETARY_SPEC ,
5155 BASE_PROPRIETARY_MANIFEST ,
5256)
53- from sagemaker .jumpstart .utils import get_jumpstart_content_bucket
57+ from sagemaker .jumpstart .utils import get_formatted_manifest , get_jumpstart_content_bucket
5458
5559
5660@patch .object (JumpStartModelsCache , "_retrieval_function" , patched_retrieval_function )
@@ -1119,3 +1123,130 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
11191123 ),
11201124 ]
11211125 )
1126+
1127+ @patch .object (JumpStartModelsCache , "_retrieval_function" )
1128+ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights (
1129+ retrieval_function : Mock
1130+ ):
1131+ sm_version = Version (utils .get_sagemaker_version ())
1132+ new_sm_version = Version (str (sm_version .major + 1 ) + ".0.0" )
1133+ print (str (new_sm_version ))
1134+ versions = ["1.0.0" , "2.9.1" , "2.16.0" ]
1135+ manifest = [
1136+ {
1137+ "model_id" : "test-model" ,
1138+ "version" : version ,
1139+ "min_version" : "2.49.0" ,
1140+ "spec_key" : "spec_key"
1141+ }
1142+ for version in versions
1143+ ]
1144+
1145+ manifest .append (
1146+ {
1147+ "model_id" : "test-model" ,
1148+ "version" : "3.0.0" ,
1149+ "min_version" : str (new_sm_version ),
1150+ "spec_key" : "spec_key"
1151+ }
1152+ )
1153+
1154+ manifest_dict = {}
1155+ for header in manifest :
1156+ header_obj = JumpStartModelHeader (header )
1157+ manifest_dict [JumpStartVersionedModelId (header_obj .model_id , header_obj .version )] = (
1158+ header_obj
1159+ )
1160+ retrieval_function .return_value = JumpStartCachedContentValue (
1161+ formatted_content = manifest_dict
1162+ )
1163+ key = JumpStartVersionedModelId ("test-model" , "*" )
1164+
1165+ cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
1166+ result = cache ._get_open_weight_manifest_key_from_model_id ( key = key , value = None )
1167+
1168+ assert_key = JumpStartVersionedModelId ("test-model" , "2.16.0" )
1169+
1170+ assert result == assert_key
1171+
1172+ @patch .object (JumpStartModelsCache , "_retrieval_function" )
1173+ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights (
1174+ retrieval_function : Mock
1175+ ):
1176+ sm_version = Version (utils .get_sagemaker_version ())
1177+ new_sm_version = Version (str (sm_version .major + 1 ) + ".0.0" )
1178+ print (str (new_sm_version ))
1179+ versions = ["1.0.0" , "2.9.1" , "2.16.0" ]
1180+ manifest = [
1181+ {
1182+ "model_id" : "test-model" ,
1183+ "version" : version ,
1184+ "min_version" : "2.49.0" ,
1185+ "spec_key" : "spec_key"
1186+ }
1187+ for version in versions
1188+ ]
1189+
1190+ manifest .append (
1191+ {
1192+ "model_id" : "test-model" ,
1193+ "version" : "3.0.0" ,
1194+ "min_version" : str (new_sm_version ),
1195+ "spec_key" : "spec_key"
1196+ }
1197+ )
1198+
1199+ manifest_dict = {}
1200+ for header in manifest :
1201+ header_obj = JumpStartModelHeader (header )
1202+ manifest_dict [JumpStartVersionedModelId (header_obj .model_id , header_obj .version )] = (
1203+ header_obj
1204+ )
1205+ retrieval_function .return_value = JumpStartCachedContentValue (
1206+ formatted_content = manifest_dict
1207+ )
1208+ key = JumpStartVersionedModelId ("test-model" , "*" )
1209+
1210+ cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
1211+ result = cache ._get_proprietary_manifest_key_from_model_id ( key = key , value = None )
1212+
1213+ assert_key = JumpStartVersionedModelId ("test-model" , "2.16.0" )
1214+
1215+ assert result == assert_key
1216+
1217+
1218+ @patch .object (JumpStartModelsCache , "_retrieval_function" )
1219+ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver (
1220+ retrieval_function : Mock
1221+ ):
1222+ sm_version = Version (utils .get_sagemaker_version ())
1223+ new_sm_version = Version (str (sm_version .major + 1 ) + ".0.0" )
1224+ print (str (new_sm_version ))
1225+ versions = ["abc" , "2.9.1" , "2.16.0" ]
1226+ manifest = [
1227+ {
1228+ "model_id" : "test-model" ,
1229+ "version" : version ,
1230+ "min_version" : "2.49.0" ,
1231+ "spec_key" : "spec_key"
1232+ }
1233+ for version in versions
1234+ ]
1235+
1236+ manifest_dict = {}
1237+ for header in manifest :
1238+ header_obj = JumpStartModelHeader (header )
1239+ manifest_dict [JumpStartVersionedModelId (header_obj .model_id , header_obj .version )] = (
1240+ header_obj
1241+ )
1242+ retrieval_function .return_value = JumpStartCachedContentValue (
1243+ formatted_content = manifest_dict
1244+ )
1245+ key = JumpStartVersionedModelId ("test-model" , "*" )
1246+
1247+ cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
1248+ result = cache ._get_open_weight_manifest_key_from_model_id ( key = key , value = None )
1249+
1250+ assert_key = JumpStartVersionedModelId ("test-model" , "abc" )
1251+
1252+ assert result == assert_key
0 commit comments