Skip to content

Commit f8fd094

Browse files
author
Eli Davidson
committed
lint
1 parent dbbc9ed commit f8fd094

File tree

1 file changed

+11
-19
lines changed

1 file changed

+11
-19
lines changed

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,7 +1127,7 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
11271127

11281128
@patch.object(JumpStartModelsCache, "_retrieval_function")
11291129
def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
1130-
retrieval_function: Mock
1130+
retrieval_function: Mock,
11311131
):
11321132
sm_version = Version(utils.get_sagemaker_version())
11331133
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
@@ -1138,7 +1138,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
11381138
"model_id": "test-model",
11391139
"version": version,
11401140
"min_version": "2.49.0",
1141-
"spec_key": "spec_key"
1141+
"spec_key": "spec_key",
11421142
}
11431143
for version in versions
11441144
]
@@ -1148,7 +1148,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
11481148
"model_id": "test-model",
11491149
"version": "3.0.0",
11501150
"min_version": str(new_sm_version),
1151-
"spec_key": "spec_key"
1151+
"spec_key": "spec_key",
11521152
}
11531153
)
11541154

@@ -1158,9 +1158,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
11581158
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
11591159
header_obj
11601160
)
1161-
retrieval_function.return_value = JumpStartCachedContentValue(
1162-
formatted_content=manifest_dict
1163-
)
1161+
retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict)
11641162
key = JumpStartVersionedModelId("test-model", "*")
11651163

11661164
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
@@ -1173,7 +1171,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
11731171

11741172
@patch.object(JumpStartModelsCache, "_retrieval_function")
11751173
def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
1176-
retrieval_function: Mock
1174+
retrieval_function: Mock,
11771175
):
11781176
sm_version = Version(utils.get_sagemaker_version())
11791177
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
@@ -1184,7 +1182,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
11841182
"model_id": "test-model",
11851183
"version": version,
11861184
"min_version": "2.49.0",
1187-
"spec_key": "spec_key"
1185+
"spec_key": "spec_key",
11881186
}
11891187
for version in versions
11901188
]
@@ -1194,7 +1192,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
11941192
"model_id": "test-model",
11951193
"version": "3.0.0",
11961194
"min_version": str(new_sm_version),
1197-
"spec_key": "spec_key"
1195+
"spec_key": "spec_key",
11981196
}
11991197
)
12001198

@@ -1204,9 +1202,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
12041202
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
12051203
header_obj
12061204
)
1207-
retrieval_function.return_value = JumpStartCachedContentValue(
1208-
formatted_content=manifest_dict
1209-
)
1205+
retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict)
12101206
key = JumpStartVersionedModelId("test-model", "*")
12111207

12121208
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
@@ -1218,9 +1214,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
12181214

12191215

12201216
@patch.object(JumpStartModelsCache, "_retrieval_function")
1221-
def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(
1222-
retrieval_function: Mock
1223-
):
1217+
def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(retrieval_function: Mock):
12241218
sm_version = Version(utils.get_sagemaker_version())
12251219
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
12261220
print(str(new_sm_version))
@@ -1230,7 +1224,7 @@ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(
12301224
"model_id": "test-model",
12311225
"version": version,
12321226
"min_version": "2.49.0",
1233-
"spec_key": "spec_key"
1227+
"spec_key": "spec_key",
12341228
}
12351229
for version in versions
12361230
]
@@ -1241,9 +1235,7 @@ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(
12411235
manifest_dict[JumpStartVersionedModelId(header_obj.model_id, header_obj.version)] = (
12421236
header_obj
12431237
)
1244-
retrieval_function.return_value = JumpStartCachedContentValue(
1245-
formatted_content=manifest_dict
1246-
)
1238+
retrieval_function.return_value = JumpStartCachedContentValue(formatted_content=manifest_dict)
12471239
key = JumpStartVersionedModelId("test-model", "*")
12481240

12491241
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")

0 commit comments

Comments
 (0)