Skip to content

Commit dbbc9ed

Browse files
author
Eli Davidson
committed
run linter
1 parent 8219160 commit dbbc9ed

File tree

3 files changed

+22
-13
lines changed

3 files changed

+22
-13
lines changed

src/sagemaker/jumpstart/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1631,7 +1631,9 @@ def get_draft_model_content_bucket(provider: Dict, region: str) -> str:
16311631
return get_jumpstart_content_bucket(region=region)
16321632
return neo_bucket
16331633

1634+
16341635
def get_latest_version(versions: List[str]) -> Optional[str]:
1635-
try: return None if not versions else max(versions, key=Version)
1636+
try:
1637+
return None if not versions else max(versions, key=Version)
16361638
except InvalidVersion as e:
16371639
return max(versions)

tests/unit/sagemaker/jumpstart/test_cache.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,9 +1124,10 @@ def test_jumpstart_local_metadata_override_specs_not_exist_both_directories(
11241124
]
11251125
)
11261126

1127+
11271128
@patch.object(JumpStartModelsCache, "_retrieval_function")
11281129
def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
1129-
retrieval_function: Mock
1130+
retrieval_function: Mock
11301131
):
11311132
sm_version = Version(utils.get_sagemaker_version())
11321133
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
@@ -1150,7 +1151,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
11501151
"spec_key": "spec_key"
11511152
}
11521153
)
1153-
1154+
11541155
manifest_dict = {}
11551156
for header in manifest:
11561157
header_obj = JumpStartModelHeader(header)
@@ -1163,15 +1164,16 @@ def test_jumpstart_cache_handles_versioning_correctly_for_open_source_weights(
11631164
key = JumpStartVersionedModelId("test-model", "*")
11641165

11651166
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
1166-
result = cache._get_open_weight_manifest_key_from_model_id( key = key, value = None )
1167+
result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None)
11671168

11681169
assert_key = JumpStartVersionedModelId("test-model", "2.16.0")
11691170

11701171
assert result == assert_key
11711172

1173+
11721174
@patch.object(JumpStartModelsCache, "_retrieval_function")
11731175
def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
1174-
retrieval_function: Mock
1176+
retrieval_function: Mock
11751177
):
11761178
sm_version = Version(utils.get_sagemaker_version())
11771179
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
@@ -1195,7 +1197,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
11951197
"spec_key": "spec_key"
11961198
}
11971199
)
1198-
1200+
11991201
manifest_dict = {}
12001202
for header in manifest:
12011203
header_obj = JumpStartModelHeader(header)
@@ -1208,7 +1210,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
12081210
key = JumpStartVersionedModelId("test-model", "*")
12091211

12101212
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
1211-
result = cache._get_proprietary_manifest_key_from_model_id( key = key, value = None )
1213+
result = cache._get_proprietary_manifest_key_from_model_id(key=key, value=None)
12121214

12131215
assert_key = JumpStartVersionedModelId("test-model", "2.16.0")
12141216

@@ -1217,7 +1219,7 @@ def test_jumpstart_cache_handles_versioning_correctly_for_proprietary_weights(
12171219

12181220
@patch.object(JumpStartModelsCache, "_retrieval_function")
12191221
def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(
1220-
retrieval_function: Mock
1222+
retrieval_function: Mock
12211223
):
12221224
sm_version = Version(utils.get_sagemaker_version())
12231225
new_sm_version = Version(str(sm_version.major + 1) + ".0.0")
@@ -1232,7 +1234,7 @@ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(
12321234
}
12331235
for version in versions
12341236
]
1235-
1237+
12361238
manifest_dict = {}
12371239
for header in manifest:
12381240
header_obj = JumpStartModelHeader(header)
@@ -1245,8 +1247,8 @@ def test_jumpstart_cache_handles_versioning_correctly_non_sem_ver(
12451247
key = JumpStartVersionedModelId("test-model", "*")
12461248

12471249
cache = JumpStartModelsCache(s3_bucket_name="some_bucket")
1248-
result = cache._get_open_weight_manifest_key_from_model_id( key = key, value = None )
1250+
result = cache._get_open_weight_manifest_key_from_model_id(key=key, value=None)
12491251

12501252
assert_key = JumpStartVersionedModelId("test-model", "abc")
12511253

1252-
assert result == assert_key
1254+
assert result == assert_key

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2143,18 +2143,23 @@ def test_add_instance_rate_stats_to_benchmark_metrics_client_ex(
21432143
def test_has_instance_rate_stat(stats, expected):
21442144
assert utils.has_instance_rate_stat(stats) is expected
21452145

2146+
21462147
def test_get_latest_version():
21472148
assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0"]) == "2.16.0"
21482149

2150+
21492151
def test_get_latest_version_empty_list_is_none():
2150-
assert utils.get_latest_version([]) == None
2152+
assert utils.get_latest_version([]) is None
2153+
21512154

21522155
def test_get_latest_version_none_is_none():
2153-
assert utils.get_latest_version(None) == None
2156+
assert utils.get_latest_version(None) is None
2157+
21542158

21552159
def test_get_latest_version_with_invalid_sem_ver():
21562160
assert utils.get_latest_version(["2.9.1", "2.16.0", "1.0.0", "abc"]) == "abc"
21572161

2162+
21582163
@pytest.mark.parametrize(
21592164
"data, expected",
21602165
[(None, []), ([], []), (get_base_deployment_configs_metadata(), get_base_deployment_configs())],

0 commit comments

Comments
 (0)