4848 _custom_speculative_decoding ,
4949 SPECULATIVE_DRAFT_MODEL ,
5050 _is_inferentia_or_trainium ,
51- _validate_and_set_eula_for_draft_model_sources ,
5251 _jumpstart_speculative_decoding ,
5352)
5453from sagemaker .serve .utils .predictors import (
@@ -837,9 +836,7 @@ def _is_gated_model(self, model=None) -> bool:
837836 return "private" in s3_uri
838837
839838 def _set_additional_model_source (
840- self ,
841- speculative_decoding_config : Optional [Dict [str , Any ]] = None ,
842- accept_eula : Optional [bool ] = None ,
839+ self , speculative_decoding_config : Optional [Dict [str , Any ]] = None
843840 ) -> None :
844841 """Set Additional Model Source to ``this`` model.
845842
@@ -849,6 +846,7 @@ def _set_additional_model_source(
849846 """
850847 if speculative_decoding_config :
851848 model_provider = _extract_speculative_draft_model_provider (speculative_decoding_config )
849+ accept_draft_model_eula = speculative_decoding_config .get ("AcceptEula" , False )
852850
853851 channel_name = _generate_channel_name (self .pysdk_model .additional_model_data_sources )
854852
@@ -865,20 +863,22 @@ def _set_additional_model_source(
865863 speculative_decoding_config
866864 )
867865 if deployment_config :
868- self .pysdk_model .set_deployment_config (
869- config_name = deployment_config .get ("DeploymentConfigName" ),
870- instance_type = deployment_config .get ("InstanceType" ),
871- )
866+ try :
867+ self .pysdk_model .set_deployment_config (
868+ config_name = deployment_config .get ("DeploymentConfigName" ),
869+ instance_type = deployment_config .get ("InstanceType" ),
870+ accept_draft_model_eula = accept_draft_model_eula ,
871+ )
872+ except ValueError as e :
873+ raise ValueError (
874+ f"{ e } If using speculative_decoding_config, "
875+ "accept the EULA by setting `AcceptEula`=True."
876+ )
872877 else :
873878 raise ValueError (
874879 "Cannot find deployment config compatible for optimization job."
875880 )
876881
877- _validate_and_set_eula_for_draft_model_sources (
878- pysdk_model = self .pysdk_model ,
879- accept_eula = speculative_decoding_config .get ("AcceptEula" ),
880- )
881-
882882 self .pysdk_model .env .update (
883883 {"OPTION_SPECULATIVE_DRAFT_MODEL" : f"{ SPECULATIVE_DRAFT_MODEL } /{ channel_name } /" }
884884 )
@@ -893,7 +893,7 @@ def _set_additional_model_source(
893893 )
894894 else :
895895 self .pysdk_model = _custom_speculative_decoding (
896- self .pysdk_model , speculative_decoding_config , accept_eula
896+ self .pysdk_model , speculative_decoding_config , accept_draft_model_eula
897897 )
898898
899899 def _find_compatible_deployment_config (
0 commit comments