19
19
import pandas as pd
20
20
from mock import MagicMock , Mock
21
21
import pytest
22
+ from sagemaker_core .shapes import ModelAccessConfig
22
23
from sagemaker .async_inference .async_inference_config import AsyncInferenceConfig
23
24
from sagemaker .jumpstart .artifacts .environment_variables import (
24
25
_retrieve_default_environment_variables ,
54
55
get_base_deployment_configs ,
55
56
get_base_spec_with_prototype_configs_with_missing_benchmarks ,
56
57
append_instance_stat_metrics ,
58
+ append_gated_draft_model_specs_to_jumpstart_model_spec ,
57
59
)
58
60
import boto3
59
61
@@ -772,6 +774,7 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
772
774
773
775
init_args_to_skip : Set [str ] = set (["model_reference_arn" ])
774
776
deploy_args_to_skip : Set [str ] = set (["kwargs" , "model_reference_arn" ])
777
+ deploy_args_removed_at_deploy_time : Set [str ] = set (["model_access_configs" ])
775
778
776
779
parent_class_init = Model .__init__
777
780
parent_class_init_args = set (signature (parent_class_init ).parameters .keys ())
@@ -798,8 +801,9 @@ def test_jumpstart_model_kwargs_match_parent_class(self):
798
801
js_class_deploy = JumpStartModel .deploy
799
802
js_class_deploy_args = set (signature (js_class_deploy ).parameters .keys ())
800
803
801
- assert js_class_deploy_args - parent_class_deploy_args == set ()
802
- assert parent_class_deploy_args - js_class_deploy_args == deploy_args_to_skip
804
+ assert js_class_deploy_args - parent_class_deploy_args - deploy_args_removed_at_deploy_time == set ()
805
+ assert (parent_class_deploy_args - js_class_deploy_args - deploy_args_removed_at_deploy_time ==
806
+ deploy_args_to_skip )
803
807
804
808
@mock .patch (
805
809
"sagemaker.jumpstart.model.get_jumpstart_configs" , side_effect = lambda * args , ** kwargs : {}
@@ -1762,6 +1766,89 @@ def test_model_set_deployment_config(
1762
1766
endpoint_logging = False ,
1763
1767
)
1764
1768
1769
+ @mock .patch (
1770
+ "sagemaker.jumpstart.model.get_jumpstart_configs" , side_effect = lambda * args , ** kwargs : {}
1771
+ )
1772
+ @mock .patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
1773
+ @mock .patch ("sagemaker.jumpstart.factory.model.Session" )
1774
+ @mock .patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
1775
+ @mock .patch ("sagemaker.jumpstart.model.Model.deploy" )
1776
+ @mock .patch ("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME" , region )
1777
+ def test_model_set_deployment_config_and_deploy_for_gated_draft_model (
1778
+ self ,
1779
+ mock_model_deploy : mock .Mock ,
1780
+ mock_get_model_specs : mock .Mock ,
1781
+ mock_session : mock .Mock ,
1782
+ mock_get_manifest : mock .Mock ,
1783
+ mock_get_jumpstart_configs : mock .Mock ,
1784
+ ):
1785
+ # WHERE
1786
+ mock_get_model_specs .side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec
1787
+ mock_get_manifest .side_effect = (
1788
+ lambda region , model_type , * args , ** kwargs :
1789
+ get_prototype_manifest (region , model_type )
1790
+ )
1791
+ mock_model_deploy .return_value = default_predictor
1792
+
1793
+ model_id = "pytorch-eqa-bert-base-cased"
1794
+
1795
+ mock_session .return_value = sagemaker_session
1796
+
1797
+ model = JumpStartModel (model_id = model_id )
1798
+
1799
+ assert model .config_name is None
1800
+
1801
+ # WHEN
1802
+ model .deploy (model_access_configs = {"pytorch-eqa-bert-base-cased" :ModelAccessConfig (accept_eula = True )})
1803
+
1804
+ # THEN
1805
+ mock_model_deploy .assert_called_once_with (
1806
+ initial_instance_count = 1 ,
1807
+ instance_type = "ml.p2.xlarge" ,
1808
+ tags = [
1809
+ {"Key" : JumpStartTag .MODEL_ID , "Value" : "pytorch-eqa-bert-base-cased" },
1810
+ {"Key" : JumpStartTag .MODEL_VERSION , "Value" : "1.0.0" },
1811
+ ],
1812
+ wait = True ,
1813
+ endpoint_logging = False ,
1814
+ )
1815
+
1816
+ @mock .patch (
1817
+ "sagemaker.jumpstart.model.get_jumpstart_configs" , side_effect = lambda * args , ** kwargs : {}
1818
+ )
1819
+ @mock .patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
1820
+ @mock .patch ("sagemaker.jumpstart.factory.model.Session" )
1821
+ @mock .patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor.get_model_specs" )
1822
+ @mock .patch ("sagemaker.jumpstart.model.Model.deploy" )
1823
+ @mock .patch ("sagemaker.jumpstart.factory.model.JUMPSTART_DEFAULT_REGION_NAME" , region )
1824
+ def test_model_set_deployment_config_and_deploy_for_gated_draft_model_no_model_access_configs (
1825
+ self ,
1826
+ mock_model_deploy : mock .Mock ,
1827
+ mock_get_model_specs : mock .Mock ,
1828
+ mock_session : mock .Mock ,
1829
+ mock_get_manifest : mock .Mock ,
1830
+ mock_get_jumpstart_configs : mock .Mock ,
1831
+ ):
1832
+ # WHERE
1833
+ mock_get_model_specs .side_effect = append_gated_draft_model_specs_to_jumpstart_model_spec
1834
+ mock_get_manifest .side_effect = (
1835
+ lambda region , model_type , * args , ** kwargs :
1836
+ get_prototype_manifest (region , model_type )
1837
+ )
1838
+ mock_model_deploy .return_value = default_predictor
1839
+
1840
+ model_id = "pytorch-eqa-bert-base-cased"
1841
+
1842
+ mock_session .return_value = sagemaker_session
1843
+
1844
+ model = JumpStartModel (model_id = model_id )
1845
+
1846
+ assert model .config_name is None
1847
+
1848
+ # WHEN / THEN
1849
+ with self .assertRaises (ValueError ):
1850
+ model .deploy ()
1851
+
1765
1852
@mock .patch ("sagemaker.jumpstart.accessors.JumpStartModelsAccessor._get_manifest" )
1766
1853
@mock .patch (
1767
1854
"sagemaker.jumpstart.factory.model.get_default_jumpstart_session_with_user_agent_suffix"
@@ -1810,6 +1897,7 @@ def test_model_deployment_config_additional_model_data_source(
1810
1897
"S3Uri" : "s3://sagemaker-sd-models-prod-us-west-2/key/to/draft/model/artifact/" ,
1811
1898
"ModelAccessConfig" : {"AcceptEula" : False },
1812
1899
},
1900
+ "HostingEulaKey" : None ,
1813
1901
}
1814
1902
],
1815
1903
)
0 commit comments