2929 TYPE_CHECKING ,
3030)
3131
32+ from dataclasses import dataclass
33+ import json
34+ import os
35+ import pathlib
36+
3237from edb import edgeql
3338from edb import errors
3439from edb .common import parsing
@@ -356,6 +361,12 @@ class Index(
356361 allow_ddl_set = True ,
357362 )
358363
364+ # If this is a generated AI index, track the created type
365+ generated_ai_index_type = so .SchemaField (
366+ s_types .Type ,
367+ default = None ,
368+ )
369+
359370 def __repr__ (self ) -> str :
360371 cls = self .__class__
361372 return '<{}.{} {!r} at 0x{:x}>' .format (
@@ -966,6 +977,43 @@ def canonicalize_attributes(
966977 )
967978 return schema
968979
980+ def _handle_ai_index_op (
981+ self ,
982+ schema : s_schema .Schema ,
983+ context : sd .CommandContext ,
984+ ) -> None :
985+ if context .canonical :
986+ return
987+
988+ # Get the provider and model from the kwargs
989+ kwargs : Optional [s_expr .ExpressionDict ] = (
990+ self .get_resolved_attribute_value (
991+ 'kwargs' , schema = schema , context = context
992+ )
993+ )
994+ if not kwargs :
995+ return
996+
997+ embedding_model_kwarg : Optional [s_expr .Expression ] = (
998+ kwargs .get ('embedding_model' )
999+ )
1000+ if not embedding_model_kwarg :
1001+ return
1002+
1003+ embedding_model_expr = s_expr .Expression .compiled (
1004+ embedding_model_kwarg , schema = schema , context = None
1005+ )
1006+ model_name = embedding_model_expr .as_python_value ()
1007+ if ':' in model_name :
1008+ model_command , model_typeshell = _create_ai_model_type (
1009+ model_name , schema , context , span = self .span ,
1010+ )
1011+ self .add_prerequisite (model_command )
1012+ self .set_attribute_value (
1013+ 'generated_ai_index_type' ,
1014+ model_typeshell ,
1015+ )
1016+
9691017 def ast_ignore_field_ownership (self , field : str ) -> bool :
9701018 """Whether to force generating an AST even though field isn't owned"""
9711019 return field == "deferred"
@@ -987,6 +1035,217 @@ def _append_subcmd_ast(
9871035 super ()._append_subcmd_ast (schema , node , subcmd , context )
9881036
9891037
1038+ @dataclass (kw_only = True , frozen = True )
1039+ class ProviderDesc :
1040+ name : str
1041+
1042+
1043+ @dataclass (kw_only = True , frozen = True )
1044+ class EmbeddingModelDesc :
1045+ model_name : str
1046+ model_provider : str
1047+ max_input_tokens : int
1048+ max_batch_tokens : int
1049+ max_output_dimensions : int
1050+ supports_shortening : bool
1051+
1052+
1053+ def _try_load_reference_descs () -> Optional [
1054+ tuple [
1055+ Mapping [str , ProviderDesc ],
1056+ Mapping [str , EmbeddingModelDesc ],
1057+ ]
1058+ ]:
1059+ try :
1060+ provider_descs : dict [str , ProviderDesc ] = {}
1061+ embedding_model_descs : dict [str , EmbeddingModelDesc ] = {}
1062+
1063+ local_reference_path = os .path .join (
1064+ pathlib .Path (__file__ ).parent .parent ,
1065+ 'server' ,
1066+ 'protocol' ,
1067+ 'ai_reference.json' ,
1068+ )
1069+ with open (local_reference_path ) as local_reference_file :
1070+ local_reference : dict [str , Any ] = json .load (local_reference_file )
1071+ if provider_ref := local_reference .get ("providers" ):
1072+ for name , ref in provider_ref .items ():
1073+ provider_descs [name ] = ProviderDesc (
1074+ name = ref ["name" ],
1075+ )
1076+ if embedding_models_ref := local_reference .get ("embedding_models" ):
1077+ for name , ref in embedding_models_ref .items ():
1078+ embedding_model_descs [name ] = EmbeddingModelDesc (
1079+ model_name = ref ["model_name" ],
1080+ model_provider = ref ["model_provider" ],
1081+ max_input_tokens = ref ["max_input_tokens" ],
1082+ max_batch_tokens = ref ["max_batch_tokens" ],
1083+ max_output_dimensions = ref ["max_output_dimensions" ],
1084+ supports_shortening = ref ["supports_shortening" ],
1085+ )
1086+ return (provider_descs , embedding_model_descs )
1087+
1088+ except Exception :
1089+ # Ignore failures
1090+ return None
1091+
1092+
1093+ def _lookup_embedding_model_description (
1094+ provider_name : str ,
1095+ model_name : str ,
1096+ ) -> Optional [tuple [ProviderDesc , EmbeddingModelDesc ]]:
1097+ reference_descs = _try_load_reference_descs ()
1098+
1099+ if reference_descs is None :
1100+ return None
1101+
1102+ provider_descs , embedding_model_descs = reference_descs
1103+
1104+ provider_desc = provider_descs .get (provider_name )
1105+ model_desc = embedding_model_descs .get (model_name )
1106+ if provider_desc is None or model_desc is None :
1107+ return None
1108+ if provider_name != model_desc .model_provider :
1109+ return None
1110+
1111+ return (provider_desc , model_desc )
1112+
1113+
1114+ def _create_ai_model_type (
1115+ model_name : str ,
1116+ schema : s_schema .Schema ,
1117+ context : sd .CommandContext ,
1118+ * ,
1119+ span : Optional [parsing .Span ],
1120+ ) -> tuple [
1121+ Optional [sd .Command ],
1122+ so .ObjectShell [s_types .Type ],
1123+ ]:
1124+ model_name_parts = model_name .split (':' )
1125+ if len (model_name_parts ) > 2 :
1126+ raise errors .SchemaDefinitionError (
1127+ f"Invalid model uri, ':' used more than once: { model_name } " ,
1128+ span = span
1129+ )
1130+
1131+ provider_name = model_name_parts [0 ]
1132+ model_name = model_name_parts [1 ]
1133+
1134+ # Lookup the ai model description
1135+ provider_model_desc = _lookup_embedding_model_description (
1136+ provider_name , model_name
1137+ )
1138+ if not provider_model_desc :
1139+ raise errors .SchemaDefinitionError (
1140+ f"Invalid model uri, unknown provider and model: { model_name } " ,
1141+ span = span
1142+ )
1143+
1144+ provider_desc , model_desc = provider_model_desc
1145+
1146+ # Ensure that a suitable model type exists for the ai index
1147+
1148+ model_typename = sn .QualName (
1149+ '__ext_generated_types__' ,
1150+ f'ai_embedding_{ provider_name } _{ model_name } '
1151+ )
1152+
1153+ if model_type := schema .get (model_typename , None , type = s_types .Type ):
1154+ # If the model type already exists, add a link to it
1155+ model_typeshell = so .ObjectShell (
1156+ name = model_typename ,
1157+ schemaclass = type (model_type ),
1158+ )
1159+ return None , model_typeshell
1160+
1161+ else :
1162+ # If the model type does not exist, create it and also add a link
1163+ model_ast = qlast .CreateObjectType (
1164+ name = qlast .ObjectRef (
1165+ name = model_typename .name ,
1166+ module = model_typename .module ,
1167+ itemclass = qltypes .SchemaObjectClass .TYPE ,
1168+ ),
1169+ abstract = True ,
1170+ bases = [
1171+ qlast .TypeName (
1172+ maintype = qlast .ObjectRef (
1173+ name = 'EmbeddingModel' ,
1174+ module = 'ext::ai' ,
1175+ ),
1176+ ),
1177+ ],
1178+ commands = [
1179+ qlast .AlterAnnotationValue (
1180+ name = qlast .ObjectRef (
1181+ name = 'model_name' ,
1182+ module = 'ext::ai' ,
1183+ ),
1184+ value = qlast .Constant .string (model_desc .model_name ),
1185+ ),
1186+ qlast .AlterAnnotationValue (
1187+ name = qlast .ObjectRef (
1188+ name = 'model_provider' ,
1189+ module = 'ext::ai' ,
1190+ ),
1191+ value = qlast .Constant .string (provider_desc .name ),
1192+ ),
1193+ qlast .AlterAnnotationValue (
1194+ name = qlast .ObjectRef (
1195+ name = 'embedding_model_max_input_tokens' ,
1196+ module = 'ext::ai' ,
1197+ ),
1198+ value = qlast .Constant .string (str (model_desc .max_input_tokens )),
1199+ ),
1200+ qlast .AlterAnnotationValue (
1201+ name = qlast .ObjectRef (
1202+ name = 'embedding_model_max_batch_tokens' ,
1203+ module = 'ext::ai' ,
1204+ ),
1205+ value = qlast .Constant .string (str (model_desc .max_batch_tokens )),
1206+ ),
1207+ qlast .AlterAnnotationValue (
1208+ name = qlast .ObjectRef (
1209+ name = 'embedding_model_max_output_dimensions' ,
1210+ module = 'ext::ai' ,
1211+ ),
1212+ value = qlast .Constant .string (str (model_desc .max_output_dimensions )),
1213+ ),
1214+ ],
1215+ )
1216+ if model_desc .supports_shortening :
1217+ model_ast .commands .append (
1218+ qlast .AlterAnnotationValue (
1219+ name = qlast .ObjectRef (
1220+ name = 'embedding_model_supports_shortening' ,
1221+ module = 'ext::ai' ,
1222+ ),
1223+ value = qlast .Constant .string ('true' ),
1224+ )
1225+ )
1226+ model_cmd = sd .compile_ddl (
1227+ schema ,
1228+ model_ast ,
1229+ )
1230+ model_cmd .set_attribute_value ('is_schema_generated' , True )
1231+
1232+ # Doing this since using model_cmd or a copy doesn't work
1233+ dummy_model_cmd = sd .compile_ddl (
1234+ schema ,
1235+ model_ast ,
1236+ )
1237+ dummy_model_cmd .set_attribute_value ('is_schema_generated' , True )
1238+ new_schema = dummy_model_cmd .apply (schema , context )
1239+ model_type = new_schema .get (model_typename , None , type = s_types .Type )
1240+ assert model_type is not None
1241+ model_typeshell = so .ObjectShell (
1242+ name = model_typename ,
1243+ schemaclass = type (model_type ),
1244+ )
1245+
1246+ return model_cmd , model_typeshell
1247+
1248+
9901249class CreateIndex (
9911250 IndexCommand ,
9921251 referencing .CreateReferencedInheritingObject [Index ],
@@ -1382,6 +1641,8 @@ def _create_begin(
13821641 schema : s_schema .Schema ,
13831642 context : sd .CommandContext ,
13841643 ) -> s_schema .Schema :
1644+ self ._handle_ai_index_op (schema , context )
1645+
13851646 schema = super ()._create_begin (schema , context )
13861647 referrer_ctx = self .get_referrer_context (context )
13871648 if (
@@ -1601,6 +1862,19 @@ def _get_referenced_embedding_model(
16011862 kwargs = self .scls .get_concrete_kwargs_as_values (schema )
16021863 model_name = kwargs ["embedding_model" ]
16031864
1865+ if ':' in model_name :
1866+ # We are doing URI lookup.
1867+ # If a new model type is needed, it will have been created in
1868+ # _handle_ai_index_op.
1869+ model_name_parts = model_name .split (':' )
1870+ if len (model_name_parts ) > 2 :
1871+ raise errors .SchemaDefinitionError (
1872+ f"Invalid model uri, ':' used more than once: { model_name } " ,
1873+ span = self .span
1874+ )
1875+
1876+ model_name = model_name_parts [1 ]
1877+
16041878 models = get_defined_ext_ai_embedding_models (schema , model_name )
16051879 if len (models ) == 0 :
16061880 raise errors .SchemaDefinitionError (
0 commit comments