@@ -1538,7 +1538,7 @@ def wrapped_f(*args, **kwargs):
1538
1538
1539
1539
def _add_model_access_configs_to_model_data_sources (
1540
1540
model_data_sources : List [Dict [str , any ]],
1541
- model_access_configs : List [ ModelAccessConfig ],
1541
+ model_access_configs : Dict [ str , ModelAccessConfig ],
1542
1542
model_id : str ,
1543
1543
region : str ,
1544
1544
):
@@ -1548,25 +1548,29 @@ def _add_model_access_configs_to_model_data_sources(
1548
1548
return model_data_sources
1549
1549
1550
1550
acked_model_data_sources = []
1551
- acked_model_access_configs = 0
1552
1551
for model_data_source in model_data_sources :
1553
- hosting_eula_key = model_data_source .pop ("HostingEulaKey" , None )
1552
+ hosting_eula_key = model_data_source .get ("HostingEulaKey" )
1554
1553
if hosting_eula_key :
1555
- if not model_access_configs or acked_model_access_configs == len ( model_access_configs ):
1554
+ if not model_access_configs or not model_access_configs . get ( model_id ):
1556
1555
eula_message_template = "{model_source}{base_eula_message}{model_access_configs_message}"
1556
+ model_access_config_entry = (
1557
+ "\" {model_id}\" :ModelAccessConfig(accept_eula=True)" .format (model_id = model_id )
1558
+ )
1557
1559
raise ValueError (eula_message_template .format (
1558
1560
model_source = "Draft " if model_data_source .get ("ChannelName" ) else "" ,
1559
1561
base_eula_message = format_eula_message_from_specs (
1560
1562
model_id = model_id , region = region , hosting_eula_key = hosting_eula_key
1561
1563
),
1562
1564
model_access_configs_message = (
1563
- " Please add a ModelAccessConfig with AcceptEula=True"
1564
- " to model_access_configs to acknowledge the EULA."
1565
+ " Please add a ModelAccessConfig entry:"
1566
+ f" { model_access_config_entry } "
1567
+ "to model_access_configs to acknowledge the EULA."
1565
1568
)
1566
1569
))
1567
1570
acked_model_data_source = model_data_source .copy ()
1571
+ acked_model_data_source .pop ("HostingEulaKey" )
1568
1572
acked_model_data_source ["S3DataSource" ]["ModelAccessConfig" ] = (
1569
- camel_case_to_pascal_case (model_access_configs [ acked_model_access_configs ] .model_dump ())
1573
+ camel_case_to_pascal_case (model_access_configs . get ( model_id ) .model_dump ())
1570
1574
)
1571
1575
acked_model_data_sources .append (acked_model_data_source )
1572
1576
else :
0 commit comments