Skip to content

Commit 65cb5b3

Browse files
committed
fix formatting issues
1 parent 3b147cd commit 65cb5b3

File tree

5 files changed

+117
-134
lines changed

5 files changed

+117
-134
lines changed

src/sagemaker/jumpstart/utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1578,13 +1578,17 @@ def _add_model_access_configs_to_model_data_sources(
15781578
),
15791579
)
15801580
)
1581-
mutable_model_data_source.pop("HostingEulaKey") # pop when model access config is applied
1581+
mutable_model_data_source.pop(
1582+
"HostingEulaKey"
1583+
) # pop when model access config is applied
15821584
mutable_model_data_source["S3DataSource"]["ModelAccessConfig"] = (
15831585
camel_case_to_pascal_case(model_access_configs.get(model_id).model_dump())
15841586
)
15851587
acked_model_data_sources.append(mutable_model_data_source)
15861588
else:
1587-
mutable_model_data_source.pop("HostingEulaKey") # pop when model access config is not applicable
1589+
mutable_model_data_source.pop(
1590+
"HostingEulaKey"
1591+
) # pop when model access config is not applicable
15881592
acked_model_data_sources.append(mutable_model_data_source)
15891593
return acked_model_data_sources
15901594

src/sagemaker/serve/validations/check_optimization_configurations.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -801,9 +801,14 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
801801
js_class_deploy = JumpStartModel.deploy
802802
js_class_deploy_args = set(signature(js_class_deploy).parameters.keys())
803803

804-
assert js_class_deploy_args - parent_class_deploy_args - deploy_args_removed_at_deploy_time == set()
805-
assert (parent_class_deploy_args - js_class_deploy_args - deploy_args_removed_at_deploy_time ==
806-
deploy_args_to_skip)
804+
assert (
805+
js_class_deploy_args - parent_class_deploy_args - deploy_args_removed_at_deploy_time
806+
== set()
807+
)
808+
assert (
809+
parent_class_deploy_args - js_class_deploy_args - deploy_args_removed_at_deploy_time
810+
== deploy_args_to_skip
811+
)
807812

