Skip to content

Commit 3b147cd

Browse files
committed
add UTs for JumpStart deployment
1 parent 09a54dc commit 3b147cd

File tree

5 files changed

+345
-8
lines changed

5 files changed

+345
-8
lines changed

src/sagemaker/jumpstart/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,12 +817,14 @@ def deploy(
817817
f"{EndpointType.INFERENCE_COMPONENT_BASED} is not supported for Proprietary models."
818818
)
819819

820+
print(self.additional_model_data_sources)
820821
self.additional_model_data_sources = _add_model_access_configs_to_model_data_sources(
821822
self.additional_model_data_sources,
822823
deploy_kwargs.model_access_configs,
823824
deploy_kwargs.model_id,
824825
deploy_kwargs.region,
825826
)
827+
print(self.additional_model_data_sources)
826828

827829
try:
828830
predictor = super(JumpStartModel, self).deploy(**deploy_kwargs.to_kwargs_dict())

src/sagemaker/jumpstart/utils.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1544,15 +1544,21 @@ def _add_model_access_configs_to_model_data_sources(
15441544
region: str,
15451545
):
15461546
"""Sets AcceptEula to True for gated speculative decoding models"""
1547+
print(model_data_sources)
15471548

15481549
if not model_data_sources:
15491550
return model_data_sources
15501551

