@@ -1127,7 +1127,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
1127
1127
1128
1128
@patch .object (JumpStartModelsCache , "_retrieval_function" )
1129
1129
def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights (
1130
- retrieval_function : Mock
1130
+ retrieval_function : Mock ,
1131
1131
):
1132
1132
sm_version = Version (utils .get_sagemaker_version ())
1133
1133
new_sm_version = Version (str (sm_version .major + 1 ) + ".0.0" )
@@ -1138,7 +1138,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
1138
1138
"model_id" : "test-model" ,
1139
1139
"version" : version ,
1140
1140
"min_version" : "2.49.0" ,
1141
- "spec_key" : "spec_key"
1141
+ "spec_key" : "spec_key" ,
1142
1142
}
1143
1143
for version in versions
1144
1144
]
@@ -1148,7 +1148,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
1148
1148
"model_id" : "test-model" ,
1149
1149
"version" : "3.0.0" ,
1150
1150
"min_version" : str (new_sm_version ),
1151
- "spec_key" : "spec_key"
1151
+ "spec_key" : "spec_key" ,
1152
1152
}
1153
1153
)
1154
1154
@@ -1158,9 +1158,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
1158
1158
manifest_dict [JumpStartVersionedModelId (header_obj .model_id , header_obj .version )] = (
1159
1159
header_obj
1160
1160
)
1161
- retrieval_function .return_value = JumpStartCachedContentValue (
1162
- formatted_content = manifest_dict
1163
- )
1161
+ retrieval_function .return_value = JumpStartCachedContentValue (formatted_content = manifest_dict )
1164
1162
key = JumpStartVersionedModelId ("test-model" , "*" )
1165
1163
1166
1164
cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
@@ -1173,7 +1171,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
1173
1171
1174
1172
@patch .object (JumpStartModelsCache , "_retrieval_function" )
1175
1173
def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights (
1176
- retrieval_function : Mock
1174
+ retrieval_function : Mock ,
1177
1175
):
1178
1176
sm_version = Version (utils .get_sagemaker_version ())
1179
1177
new_sm_version = Version (str (sm_version .major + 1 ) + ".0.0" )
@@ -1184,7 +1182,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
1184
1182
"model_id" : "test-model" ,
1185
1183
"version" : version ,
1186
1184
"min_version" : "2.49.0" ,
1187
- "spec_key" : "spec_key"
1185
+ "spec_key" : "spec_key" ,
1188
1186
}
1189
1187
for version in versions
1190
1188
]
@@ -1194,7 +1192,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
1194
1192
"model_id" : "test-model" ,
1195
1193
"version" : "3.0.0" ,
1196
1194
"min_version" : str (new_sm_version ),
1197
- "spec_key" : "spec_key"
1195
+ "spec_key" : "spec_key" ,
1198
1196
}
1199
1197
)
1200
1198
@@ -1204,9 +1202,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
1204
1202
manifest_dict [JumpStartVersionedModelId (header_obj .model_id , header_obj .version )] = (
1205
1203
header_obj
1206
1204
)
1207
- retrieval_function .return_value = JumpStartCachedContentValue (
1208
- formatted_content = manifest_dict
1209
- )
1205
+ retrieval_function .return_value = JumpStartCachedContentValue (formatted_content = manifest_dict )
1210
1206
key = JumpStartVersionedModelId ("test-model" , "*" )
1211
1207
1212
1208
cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
@@ -1218,9 +1214,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
1218
1214
1219
1215
1220
1216
@patch .object (JumpStartModelsCache , "_retrieval_function" )
1221
- def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver (
1222
- retrieval_function : Mock
1223
- ):
1217
+ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver (retrieval_function : Mock ):
1224
1218
sm_version = Version (utils .get_sagemaker_version ())
1225
1219
new_sm_version = Version (str (sm_version .major + 1 ) + ".0.0" )
1226
1220
print (str (new_sm_version ))
@@ -1230,7 +1224,7 @@ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(
1230
1224
"model_id" : "test-model" ,
1231
1225
"version" : version ,
1232
1226
"min_version" : "2.49.0" ,
1233
- "spec_key" : "spec_key"
1227
+ "spec_key" : "spec_key" ,
1234
1228
}
1235
1229
for version in versions
1236
1230
]
@@ -1241,9 +1235,7 @@ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(
1241
1235
manifest_dict [JumpStartVersionedModelId (header_obj .model_id , header_obj .version )] = (
1242
1236
header_obj
1243
1237
)
1244
- retrieval_function .return_value = JumpStartCachedContentValue (
1245
- formatted_content = manifest_dict
1246
- )
1238
+ retrieval_function .return_value = JumpStartCachedContentValue (formatted_content = manifest_dict )
1247
1239
key = JumpStartVersionedModelId ("test-model" , "*" )
1248
1240
1249
1241
cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
0 commit comments