808813
@mock.patch(
809814
"sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
@@ -1775,18 +1780,17 @@ def test_model_set_deployment_config(
17751780
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
17761781
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
17771782
def test_model_set_deployment_config_and_deploy_for_gated_draft_model(
1778-
self,
1779-
mock_model_deploy: mock.Mock,
1780-
mock_get_model_specs: mock.Mock,
1781-
mock_session: mock.Mock,
1782-
mock_get_manifest: mock.Mock,
1783-
mock_get_jumpstart_configs: mock.Mock,
1783+
self,
1784+
mock_model_deploy: mock.Mock,
1785+
mock_get_model_specs: mock.Mock,
1786+
mock_session: mock.Mock,
1787+
mock_get_manifest: mock.Mock,
1788+
mock_get_jumpstart_configs: mock.Mock,
17841789
):
17851790
# WHERE
17861791
mock_get_model_specs.side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec
17871792
mock_get_manifest.side_effect = (
1788-
lambda region, model_type, *args, **kwargs:
1789-
get_prototype_manifest(region, model_type)
1793+
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
17901794
)
17911795
mock_model_deploy.return_value = default_predictor
17921796

@@ -1799,7 +1803,11 @@ def test_model_set_deployment_config_and_deploy_for_gated_draft_model(
17991803
assert model.config_name is None
18001804

18011805
# WHEN
1802-
model.deploy(model_access_configs={"pytorch-eqa-bert-base-cased":ModelAccessConfig(accept_eula=True)})
1806+
model.deploy(
1807+
model_access_configs={
1808+
"pytorch-eqa-bert-base-cased": ModelAccessConfig(accept_eula=True)
1809+
}
1810+
)
18031811

18041812
# THEN
18051813
mock_model_deploy.assert_called_once_with(
@@ -1822,18 +1830,17 @@ def test_model_set_deployment_config_and_deploy_for_gated_draft_model(
18221830
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
18231831
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
18241832
def test_model_set_deployment_config_and_deploy_for_gated_draft_model_no_model_access_configs(
1825-
self,
1826-
mock_model_deploy: mock.Mock,
1827-
mock_get_model_specs: mock.Mock,
1828-
mock_session: mock.Mock,
1829-
mock_get_manifest: mock.Mock,
1830-
mock_get_jumpstart_configs: mock.Mock,
1833+
self,
1834+
mock_model_deploy: mock.Mock,
1835+
mock_get_model_specs: mock.Mock,
1836+
mock_session: mock.Mock,
1837+
mock_get_manifest: mock.Mock,
1838+
mock_get_jumpstart_configs: mock.Mock,
18311839
):
18321840
# WHERE
18331841
mock_get_model_specs.side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec
18341842
mock_get_manifest.side_effect = (
1835-
lambda region, model_type, *args, **kwargs:
1836-
get_prototype_manifest(region, model_type)
1843+
lambda region, model_type, *args, **kwargs: get_prototype_manifest(region, model_type)
18371844
)
18381845
mock_model_deploy.return_value = default_predictor
18391846

tests/unit/sagemaker/jumpstart/test_utils.py

Lines changed: 66 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -2175,48 +2175,46 @@ class TestAcceptEulaModelAccessConfig(TestCase):
21752175
MOCK_PUBLIC_MODEL_ID = "mock_public_model_id"
21762176
MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL = [
21772177
{
2178-
'ChannelName': 'draft_model',
2179-
'S3DataSource': {
2180-
'CompressionType': 'None',
2181-
'S3DataType': 'S3Prefix',
2182-
'S3Uri': 's3://jumpstart_bucket/path/to/public/resources/'
2178+
"ChannelName": "draft_model",
2179+
"S3DataSource": {
2180+
"CompressionType": "None",
2181+
"S3DataType": "S3Prefix",
2182+
"S3Uri": "s3://jumpstart_bucket/path/to/public/resources/",
21832183
},
2184-
'HostingEulaKey': None
2184+
"HostingEulaKey": None,
21852185
}
21862186
]
21872187
MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL = [
21882188
{
2189-
'ChannelName': 'draft_model',
2190-
'S3DataSource': {
2191-
'CompressionType': 'None',
2192-
'S3DataType': 'S3Prefix',
2193-
'S3Uri': 's3://jumpstart_bucket/path/to/public/resources/'
2194-
}
2189+
"ChannelName": "draft_model",
2190+
"S3DataSource": {
2191+
"CompressionType": "None",
2192+
"S3DataType": "S3Prefix",
2193+
"S3Uri": "s3://jumpstart_bucket/path/to/public/resources/",
2194+
},
21952195
}
21962196
]
21972197
MOCK_GATED_MODEL_ID = "mock_gated_model_id"
21982198
MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL = [
21992199
{
2200-
'ChannelName': 'draft_model',
2201-
'S3DataSource': {
2202-
'CompressionType': 'None',
2203-
'S3DataType': 'S3Prefix',
2204-
'S3Uri': 's3://jumpstart_bucket/path/to/gated/resources/'
2200+
"ChannelName": "draft_model",
2201+
"S3DataSource": {
2202+
"CompressionType": "None",
2203+
"S3DataType": "S3Prefix",
2204+
"S3Uri": "s3://jumpstart_bucket/path/to/gated/resources/",
22052205
},
2206-
'HostingEulaKey': "fmhMetadata/eula/llama3_2Eula.txt"
2206+
"HostingEulaKey": "fmhMetadata/eula/llama3_2Eula.txt",
22072207
}
22082208
]
22092209
MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL = [
22102210
{
2211-
'ChannelName': 'draft_model',
2212-
'S3DataSource': {
2213-
'CompressionType': 'None',
2214-
'S3DataType': 'S3Prefix',
2215-
'S3Uri': 's3://jumpstart_bucket/path/to/gated/resources/',
2216-
'ModelAccessConfig': {
2217-
"AcceptEula": True
2218-
}
2219-
}
2211+
"ChannelName": "draft_model",
2212+
"S3DataSource": {
2213+
"CompressionType": "None",
2214+
"S3DataType": "S3Prefix",
2215+
"S3Uri": "s3://jumpstart_bucket/path/to/gated/resources/",
2216+
"ModelAccessConfig": {"AcceptEula": True},
2217+
},
22202218
}
22212219
]
22222220

@@ -2232,14 +2230,17 @@ def test_public_additional_model_data_source_should_pass_through(self):
22322230
)
22332231

22342232
# THEN
2235-
assert additional_model_data_sources == self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2233+
assert (
2234+
additional_model_data_sources
2235+
== self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2236+
)
22362237

22372238
def test_multiple_public_additional_model_data_source_should_pass_through_both(self):
22382239
# WHERE / WHEN
22392240
additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources(
22402241
model_data_sources=(
2241-
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL +
22422242
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2243+
+ self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
22432244
),
22442245
model_access_configs=None,
22452246
model_id=self.MOCK_PUBLIC_MODEL_ID,
@@ -2248,23 +2249,24 @@ def test_multiple_public_additional_model_data_source_should_pass_through_both(s
22482249

22492250
# THEN
22502251
assert additional_model_data_sources == (
2251-
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL +
22522252
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2253+
+ self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
22532254
)
22542255

22552256
def test_public_additional_model_data_source_with_model_access_config_should_ignored_it(self):
22562257
# WHERE / WHEN
22572258
additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources(
22582259
model_data_sources=self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL,
2259-
model_access_configs={
2260-
self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True)
2261-
},
2260+
model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True)},
22622261
model_id=self.MOCK_GATED_MODEL_ID,
22632262
region=JUMPSTART_DEFAULT_REGION_NAME,
22642263
)
22652264

22662265
# THEN
2267-
assert additional_model_data_sources == self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2266+
assert (
2267+
additional_model_data_sources
2268+
== self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2269+
)
22682270

22692271
def test_no_additional_model_data_source_should_pass_through(self):
22702272
# WHERE / WHEN
@@ -2284,62 +2286,65 @@ def test_gated_additional_model_data_source_should_accept_it(self):
22842286
# WHERE / WHEN
22852287
additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources(
22862288
model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL,
2287-
model_access_configs={
2288-
self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True)
2289-
},
2289+
model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True)},
22902290
model_id=self.MOCK_GATED_MODEL_ID,
22912291
region=JUMPSTART_DEFAULT_REGION_NAME,
22922292
)
22932293

22942294
# THEN
2295-
assert additional_model_data_sources == self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2295+
assert (
2296+
additional_model_data_sources
2297+
== self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2298+
)
22962299

22972300
def test_multiple_gated_additional_model_data_source_should_accept_both(self):
22982301
# WHERE / WHEN
22992302
additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources(
23002303
model_data_sources=(
2301-
self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL +
23022304
self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2305+
+ self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
23032306
),
23042307
model_access_configs={
2305-
self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True),
2306-
self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True)
2308+
self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True),
2309+
self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True),
23072310
},
23082311
model_id=self.MOCK_GATED_MODEL_ID,
23092312
region=JUMPSTART_DEFAULT_REGION_NAME,
23102313
)
23112314

