Skip to content

Commit 15b92fe

Browse files
committed
Handle pre-existing model types.
1 parent 97c3e9d commit 15b92fe

File tree

1 file changed

+74
-47
lines changed

1 file changed

+74
-47
lines changed

edb/schema/indexes.py

Lines changed: 74 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1023,10 +1023,11 @@ def _handle_ai_index_op(
10231023
model_cmd.prepend(drop_old_model_cmd)
10241024

10251025
self.add_prerequisite(model_cmd)
1026-
self.set_attribute_value(
1027-
'generated_ai_model_type',
1028-
model_typeshell,
1029-
)
1026+
if model_typeshell:
1027+
self.set_attribute_value(
1028+
'generated_ai_model_type',
1029+
model_typeshell,
1030+
)
10301031

10311032
def _delete_ai_model_type(
10321033
self,
@@ -1100,16 +1101,19 @@ class EmbeddingModelDesc:
11001101
supports_shortening: bool
11011102

11021103

1103-
def _try_load_reference_descs() -> Optional[
1104+
def _try_load_reference_descs(
1105+
*,
1106+
context: sd.CommandContext,
1107+
) -> Optional[
11041108
tuple[
1105-
Mapping[str, ProviderDesc],
1106-
Mapping[str, EmbeddingModelDesc],
1109+
dict[str, ProviderDesc],
1110+
dict[str, EmbeddingModelDesc],
11071111
]
11081112
]:
1109-
try:
1110-
provider_descs: dict[str, ProviderDesc] = {}
1111-
embedding_model_descs: dict[str, EmbeddingModelDesc] = {}
1113+
provider_descs: dict[str, ProviderDesc] = {}
1114+
embedding_model_descs: dict[str, EmbeddingModelDesc] = {}
11121115

1116+
try:
11131117
local_reference_path = os.path.join(
11141118
pathlib.Path(__file__).parent.parent,
11151119
'server',
@@ -1133,32 +1137,12 @@ def _try_load_reference_descs() -> Optional[
11331137
max_output_dimensions=ref["max_output_dimensions"],
11341138
supports_shortening=ref["supports_shortening"],
11351139
)
1136-
return (provider_descs, embedding_model_descs)
11371140

11381141
except Exception:
11391142
# Ignore failures
1140-
return None
1141-
1143+
pass
11421144

1143-
def _lookup_embedding_model_description(
1144-
provider_name: str,
1145-
model_name: str,
1146-
) -> Optional[tuple[ProviderDesc, EmbeddingModelDesc]]:
1147-
reference_descs = _try_load_reference_descs()
1148-
1149-
if reference_descs is None:
1150-
return None
1151-
1152-
provider_descs, embedding_model_descs = reference_descs
1153-
1154-
provider_desc = provider_descs.get(provider_name)
1155-
model_desc = embedding_model_descs.get(model_name)
1156-
if provider_desc is None or model_desc is None:
1157-
return None
1158-
if provider_name != model_desc.model_provider:
1159-
return None
1160-
1161-
return (provider_desc, model_desc)
1145+
return (provider_descs, embedding_model_descs)
11621146

11631147

11641148
def _create_ai_model_type(
@@ -1169,35 +1153,78 @@ def _create_ai_model_type(
11691153
span: Optional[parsing.Span],
11701154
) -> tuple[
11711155
sd.Command,
1172-
so.ObjectShell[s_types.Type],
1156+
Optional[so.ObjectShell[s_types.Type]],
11731157
]:
11741158
model_name_parts = model_name.split(':')
11751159
if len(model_name_parts) > 2:
11761160
raise errors.SchemaDefinitionError(
1177-
f"Invalid model uri, ':' used more than once: {model_name}",
1161+
f"Invalid model uri, ':' used more than once: {uri_model_name}",
11781162
span=span
11791163
)
11801164

1181-
provider_name = model_name_parts[0]
1182-
model_name = model_name_parts[1]
1165+
uri_provider_name = model_name_parts[0]
1166+
uri_model_name = model_name_parts[1]
11831167

11841168
# Lookup the ai model description
1185-
provider_model_desc = _lookup_embedding_model_description(
1186-
provider_name, model_name
1169+
provider_descs, model_descs = _try_load_reference_descs(
1170+
context=context
11871171
)
1188-
if not provider_model_desc:
1189-
raise errors.SchemaDefinitionError(
1190-
f"Invalid model uri, unknown provider and model: {model_name}",
1191-
span=span
1192-
)
11931172

1194-
provider_desc, model_desc = provider_model_desc
1173+
if uri_model_name not in model_descs:
1174+
# No match in reference
1175+
return sd.DeltaRoot(), None
11951176

1196-
# Ensure that a suitable model type exists for the ai index
1177+
model_desc = model_descs[uri_model_name]
1178+
1179+
resolved_provider_name = (
1180+
provider_descs[uri_provider_name].name
1181+
if uri_provider_name in provider_descs else
1182+
uri_provider_name
1183+
)
1184+
1185+
# First check if a model with a matching name is already in the schema
1186+
if models := get_defined_ext_ai_embedding_models(schema, uri_model_name):
1187+
# If there is only 1 model, ensure that the provider name matches
1188+
if len(models) == 1:
1189+
model = next(iter(models.values()))
1190+
1191+
model_provider_name = model.get_annotation(
1192+
schema, sn.QualName("ext::ai", "model_provider")
1193+
)
1194+
# The URI will specify "openai" instead of "builtin::openai"
1195+
# We want to show only "openai" in the case of an error here.
1196+
model_provider_short_name = model_provider_name
1197+
for provider_short_name, provider_desc in provider_descs.items():
1198+
if provider_desc.name == model_provider_name:
1199+
model_provider_short_name = provider_short_name
1200+
break
11971201

1202+
if model_provider_name != resolved_provider_name:
1203+
raise errors.SchemaDefinitionError(
1204+
f"An embedding model with the name '{uri_model_name}' exists "
1205+
f"but the provider specified by the index "
1206+
f"('{uri_provider_name}') differs from the one "
1207+
f"specified by the model ('{model_provider_short_name}').",
1208+
span=span,
1209+
)
1210+
1211+
elif len(models) > 1:
1212+
models_dn = [
1213+
model.get_displayname(schema) for model in models.values()
1214+
]
1215+
raise errors.SchemaDefinitionError(
1216+
f'expecting only one embedding model to be annotated '
1217+
f'with ext::ai::model_name={model_name!r}: got multiple: '
1218+
f'{", ".join(models_dn)}',
1219+
span=span,
1220+
)
1221+
1222+
return sd.DeltaRoot(), model.as_shell(schema)
1223+
1224+
# Ensure that a suitable model type exists for the ai index
11981225
model_typename = sn.QualName(
11991226
'__ext_generated_types__',
1200-
f'ai_embedding_{provider_name}_{model_name}'
1227+
f'ai_embedding_{uri_provider_name}_{uri_model_name}'
12011228
)
12021229

12031230
if model_type := schema.get(model_typename, None, type=s_types.Type):
@@ -1238,7 +1265,7 @@ def _create_ai_model_type(
12381265
name='model_provider',
12391266
module='ext::ai',
12401267
),
1241-
value=qlast.Constant.string(provider_desc.name),
1268+
value=qlast.Constant.string(resolved_provider_name),
12421269
),
12431270
qlast.AlterAnnotationValue(
12441271
name=qlast.ObjectRef(

0 commit comments

Comments
 (0)