@@ -172,6 +172,60 @@ def _extract_speculative_draft_model_provider(
172172 return "sagemaker"
173173
174174
175+ def _extract_additional_model_data_source_s3_uri (
176+ additional_model_data_source : Optional [Dict ] = None ,
177+ ) -> Optional [str ]:
178+ """Extracts model data source s3 uri from a model data source in Pascal case.
179+
180+ Args:
181+ additional_model_data_source (Optional[Dict]): A model data source.
182+
183+ Returns:
184+ str: S3 uri of the model resources.
185+ """
186+ if (
187+ additional_model_data_source is None
188+ or additional_model_data_source .get ("S3DataSource" , None ) is None
189+ ):
190+ return None
191+
192+ return additional_model_data_source .get ("S3DataSource" ).get ("S3Uri" , None )
193+
194+
195+ def _extract_deployment_config_additional_model_data_source_s3_uri (
196+ additional_model_data_source : Optional [Dict ] = None ,
197+ ) -> Optional [str ]:
198+ """Extracts model data source s3 uri from a model data source in snake case.
199+
200+ Args:
201+ additional_model_data_source (Optional[Dict]): A model data source.
202+
203+ Returns:
204+ str: S3 uri of the model resources.
205+ """
206+ if (
207+ additional_model_data_source is None
208+ or additional_model_data_source .get ("s3_data_source" , None ) is None
209+ ):
210+ return None
211+
212+ return additional_model_data_source .get ("s3_data_source" ).get ("s3_uri" , None )
213+
214+
215+ def _is_draft_model_gated (
216+ draft_model_config : Optional [Dict ] = None ,
217+ ) -> bool :
218+ """Extracts model gated-ness from draft model data source.
219+
220+ Args:
221+ draft_model_config (Optional[Dict]): A model data source.
222+
223+ Returns:
224+ bool: Whether the draft model is gated or not.
225+ """
226+ return draft_model_config .get ("hosting_eula_key" , None )
227+
228+
175229def _extracts_and_validates_speculative_model_source (
176230 speculative_decoding_config : Dict ,
177231) -> str :
@@ -289,7 +343,7 @@ def _extract_optimization_config_and_env(
289343 if optimization_config :
290344 return optimization_config , quantization_override_env , compilation_override_env
291345
292- return None , None , None
346+ return {} , None , None
293347
294348
295349def _custom_speculative_decoding (
@@ -310,6 +364,8 @@ def _custom_speculative_decoding(
310364 speculative_decoding_config
311365 )
312366
367+ accept_eula = speculative_decoding_config .get ("AcceptEula" , False )
368+
313369 if _is_s3_uri (additional_model_source ):
314370 channel_name = _generate_channel_name (model .additional_model_data_sources )
315371 speculative_draft_model = f"{ SPECULATIVE_DRAFT_MODEL } /{ channel_name } "
@@ -326,3 +382,78 @@ def _custom_speculative_decoding(
326382 )
327383
328384 return model
385+
386+
387+ def _validate_and_set_eula_for_draft_model_sources (
388+ pysdk_model : Model ,
389+ accept_eula : bool = False ,
390+ ):
391+ """Validates whether the EULA has been accepted for gated additional draft model sources.
392+
393+ If accepted, updates the model data source's model access config.
394+
395+ Args:
396+ pysdk_model (Model): The model whose additional model data sources to check.
397+ accept_eula (bool): EULA acceptance for the draft model.
398+ """
399+ if not pysdk_model :
400+ return
401+
402+ deployment_config_draft_model_sources = (
403+ pysdk_model .deployment_config .get ("DeploymentArgs" , {})
404+ .get ("AdditionalDataSources" , {})
405+ .get ("speculative_decoding" , [])
406+ if pysdk_model .deployment_config
407+ else None
408+ )
409+ pysdk_model_additional_model_sources = pysdk_model .additional_model_data_sources
410+
411+ if not deployment_config_draft_model_sources or not pysdk_model_additional_model_sources :
412+ return
413+
414+ # Gated/ungated classification is only available through deployment_config.
415+ # Thus we must check each draft model in the deployment_config and see if it is set
416+ # as an additional model data source on the PySDK model itself.
417+ model_access_config_updated = False
418+ for source in deployment_config_draft_model_sources :
419+ if source .get ("channel_name" ) != "draft_model" :
420+ continue
421+
422+ if not _is_draft_model_gated (source ):
423+ continue
424+
425+ deployment_config_draft_model_source_s3_uri = (
426+ _extract_deployment_config_additional_model_data_source_s3_uri (source )
427+ )
428+
429+ # If EULA is accepted, proceed with modifying the draft model data source
430+ for additional_source in pysdk_model_additional_model_sources :
431+ if additional_source .get ("ChannelName" ) != "draft_model" :
432+ continue
433+
434+ # Verify the pysdk model source and deployment config model source match
435+ pysdk_model_source_s3_uri = _extract_additional_model_data_source_s3_uri (
436+ additional_source
437+ )
438+ if deployment_config_draft_model_source_s3_uri not in pysdk_model_source_s3_uri :
439+ continue
440+
441+ if not accept_eula :
442+ raise ValueError (
443+ "Gated draft model requires accepting end-user license agreement (EULA)."
444+ )
445+
446+ # Set ModelAccessConfig.AcceptEula to True
447+ updated_source = additional_source .copy ()
448+ updated_source ["S3DataSource" ]["ModelAccessConfig" ] = {"AcceptEula" : True }
449+
450+ index = pysdk_model .additional_model_data_sources .index (additional_source )
451+ pysdk_model .additional_model_data_sources [index ] = updated_source
452+
453+ model_access_config_updated = True
454+ break
455+
456+ if model_access_config_updated :
457+ break
458+
459+ return
0 commit comments