@@ -1023,10 +1023,11 @@ def _handle_ai_index_op(
10231023 model_cmd .prepend (drop_old_model_cmd )
10241024
10251025 self .add_prerequisite (model_cmd )
1026- self .set_attribute_value (
1027- 'generated_ai_model_type' ,
1028- model_typeshell ,
1029- )
1026+ if model_typeshell :
1027+ self .set_attribute_value (
1028+ 'generated_ai_model_type' ,
1029+ model_typeshell ,
1030+ )
10301031
10311032 def _delete_ai_model_type (
10321033 self ,
@@ -1100,16 +1101,19 @@ class EmbeddingModelDesc:
11001101 supports_shortening : bool
11011102
11021103
1103- def _try_load_reference_descs () -> Optional [
1104+ def _try_load_reference_descs (
1105+ * ,
1106+ context : sd .CommandContext ,
1107+ ) -> Optional [
11041108 tuple [
1105- Mapping [str , ProviderDesc ],
1106- Mapping [str , EmbeddingModelDesc ],
1109+ dict [str , ProviderDesc ],
1110+ dict [str , EmbeddingModelDesc ],
11071111 ]
11081112]:
1109- try :
1110- provider_descs : dict [str , ProviderDesc ] = {}
1111- embedding_model_descs : dict [str , EmbeddingModelDesc ] = {}
1113+ provider_descs : dict [str , ProviderDesc ] = {}
1114+ embedding_model_descs : dict [str , EmbeddingModelDesc ] = {}
11121115
1116+ try :
11131117 local_reference_path = os .path .join (
11141118 pathlib .Path (__file__ ).parent .parent ,
11151119 'server' ,
@@ -1133,32 +1137,12 @@ def _try_load_reference_descs() -> Optional[
11331137 max_output_dimensions = ref ["max_output_dimensions" ],
11341138 supports_shortening = ref ["supports_shortening" ],
11351139 )
1136- return (provider_descs , embedding_model_descs )
11371140
11381141 except Exception :
11391142 # Ignore failures
1140- return None
1141-
1143+ pass
11421144
1143- def _lookup_embedding_model_description (
1144- provider_name : str ,
1145- model_name : str ,
1146- ) -> Optional [tuple [ProviderDesc , EmbeddingModelDesc ]]:
1147- reference_descs = _try_load_reference_descs ()
1148-
1149- if reference_descs is None :
1150- return None
1151-
1152- provider_descs , embedding_model_descs = reference_descs
1153-
1154- provider_desc = provider_descs .get (provider_name )
1155- model_desc = embedding_model_descs .get (model_name )
1156- if provider_desc is None or model_desc is None :
1157- return None
1158- if provider_name != model_desc .model_provider :
1159- return None
1160-
1161- return (provider_desc , model_desc )
1145+ return (provider_descs , embedding_model_descs )
11621146
11631147
11641148def _create_ai_model_type (
@@ -1169,35 +1153,78 @@ def _create_ai_model_type(
11691153 span : Optional [parsing .Span ],
11701154) -> tuple [
11711155 sd .Command ,
1172- so .ObjectShell [s_types .Type ],
1156+ Optional [ so .ObjectShell [s_types .Type ] ],
11731157]:
11741158 model_name_parts = model_name .split (':' )
11751159 if len (model_name_parts ) > 2 :
11761160 raise errors .SchemaDefinitionError (
1177- f"Invalid model uri, ':' used more than once: { model_name } " ,
1161+ f"Invalid model uri, ':' used more than once: { uri_model_name } " ,
11781162 span = span
11791163 )
11801164
1181- provider_name = model_name_parts [0 ]
1182- model_name = model_name_parts [1 ]
1165+ uri_provider_name = model_name_parts [0 ]
1166+ uri_model_name = model_name_parts [1 ]
11831167
11841168 # Lookup the ai model description
1185- provider_model_desc = _lookup_embedding_model_description (
1186- provider_name , model_name
1169+ provider_descs , model_descs = _try_load_reference_descs (
1170+ context = context
11871171 )
1188- if not provider_model_desc :
1189- raise errors .SchemaDefinitionError (
1190- f"Invalid model uri, unknown provider and model: { model_name } " ,
1191- span = span
1192- )
11931172
1194- provider_desc , model_desc = provider_model_desc
1173+ if uri_model_name not in model_descs :
1174+ # No match in reference
1175+ return sd .DeltaRoot (), None
11951176
1196- # Ensure that a suitable model type exists for the ai index
1177+ model_desc = model_descs [uri_model_name ]
1178+
1179+ resolved_provider_name = (
1180+ provider_descs [uri_provider_name ].name
1181+ if uri_provider_name in provider_descs else
1182+ uri_provider_name
1183+ )
1184+
1185+ # First check if a model with a matching name is already in the schema
1186+ if models := get_defined_ext_ai_embedding_models (schema , uri_model_name ):
1187+ # If there is only 1 model, ensure that the provider name matches
1188+ if len (models ) == 1 :
1189+ model = next (iter (models .values ()))
1190+
1191+ model_provider_name = model .get_annotation (
1192+ schema , sn .QualName ("ext::ai" , "model_provider" )
1193+ )
1194+ # The URI will specify "openai" instead of "builtin::openai"
1195+ # We want to show only "openai" in the case of an error here.
1196+ model_provider_short_name = model_provider_name
1197+ for provider_short_name , provider_desc in provider_descs .items ():
1198+ if provider_desc .name == model_provider_name :
1199+ model_provider_short_name = provider_short_name
1200+ break
11971201
1202+ if model_provider_name != resolved_provider_name :
1203+ raise errors .SchemaDefinitionError (
1204+ f"An embedding model with the name '{ uri_model_name } ' exists "
1205+ f"but the provider specified by the index "
1206+ f"('{ uri_provider_name } ') differs from the one "
1207+ f"specified by the model ('{ model_provider_short_name } ')." ,
1208+ span = span ,
1209+ )
1210+
1211+ elif len (models ) > 1 :
1212+ models_dn = [
1213+ model .get_displayname (schema ) for model in models .values ()
1214+ ]
1215+ raise errors .SchemaDefinitionError (
1216+ f'expecting only one embedding model to be annotated '
1217+ f'with ext::ai::model_name={ model_name !r} : got multiple: '
1218+ f'{ ", " .join (models_dn )} ' ,
1219+ span = span ,
1220+ )
1221+
1222+ return sd .DeltaRoot (), model .as_shell (schema )
1223+
1224+ # Ensure that a suitable model type exists for the ai index
11981225 model_typename = sn .QualName (
11991226 '__ext_generated_types__' ,
1200- f'ai_embedding_{ provider_name } _{ model_name } '
1227+ f'ai_embedding_{ uri_provider_name } _{ uri_model_name } '
12011228 )
12021229
12031230 if model_type := schema .get (model_typename , None , type = s_types .Type ):
@@ -1238,7 +1265,7 @@ def _create_ai_model_type(
12381265 name = 'model_provider' ,
12391266 module = 'ext::ai' ,
12401267 ),
1241- value = qlast .Constant .string (provider_desc . name ),
1268+ value = qlast .Constant .string (resolved_provider_name ),
12421269 ),
12431270 qlast .AlterAnnotationValue (
12441271 name = qlast .ObjectRef (
0 commit comments