1919import pandas as pd
2020from mock import MagicMock , Mock
2121import pytest
22+ from sagemaker_core .shapes import ModelAccessConfig
2223from sagemaker .async_inference .async_inference_config import AsyncInferenceConfig
2324from sagemaker .jumpstart .artifacts .environment_variables import (
2425 _retrieve_default_environment_variables ,
5455 get_base_deployment_configs ,
5556 get_base_spec_with_prototype_configs_with_missing_benchmarks ,
5657 append_instance_stat_metrics ,
58+ append_gated_draft_model_specs_to_jumpstart_model_spec ,
5759)
5860import boto3
5961
@@ -772,6 +774,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
772774
773775 init_args_to_skip : Set [str ] = set (["model_reference_arn" ])
774776 deploy_args_to_skip : Set [str ] = set (["kwargs" , "model_reference_arn" ])
777+ deploy_args_removed_at_deploy_time : Set [str ] = set (["model_access_configs" ])
775778
776779 parent_class_init = Model .__init__
777780 parent_class_init_args = set (signature (parent_class_init ).parameters .keys ())
@@ -798,8 +801,9 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
798801 js_class_deploy = JumpStartModel .deploy
799802 js_class_deploy_args = set (signature (js_class_deploy ).parameters .keys ())
800803
801- assert js_class_deploy_args - parent_class_deploy_args == set ()
802- assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip
804+ assert js_class_deploy_args - parent_class_deploy_args - deploy_args_removed_at_deploy_time == set ()
805+ assert (parent_class_deploy_args - js_class_deploy_args - deploy_args_removed_at_deploy_time ==
806+ deploy_args_to_skip )
803807
804808 @mock .patch (
805809 "sagemaker.jumpstart.model.get_jumpstart_configs" , side_effect = lambda * args , ** kwargs : {}
@@ -1762,6 +1766,89 @@ def test_model_set_deployment_config(
17621766 endpoint_logging = False ,
17631767 )
17641768
1769+ @mock .patch (
1770+ "sagemaker.jumpstart.model.get_jumpstart_configs" , side_effect = lambda * args , ** kwargs : {}
1771+ )
1772+ @mock .patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
1773+ @mock .patch ("sagemaker.jumpstart.factory.model.Session" )
1774+ @mock .patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
1775+ @mock .patch ("sagemaker.jumpstart.model.Model.deploy" )
1776+ @mock .patch ("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME" , region )
1777+ def test_model_set_deployment_config_and_deploy_for_gated_draft_model (
1778+ self ,
1779+ mock_model_deploy : mock .Mock ,
1780+ mock_get_model_specs : mock .Mock ,
1781+ mock_session : mock .Mock ,
1782+ mock_get_manifest : mock .Mock ,
1783+ mock_get_jumpstart_configs : mock .Mock ,
1784+ ):
1785+ # WHERE
1786+ mock_get_model_specs .side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec
1787+ mock_get_manifest .side_effect = (
1788+ lambda region , model_type , * args , ** kwargs :
1789+ get_prototype_manifest (region , model_type )
1790+ )
1791+ mock_model_deploy .return_value = default_predictor
1792+
1793+ model_id = "pytorch-eqa-bert-base-cased"
1794+
1795+ mock_session .return_value = sagemaker_session
1796+
1797+ model = JumpStartModel (model_id = model_id )
1798+
1799+ assert model .config_name is None
1800+
1801+ # WHEN
1802+ model .deploy (model_access_configs = {"pytorch-eqa-bert-base-cased" :ModelAccessConfig (accept_eula = True )})
1803+
1804+ # THEN
1805+ mock_model_deploy .assert_called_once_with (
1806+ initial_instance_count = 1 ,
1807+ instance_type = "ml.p2.xlarge" ,
1808+ tags = [
1809+ {"Key" : JumpStartTag .MODEL_ID , "Value" : "pytorch-eqa-bert-base-cased" },
1810+ {"Key" : JumpStartTag .MODEL_VERSION , "Value" : "1.0.0" },
1811+ ],
1812+ wait = True ,
1813+ endpoint_logging = False ,
1814+ )
1815+
1816+ @mock .patch (
1817+ "sagemaker.jumpstart.model.get_jumpstart_configs" , side_effect = lambda * args , ** kwargs : {}
1818+ )
1819+ @mock .patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
1820+ @mock .patch ("sagemaker.jumpstart.factory.model.Session" )
1821+ @mock .patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
1822+ @mock .patch ("sagemaker.jumpstart.model.Model.deploy" )
1823+ @mock .patch ("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME" , region )
1824+ def test_model_set_deployment_config_and_deploy_for_gated_draft_model_no_model_access_configs (
1825+ self ,
1826+ mock_model_deploy : mock .Mock ,
1827+ mock_get_model_specs : mock .Mock ,
1828+ mock_session : mock .Mock ,
1829+ mock_get_manifest : mock .Mock ,
1830+ mock_get_jumpstart_configs : mock .Mock ,
1831+ ):
1832+ # WHERE
1833+ mock_get_model_specs .side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec
1834+ mock_get_manifest .side_effect = (
1835+ lambda region , model_type , * args , ** kwargs :
1836+ get_prototype_manifest (region , model_type )
1837+ )
1838+ mock_model_deploy .return_value = default_predictor
1839+
1840+ model_id = "pytorch-eqa-bert-base-cased"
1841+
1842+ mock_session .return_value = sagemaker_session
1843+
1844+ model = JumpStartModel (model_id = model_id )
1845+
1846+ assert model .config_name is None
1847+
1848+ # WHEN / THEN
1849+ with self .assertRaises (ValueError ):
1850+ model .deploy ()
1851+
17651852 @mock .patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
17661853 @mock .patch (
17671854 "sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
@@ -1810,6 +1897,7 @@ def test_model_deployment_config_additional_model_data_source(
18101897 "S3Uri" : "s3://sagemaker-sd-models-prod-us-west-2/key/to/draft/model/artifact/" ,
18111898 "ModelAccessConfig" : {"AcceptEula" : False },
18121899 },
1900+ "HostingEulaKey" : None ,
18131901 }
18141902 ],
18151903 )
0 commit comments