@@ -1127,7 +1127,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
11271127
11281128@patch .object (JumpStartModelsCache , "_retrieval_function" )
11291129def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights (
1130- retrieval_function : Mock
1130+ retrieval_function : Mock ,
11311131):
11321132 sm_version = Version (utils .get_sagemaker_version ())
11331133 new_sm_version = Version (str (sm_version .major + 1 ) + ".0.0" )
@@ -1138,7 +1138,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
11381138 "model_id" : "test-model" ,
11391139 "version" : version ,
11401140 "min_version" : "2.49.0" ,
1141- "spec_key" : "spec_key"
1141+ "spec_key" : "spec_key" ,
11421142 }
11431143 for version in versions
11441144 ]
@@ -1148,7 +1148,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
11481148 "model_id" : "test-model" ,
11491149 "version" : "3.0.0" ,
11501150 "min_version" : str (new_sm_version ),
1151- "spec_key" : "spec_key"
1151+ "spec_key" : "spec_key" ,
11521152 }
11531153 )
11541154
@@ -1158,9 +1158,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
11581158 manifest_dict [JumpStartVersionedModelId (header_obj .model_id , header_obj .version )] = (
11591159 header_obj
11601160 )
1161- retrieval_function .return_value = JumpStartCachedContentValue (
1162- formatted_content = manifest_dict
1163- )
1161+ retrieval_function .return_value = JumpStartCachedContentValue (formatted_content = manifest_dict )
11641162 key = JumpStartVersionedModelId ("test-model" , "*" )
11651163
11661164 cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
@@ -1173,7 +1171,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
11731171
11741172@patch .object (JumpStartModelsCache , "_retrieval_function" )
11751173def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights (
1176- retrieval_function : Mock
1174+ retrieval_function : Mock ,
11771175):
11781176 sm_version = Version (utils .get_sagemaker_version ())
11791177 new_sm_version = Version (str (sm_version .major + 1 ) + ".0.0" )
@@ -1184,7 +1182,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
11841182 "model_id" : "test-model" ,
11851183 "version" : version ,
11861184 "min_version" : "2.49.0" ,
1187- "spec_key" : "spec_key"
1185+ "spec_key" : "spec_key" ,
11881186 }
11891187 for version in versions
11901188 ]
@@ -1194,7 +1192,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
11941192 "model_id" : "test-model" ,
11951193 "version" : "3.0.0" ,
11961194 "min_version" : str (new_sm_version ),
1197- "spec_key" : "spec_key"
1195+ "spec_key" : "spec_key" ,
11981196 }
11991197 )
12001198
@@ -1204,9 +1202,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
12041202 manifest_dict [JumpStartVersionedModelId (header_obj .model_id , header_obj .version )] = (
12051203 header_obj
12061204 )
1207- retrieval_function .return_value = JumpStartCachedContentValue (
1208- formatted_content = manifest_dict
1209- )
1205+ retrieval_function .return_value = JumpStartCachedContentValue (formatted_content = manifest_dict )
12101206 key = JumpStartVersionedModelId ("test-model" , "*" )
12111207
12121208 cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
@@ -1218,9 +1214,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
12181214
12191215
12201216@patch .object (JumpStartModelsCache , "_retrieval_function" )
1221- def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver (
1222- retrieval_function : Mock
1223- ):
1217+ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver (retrieval_function : Mock ):
12241218 sm_version = Version (utils .get_sagemaker_version ())
12251219 new_sm_version = Version (str (sm_version .major + 1 ) + ".0.0" )
12261220 print (str (new_sm_version ))
@@ -1230,7 +1224,7 @@ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(
12301224 "model_id" : "test-model" ,
12311225 "version" : version ,
12321226 "min_version" : "2.49.0" ,
1233- "spec_key" : "spec_key"
1227+ "spec_key" : "spec_key" ,
12341228 }
12351229 for version in versions
12361230 ]
@@ -1241,9 +1235,7 @@ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(
12411235 manifest_dict [JumpStartVersionedModelId (header_obj .model_id , header_obj .version )] = (
12421236 header_obj
12431237 )
1244- retrieval_function .return_value = JumpStartCachedContentValue (
1245- formatted_content = manifest_dict
1246- )
1238+ retrieval_function .return_value = JumpStartCachedContentValue (formatted_content = manifest_dict )
12471239 key = JumpStartVersionedModelId ("test-model" , "*" )
12481240
12491241 cache = JumpStartModelsCache (s3_bucket_name = "some_bucket" )
0 commit comments