@@ -1538,7 +1538,7 @@ def wrapped_f(*args, **kwargs):
15381538
15391539def _add_model_access_configs_to_model_data_sources (
15401540 model_data_sources : List [Dict [str , any ]],
1541- model_access_configs : List [ ModelAccessConfig ],
1541+ model_access_configs : Dict [ str , ModelAccessConfig ],
15421542 model_id : str ,
15431543 region : str ,
15441544):
@@ -1548,25 +1548,29 @@ def _add_model_access_configs_to_model_data_sources(
15481548 return model_data_sources
15491549
15501550 acked_model_data_sources = []
1551- acked_model_access_configs = 0
15521551 for model_data_source in model_data_sources :
1553- hosting_eula_key = model_data_source .pop ("HostingEulaKey" , None )
1552+ hosting_eula_key = model_data_source .get ("HostingEulaKey" )
15541553 if hosting_eula_key :
1555- if not model_access_configs or acked_model_access_configs == len ( model_access_configs ):
1554+ if not model_access_configs or not model_access_configs . get ( model_id ):
15561555 eula_message_template = "{model_source}{base_eula_message}{model_access_configs_message}"
1556+ model_access_config_entry = (
1557+ "\" {model_id}\" :ModelAccessConfig(accept_eula=True)" .format (model_id = model_id )
1558+ )
15571559 raise ValueError (eula_message_template .format (
15581560 model_source = "Draft " if model_data_source .get ("ChannelName" ) else "" ,
15591561 base_eula_message = format_eula_message_from_specs (
15601562 model_id = model_id , region = region , hosting_eula_key = hosting_eula_key
15611563 ),
15621564 model_access_configs_message = (
1563- " Please add a ModelAccessConfig with AcceptEula=True"
1564- " to model_access_configs to acknowledge the EULA."
1565+ " Please add a ModelAccessConfig entry:"
1566+ f" { model_access_config_entry } "
1567+ "to model_access_configs to acknowledge the EULA."
15651568 )
15661569 ))
15671570 acked_model_data_source = model_data_source .copy ()
1571+ acked_model_data_source .pop ("HostingEulaKey" )
15681572 acked_model_data_source ["S3DataSource" ]["ModelAccessConfig" ] = (
1569- camel_case_to_pascal_case (model_access_configs [ acked_model_access_configs ] .model_dump ())
1573+ camel_case_to_pascal_case (model_access_configs . get ( model_id ) .model_dump ())
15701574 )
15711575 acked_model_data_sources .append (acked_model_data_source )
15721576 else :
0 commit comments