22
22
from mock .mock import MagicMock
23
23
import pytest
24
24
from mock import patch
25
+ from packaging .version import Version
25
26
27
+
28
+ from sagemaker .jumpstart import utils
26
29
from sagemaker .jumpstart .cache import (
27
30
JUMPSTART_DEFAULT_MANIFEST_FILE_S3_KEY ,
28
31
JUMPSTART_DEFAULT_PROPRIETARY_MANIFEST_KEY ,
33
36
ENV_VARIABLE_JUMPSTART_SPECS_LOCAL_ROOT_DIR_OVERRIDE ,
34
37
)
35
38
from sagemaker .jumpstart .types import (
39
+ JumpStartCachedContentValue ,
36
40
JumpStartModelHeader ,
37
41
JumpStartModelSpecs ,
38
42
JumpStartVersionedModelId ,
50
54
BASE_PROPRIETARY_SPEC ,
51
55
BASE_PROPRIETARY_MANIFEST ,
52
56
)
53
- from sagemaker .jumpstart .utils import get_jumpstart_content_bucket
57
+ from sagemaker .jumpstart .utils import get_formatted_manifest , get_jumpstart_content_bucket
54
58
55
59
56
60
@patch .object (JumpStartModelsCache , "_retrieval_function" , patched_retrieval_function )
@@ -1119,3 +1123,130 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
1119
1123
),
1120
1124
]
1121
1125
)
1126
+
1127
+ @patch .object (JumpStartModelsCache , "_retrieval_function" )
1128
+ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights (
1129
+ retrieval_function : Mock
1130
+ ):
1131
+ sm_version = Version (utils .get_sagemaker_version ())
1132
+ new_sm_version = Version (str (sm_version .major + 1 ) + ".0.0" )
1133
+ print (str (new_sm_version ))
1134
+ versions = ["1.0.0" , "2.9.1" , "2.16.0" ]
1135
+ manifest = [
1136
+ {
1137
+ "model_id" : "test-model" ,
1138
+ "version" : version ,
1139
+ "min_version" : "2.49.0" ,
1140
+ "spec_key" : "spec_key"
1141
+ }
1142
+ for version in versions
1143
+ ]
1144
+
1145
+ manifest .append (
1146
+ {
1147
+ "model_id" : "test-model" ,
1148
+ "version" : "3.0.0" ,
1149
+ "min_version" : str (new_sm_version ),
1150
+ "spec_key" : "spec_key"
1151
+ }
1152
+ )
1153
+
1154
+ manifest_dict = {}
1155
+ for header in manifest :
1156
+ header_obj = JumpStartModelHeader (header )
1157
+ manifest_dict [JumpStartVersionedModelId (header_obj .model_id , header_obj .version )] = (
1158
+ header_obj
1159
+ )
1160
+ retrieval_function .return_value = JumpStartCachedContentValue (
1161
+ formatted_content = manifest_dict
1162
+ )
1163
+ key = JumpStartVersionedModelId ("test-model" , "*" )
1164
+
1165
+ cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
1166
+ result = cache ._get_open_weight_manifest_key_from_model_id ( key = key , value = None )
1167
+
1168
+ assert_key = JumpStartVersionedModelId ("test-model" , "2.16.0" )
1169
+
1170
+ assert result == assert_key
1171
+
1172
+ @patch .object (JumpStartModelsCache , "_retrieval_function" )
1173
+ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights (
1174
+ retrieval_function : Mock
1175
+ ):
1176
+ sm_version = Version (utils .get_sagemaker_version ())
1177
+ new_sm_version = Version (str (sm_version .major + 1 ) + ".0.0" )
1178
+ print (str (new_sm_version ))
1179
+ versions = ["1.0.0" , "2.9.1" , "2.16.0" ]
1180
+ manifest = [
1181
+ {
1182
+ "model_id" : "test-model" ,
1183
+ "version" : version ,
1184
+ "min_version" : "2.49.0" ,
1185
+ "spec_key" : "spec_key"
1186
+ }
1187
+ for version in versions
1188
+ ]
1189
+
1190
+ manifest .append (
1191
+ {
1192
+ "model_id" : "test-model" ,
1193
+ "version" : "3.0.0" ,
1194
+ "min_version" : str (new_sm_version ),
1195
+ "spec_key" : "spec_key"
1196
+ }
1197
+ )
1198
+
1199
+ manifest_dict = {}
1200
+ for header in manifest :
1201
+ header_obj = JumpStartModelHeader (header )
1202
+ manifest_dict [JumpStartVersionedModelId (header_obj .model_id , header_obj .version )] = (
1203
+ header_obj
1204
+ )
1205
+ retrieval_function .return_value = JumpStartCachedContentValue (
1206
+ formatted_content = manifest_dict
1207
+ )
1208
+ key = JumpStartVersionedModelId ("test-model" , "*" )
1209
+
1210
+ cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
1211
+ result = cache ._get_proprietary_manifest_key_from_model_id ( key = key , value = None )
1212
+
1213
+ assert_key = JumpStartVersionedModelId ("test-model" , "2.16.0" )
1214
+
1215
+ assert result == assert_key
1216
+
1217
+
1218
+ @patch .object (JumpStartModelsCache , "_retrieval_function" )
1219
+ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver (
1220
+ retrieval_function : Mock
1221
+ ):
1222
+ sm_version = Version (utils .get_sagemaker_version ())
1223
+ new_sm_version = Version (str (sm_version .major + 1 ) + ".0.0" )
1224
+ print (str (new_sm_version ))
1225
+ versions = ["abc" , "2.9.1" , "2.16.0" ]
1226
+ manifest = [
1227
+ {
1228
+ "model_id" : "test-model" ,
1229
+ "version" : version ,
1230
+ "min_version" : "2.49.0" ,
1231
+ "spec_key" : "spec_key"
1232
+ }
1233
+ for version in versions
1234
+ ]
1235
+
1236
+ manifest_dict = {}
1237
+ for header in manifest :
1238
+ header_obj = JumpStartModelHeader (header )
1239
+ manifest_dict [JumpStartVersionedModelId (header_obj .model_id , header_obj .version )] = (
1240
+ header_obj
1241
+ )
1242
+ retrieval_function .return_value = JumpStartCachedContentValue (
1243
+ formatted_content = manifest_dict
1244
+ )
1245
+ key = JumpStartVersionedModelId ("test-model" , "*" )
1246
+
1247
+ cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
1248
+ result = cache ._get_open_weight_manifest_key_from_model_id ( key = key , value = None )
1249
+
1250
+ assert_key = JumpStartVersionedModelId ("test-model" , "abc" )
1251
+
1252
+ assert result == assert_key
0 commit comments