15511552
acked_model_data_sources = []
15521553
for model_data_source in model_data_sources:
15531554
hosting_eula_key = model_data_source.get("HostingEulaKey")
1555+
mutable_model_data_source = model_data_source.copy()
15541556
if hosting_eula_key:
1555-
if not model_access_configs or not model_access_configs.get(model_id):
1557+
if (
1558+
not model_access_configs
1559+
or not model_access_configs.get(model_id)
1560+
or not model_access_configs.get(model_id).accept_eula
1561+
):
15561562
eula_message_template = (
15571563
"{model_source}{base_eula_message}{model_access_configs_message}"
15581564
)
@@ -1572,14 +1578,14 @@ def _add_model_access_configs_to_model_data_sources(
15721578
),
15731579
)
15741580
)
1575-
acked_model_data_source = model_data_source.copy()
1576-
acked_model_data_source.pop("HostingEulaKey")
1577-
acked_model_data_source["S3DataSource"]["ModelAccessConfig"] = (
1581+
mutable_model_data_source.pop("HostingEulaKey") # pop when model access config is applied
1582+
mutable_model_data_source["S3DataSource"]["ModelAccessConfig"] = (
15781583
camel_case_to_pascal_case(model_access_configs.get(model_id).model_dump())
15791584
)
1580-
acked_model_data_sources.append(acked_model_data_source)
1585+
acked_model_data_sources.append(mutable_model_data_source)
15811586
else:
1582-
acked_model_data_sources.append(model_data_source)
1587+
mutable_model_data_source.pop("HostingEulaKey") # pop when model access config is not applicable
1588+
acked_model_data_sources.append(mutable_model_data_source)
15831589
return acked_model_data_sources
15841590

15851591

tests/unit/sagemaker/jumpstart/model/test_model.py

Lines changed: 90 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import pandas as pd
2020
from mock import MagicMock, Mock
2121
import pytest
22+
from sagemaker_core.shapes import ModelAccessConfig
2223
from sagemaker.async_inference.async_inference_config import AsyncInferenceConfig
2324
from sagemaker.jumpstart.artifacts.environment_variables import (
2425
_retrieve_default_environment_variables,
@@ -54,6 +55,7 @@
5455
get_base_deployment_configs,
5556
get_base_spec_with_prototype_configs_with_missing_benchmarks,
5657
append_instance_stat_metrics,
58+
append_gated_draft_model_specs_to_jumpstart_model_spec,
5759
)
5860
import boto3
5961

@@ -772,6 +774,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
772774

773775
init_args_to_skip: Set[str] = set(["model_reference_arn"])
774776
deploy_args_to_skip: Set[str] = set(["kwargs", "model_reference_arn"])
777+
deploy_args_removed_at_deploy_time: Set[str] = set(["model_access_configs"])
775778

776779
parent_class_init = Model.__init__
777780
parent_class_init_args = set(signature(parent_class_init).parameters.keys())
@@ -798,8 +801,9 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
798801
js_class_deploy = JumpStartModel.deploy
799802
js_class_deploy_args = set(signature(js_class_deploy).parameters.keys())
800803

801-
assert js_class_deploy_args - parent_class_deploy_args == set()
802-
assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip
804+
assert js_class_deploy_args - parent_class_deploy_args - deploy_args_removed_at_deploy_time == set()
805+
assert (parent_class_deploy_args - js_class_deploy_args - deploy_args_removed_at_deploy_time ==
806+
deploy_args_to_skip)
803807

804808
@mock.patch(
805809
"sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
@@ -1762,6 +1766,89 @@ def test_model_set_deployment_config(
17621766
endpoint_logging=False,
17631767
)
17641768

1769+
@mock.patch(
1770+
"sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
1771+
)
1772+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
1773+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
1774+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
1775+
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
1776+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
1777+
def test_model_set_deployment_config_and_deploy_for_gated_draft_model(
1778+
self,
1779+
mock_model_deploy: mock.Mock,
1780+
mock_get_model_specs: mock.Mock,
1781+
mock_session: mock.Mock,
1782+
mock_get_manifest: mock.Mock,
1783+
mock_get_jumpstart_configs: mock.Mock,
1784+
):
1785+
# WHERE
1786+
mock_get_model_specs.side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec
1787+
mock_get_manifest.side_effect = (
1788+
lambda region, model_type, *args, **kwargs:
1789+
get_prototype_manifest(region, model_type)
1790+
)
1791+
mock_model_deploy.return_value = default_predictor
1792+
1793+
model_id = "pytorch-eqa-bert-base-cased"
1794+
1795+
mock_session.return_value = sagemaker_session
1796+
1797+
model = JumpStartModel(model_id=model_id)
1798+
1799+
assert model.config_name is None
1800+
1801+
# WHEN
1802+
model.deploy(model_access_configs={"pytorch-eqa-bert-base-cased":ModelAccessConfig(accept_eula=True)})
1803+
1804+
# THEN
1805+
mock_model_deploy.assert_called_once_with(
1806+
initial_instance_count=1,
1807+
instance_type="ml.p2.xlarge",
1808+
tags=[
1809+
{"Key": JumpStartTag.MODEL_ID, "Value": "pytorch-eqa-bert-base-cased"},
1810+
{"Key": JumpStartTag.MODEL_VERSION, "Value": "1.0.0"},
1811+
],
1812+
wait=True,
1813+
endpoint_logging=False,
1814+
)
1815+
1816+
@mock.patch(
1817+
"sagemaker.jumpstart.model.get_jumpstart_configs", side_effect=lambda *args, **kwargs: {}
1818+
)
1819+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
1820+
@mock.patch("sagemaker.jumpstart.factory.model.Session")
1821+
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs")
1822+
@mock.patch("sagemaker.jumpstart.model.Model.deploy")
1823+
@mock.patch("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME", region)
1824+
def test_model_set_deployment_config_and_deploy_for_gated_draft_model_no_model_access_configs(
1825+
self,
1826+
mock_model_deploy: mock.Mock,
1827+
mock_get_model_specs: mock.Mock,
1828+
mock_session: mock.Mock,
1829+
mock_get_manifest: mock.Mock,
1830+
mock_get_jumpstart_configs: mock.Mock,
1831+
):
1832+
# WHERE
1833+
mock_get_model_specs.side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec
1834+
mock_get_manifest.side_effect = (
1835+
lambda region, model_type, *args, **kwargs:
1836+
get_prototype_manifest(region, model_type)
1837+
)
1838+
mock_model_deploy.return_value = default_predictor
1839+
1840+
model_id = "pytorch-eqa-bert-base-cased"
1841+
1842+
mock_session.return_value = sagemaker_session
1843+
1844+
model = JumpStartModel(model_id=model_id)
1845+
1846+
assert model.config_name is None
1847+
1848+
# WHEN / THEN
1849+
with self.assertRaises(ValueError):
1850+
model.deploy()
1851+
17651852
@mock.patch("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest")
17661853
@mock.patch(
17671854
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
@@ -1810,6 +1897,7 @@ def test_model_deployment_config_additional_model_data_source(
18101897
"S3Uri": "s3://sagemaker-sd-models-prod-us-west-2/key/to/draft/model/artifact/",
18111898
"ModelAccessConfig": {"AcceptEula": False},
18121899
},
1900+
"HostingEulaKey": None,
18131901
}
18141902
],
18151903
)

0 commit comments

Comments
 (0)