23122315
# THEN
23132316
assert additional_model_data_sources == (
2314-
self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL +
23152317
self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2318+
+ self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
23162319
)
23172320

23182321
# Mixed Positive Cases
23192322

2320-
def test_multiple_mixed_additional_model_data_source_should_pass_through_one_accept_the_other(self):
2323+
def test_multiple_mixed_additional_model_data_source_should_pass_through_one_accept_the_other(
2324+
self,
2325+
):
23212326
# WHERE / WHEN
23222327
additional_model_data_sources = utils._add_model_access_configs_to_model_data_sources(
23232328
model_data_sources=(
2324-
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL +
2325-
self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2329+
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2330+
+ self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
23262331
),
2327-
model_access_configs={
2328-
self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=True)
2329-
},
2332+
model_access_configs={self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=True)},
23302333
model_id=self.MOCK_GATED_MODEL_ID,
23312334
region=JUMPSTART_DEFAULT_REGION_NAME,
23322335
)
23332336

23342337
# THEN
23352338
assert additional_model_data_sources == (
2336-
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL +
2337-
self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2339+
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
2340+
+ self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_POST_CALL
23382341
)
23392342

23402343
# Test Gated Negative Tests
23412344

2342-
def test_gated_additional_model_data_source_no_model_access_config_should_raise_value_error(self):
2345+
def test_gated_additional_model_data_source_no_model_access_config_should_raise_value_error(
2346+
self,
2347+
):
23432348
# WHERE / WHEN / THEN
23442349
with self.assertRaises(ValueError):
23452350
utils._add_model_access_configs_to_model_data_sources(
@@ -2354,33 +2359,37 @@ def test_multiple_mixed_additional_no_model_data_source_should_raise_value_error
23542359
with self.assertRaises(ValueError):
23552360
utils._add_model_access_configs_to_model_data_sources(
23562361
model_data_sources=(
2357-
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL +
2358-
self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2362+
self.MOCK_PUBLIC_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
2363+
+ self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL
23592364
),
23602365
model_access_configs=None,
23612366
model_id=self.MOCK_GATED_MODEL_ID,
23622367
region=JUMPSTART_DEFAULT_REGION_NAME,
23632368
)
23642369

2365-
def test_gated_additional_model_data_source_wrong_model_access_config_should_raise_value_error(self):
2370+
def test_gated_additional_model_data_source_wrong_model_access_config_should_raise_value_error(
2371+
self,
2372+
):
23662373
# WHERE / WHEN / THEN
23672374
with self.assertRaises(ValueError):
23682375
utils._add_model_access_configs_to_model_data_sources(
23692376
model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL,
23702377
model_access_configs={
2371-
self.MOCK_PUBLIC_MODEL_ID:ModelAccessConfig(accept_eula=True)
2378+
self.MOCK_PUBLIC_MODEL_ID: ModelAccessConfig(accept_eula=True)
23722379
},
23732380
model_id=self.MOCK_GATED_MODEL_ID,
23742381
region=JUMPSTART_DEFAULT_REGION_NAME,
23752382
)
23762383

2377-
def test_gated_additional_model_data_source_false_model_access_config_should_raise_value_error(self):
2384+
def test_gated_additional_model_data_source_false_model_access_config_should_raise_value_error(
2385+
self,
2386+
):
23782387
# WHERE / WHEN / THEN
23792388
with self.assertRaises(ValueError):
23802389
utils._add_model_access_configs_to_model_data_sources(
23812390
model_data_sources=self.MOCK_GATED_DEPLOY_CONFIG_ADDITIONAL_MODEL_DATA_SOURCE_PRE_CALL,
23822391
model_access_configs={
2383-
self.MOCK_GATED_MODEL_ID:ModelAccessConfig(accept_eula=False)
2392+
self.MOCK_GATED_MODEL_ID: ModelAccessConfig(accept_eula=False)
23842393
},
23852394
model_id=self.MOCK_GATED_MODEL_ID,
23862395
region=JUMPSTART_DEFAULT_REGION_NAME,

0 commit comments

Comments
 (0)