@@ -1124,9 +1124,10 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
1124
1124
]
1125
1125
)
1126
1126
1127
+
1127
1128
@patch .object (JumpStartModelsCache , "_retrieval_function" )
1128
1129
def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights (
1129
- retrieval_function : Mock
1130
+ retrieval_function : Mock
1130
1131
):
1131
1132
sm_version = Version (utils .get_sagemaker_version ())
1132
1133
new_sm_version = Version (str (sm_version .major + 1 ) + ".0.0" )
@@ -1150,7 +1151,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
1150
1151
"spec_key" : "spec_key"
1151
1152
}
1152
1153
)
1153
-
1154
+
1154
1155
manifest_dict = {}
1155
1156
for header in manifest :
1156
1157
header_obj = JumpStartModelHeader (header )
@@ -1163,15 +1164,16 @@ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
1163
1164
key = JumpStartVersionedModelId ("test-model" , "*" )
1164
1165
1165
1166
cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
1166
- result = cache ._get_open_weight_manifest_key_from_model_id ( key = key , value = None )
1167
+ result = cache ._get_open_weight_manifest_key_from_model_id (key = key , value = None )
1167
1168
1168
1169
assert_key = JumpStartVersionedModelId ("test-model" , "2.16.0" )
1169
1170
1170
1171
assert result == assert_key
1171
1172
1173
+
1172
1174
@patch .object (JumpStartModelsCache , "_retrieval_function" )
1173
1175
def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights (
1174
- retrieval_function : Mock
1176
+ retrieval_function : Mock
1175
1177
):
1176
1178
sm_version = Version (utils .get_sagemaker_version ())
1177
1179
new_sm_version = Version (str (sm_version .major + 1 ) + ".0.0" )
@@ -1195,7 +1197,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
1195
1197
"spec_key" : "spec_key"
1196
1198
}
1197
1199
)
1198
-
1200
+
1199
1201
manifest_dict = {}
1200
1202
for header in manifest :
1201
1203
header_obj = JumpStartModelHeader (header )
@@ -1208,7 +1210,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
1208
1210
key = JumpStartVersionedModelId ("test-model" , "*" )
1209
1211
1210
1212
cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
1211
- result = cache ._get_proprietary_manifest_key_from_model_id ( key = key , value = None )
1213
+ result = cache ._get_proprietary_manifest_key_from_model_id (key = key , value = None )
1212
1214
1213
1215
assert_key = JumpStartVersionedModelId ("test-model" , "2.16.0" )
1214
1216
@@ -1217,7 +1219,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
1217
1219
1218
1220
@patch .object (JumpStartModelsCache , "_retrieval_function" )
1219
1221
def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver (
1220
- retrieval_function : Mock
1222
+ retrieval_function : Mock
1221
1223
):
1222
1224
sm_version = Version (utils .get_sagemaker_version ())
1223
1225
new_sm_version = Version (str (sm_version .major + 1 ) + ".0.0" )
@@ -1232,7 +1234,7 @@ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(
1232
1234
}
1233
1235
for version in versions
1234
1236
]
1235
-
1237
+
1236
1238
manifest_dict = {}
1237
1239
for header in manifest :
1238
1240
header_obj = JumpStartModelHeader (header )
@@ -1245,8 +1247,8 @@ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(
1245
1247
key = JumpStartVersionedModelId ("test-model" , "*" )
1246
1248
1247
1249
cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
1248
- result = cache ._get_open_weight_manifest_key_from_model_id ( key = key , value = None )
1250
+ result = cache ._get_open_weight_manifest_key_from_model_id (key = key , value = None )
1249
1251
1250
1252
assert_key = JumpStartVersionedModelId ("test-model" , "abc" )
1251
1253
1252
- assert result == assert_key
1254
+ assert result == assert_key
0 commit comments