Skip to content

Commit a941f5e

Browse files
committed
Generate model types.
1 parent 0fd0eeb commit a941f5e

File tree

6 files changed

+317
-3
lines changed

6 files changed

+317
-3
lines changed

edb/schema/ddl.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ def _filter(schema: s_schema.Schema, obj: so.Object) -> bool:
227227
# TODO: Fix this.
228228
if not include_derived_types:
229229
excluded_modules.add(sn.UnqualName('__derived__'))
230+
excluded_modules.add(sn.UnqualName('__ext_generated_types__'))
230231

231232
excluded_modules.add(sn.UnqualName('__ext_casts__'))
232233
excluded_modules.add(sn.UnqualName('__ext_index_matches__'))

edb/schema/expraliases.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,7 @@ def _create_alias_types(
576576
dict(
577577
alias_is_persistent=True,
578578
expr_type=s_types.ExprType.Select,
579+
is_schema_generated=True,
579580
from_alias=True,
580581
from_global=is_global,
581582
),

edb/schema/indexes.py

Lines changed: 274 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,11 @@
2929
TYPE_CHECKING,
3030
)
3131

32+
from dataclasses import dataclass
33+
import json
34+
import os
35+
import pathlib
36+
3237
from edb import edgeql
3338
from edb import errors
3439
from edb.common import parsing
@@ -356,6 +361,12 @@ class Index(
356361
allow_ddl_set=True,
357362
)
358363

364+
# If this is a generated AI index, track the created type
365+
generated_ai_index_type = so.SchemaField(
366+
s_types.Type,
367+
default=None,
368+
)
369+
359370
def __repr__(self) -> str:
360371
cls = self.__class__
361372
return '<{}.{} {!r} at 0x{:x}>'.format(
@@ -966,6 +977,43 @@ def canonicalize_attributes(
966977
)
967978
return schema
968979

980+
def _handle_ai_index_op(
981+
self,
982+
schema: s_schema.Schema,
983+
context: sd.CommandContext,
984+
) -> None:
985+
if context.canonical:
986+
return
987+
988+
# Get the provider and model from the kwargs
989+
kwargs: Optional[s_expr.ExpressionDict] = (
990+
self.get_resolved_attribute_value(
991+
'kwargs', schema=schema, context=context
992+
)
993+
)
994+
if not kwargs:
995+
return
996+
997+
embedding_model_kwarg: Optional[s_expr.Expression] = (
998+
kwargs.get('embedding_model')
999+
)
1000+
if not embedding_model_kwarg:
1001+
return
1002+
1003+
embedding_model_expr = s_expr.Expression.compiled(
1004+
embedding_model_kwarg, schema=schema, context=None
1005+
)
1006+
model_name = embedding_model_expr.as_python_value()
1007+
if ':' in model_name:
1008+
model_command, model_typeshell = _create_ai_model_type(
1009+
model_name, schema, context, span=self.span,
1010+
)
1011+
self.add_prerequisite(model_command)
1012+
self.set_attribute_value(
1013+
'generated_ai_index_type',
1014+
model_typeshell,
1015+
)
1016+
9691017
def ast_ignore_field_ownership(self, field: str) -> bool:
9701018
"""Whether to force generating an AST even though field isn't owned"""
9711019
return field == "deferred"
@@ -987,6 +1035,217 @@ def _append_subcmd_ast(
9871035
super()._append_subcmd_ast(schema, node, subcmd, context)
9881036

9891037

1038+
@dataclass(kw_only=True, frozen=True)
1039+
class ProviderDesc:
1040+
name: str
1041+
1042+
1043+
@dataclass(kw_only=True, frozen=True)
1044+
class EmbeddingModelDesc:
1045+
model_name: str
1046+
model_provider: str
1047+
max_input_tokens: int
1048+
max_batch_tokens: int
1049+
max_output_dimensions: int
1050+
supports_shortening: bool
1051+
1052+
1053+
def _try_load_reference_descs() -> Optional[
1054+
tuple[
1055+
Mapping[str, ProviderDesc],
1056+
Mapping[str, EmbeddingModelDesc],
1057+
]
1058+
]:
1059+
try:
1060+
provider_descs: dict[str, ProviderDesc] = {}
1061+
embedding_model_descs: dict[str, EmbeddingModelDesc] = {}
1062+
1063+
local_reference_path = os.path.join(
1064+
pathlib.Path(__file__).parent.parent,
1065+
'server',
1066+
'protocol',
1067+
'ai_reference.json',
1068+
)
1069+
with open(local_reference_path) as local_reference_file:
1070+
local_reference: dict[str, Any] = json.load(local_reference_file)
1071+
if provider_ref := local_reference.get("providers"):
1072+
for name, ref in provider_ref.items():
1073+
provider_descs[name] = ProviderDesc(
1074+
name=ref["name"],
1075+
)
1076+
if embedding_models_ref := local_reference.get("embedding_models"):
1077+
for name, ref in embedding_models_ref.items():
1078+
embedding_model_descs[name] = EmbeddingModelDesc(
1079+
model_name=ref["model_name"],
1080+
model_provider=ref["model_provider"],
1081+
max_input_tokens=ref["max_input_tokens"],
1082+
max_batch_tokens=ref["max_batch_tokens"],
1083+
max_output_dimensions=ref["max_output_dimensions"],
1084+
supports_shortening=ref["supports_shortening"],
1085+
)
1086+
return (provider_descs, embedding_model_descs)
1087+
1088+
except Exception:
1089+
# Ignore failures
1090+
return None
1091+
1092+
1093+
def _lookup_embedding_model_description(
1094+
provider_name: str,
1095+
model_name: str,
1096+
) -> Optional[tuple[ProviderDesc, EmbeddingModelDesc]]:
1097+
reference_descs = _try_load_reference_descs()
1098+
1099+
if reference_descs is None:
1100+
return None
1101+
1102+
provider_descs, embedding_model_descs = reference_descs
1103+
1104+
provider_desc = provider_descs.get(provider_name)
1105+
model_desc = embedding_model_descs.get(model_name)
1106+
if provider_desc is None or model_desc is None:
1107+
return None
1108+
if provider_name != model_desc.model_provider:
1109+
return None
1110+
1111+
return (provider_desc, model_desc)
1112+
1113+
1114+
def _create_ai_model_type(
1115+
model_name: str,
1116+
schema: s_schema.Schema,
1117+
context: sd.CommandContext,
1118+
*,
1119+
span: Optional[parsing.Span],
1120+
) -> tuple[
1121+
Optional[sd.Command],
1122+
so.ObjectShell[s_types.Type],
1123+
]:
1124+
model_name_parts = model_name.split(':')
1125+
if len(model_name_parts) > 2:
1126+
raise errors.SchemaDefinitionError(
1127+
f"Invalid model uri, ':' used more than once: {model_name}",
1128+
span=span
1129+
)
1130+
1131+
provider_name = model_name_parts[0]
1132+
model_name = model_name_parts[1]
1133+
1134+
# Lookup the ai model description
1135+
provider_model_desc = _lookup_embedding_model_description(
1136+
provider_name, model_name
1137+
)
1138+
if not provider_model_desc:
1139+
raise errors.SchemaDefinitionError(
1140+
f"Invalid model uri, unknown provider and model: {model_name}",
1141+
span=span
1142+
)
1143+
1144+
provider_desc, model_desc = provider_model_desc
1145+
1146+
# Ensure that a suitable model type exists for the ai index
1147+
1148+
model_typename = sn.QualName(
1149+
'__ext_generated_types__',
1150+
f'ai_embedding_{provider_name}_{model_name}'
1151+
)
1152+
1153+
if model_type := schema.get(model_typename, None, type=s_types.Type):
1154+
# If the model type already exists, add a link to it
1155+
model_typeshell = so.ObjectShell(
1156+
name=model_typename,
1157+
schemaclass=type(model_type),
1158+
)
1159+
return None, model_typeshell
1160+
1161+
else:
1162+
# If the model type does not exist, create it and also add a link
1163+
model_ast = qlast.CreateObjectType(
1164+
name=qlast.ObjectRef(
1165+
name=model_typename.name,
1166+
module=model_typename.module,
1167+
itemclass=qltypes.SchemaObjectClass.TYPE,
1168+
),
1169+
abstract=True,
1170+
bases=[
1171+
qlast.TypeName(
1172+
maintype=qlast.ObjectRef(
1173+
name='EmbeddingModel',
1174+
module='ext::ai',
1175+
),
1176+
),
1177+
],
1178+
commands=[
1179+
qlast.AlterAnnotationValue(
1180+
name=qlast.ObjectRef(
1181+
name='model_name',
1182+
module='ext::ai',
1183+
),
1184+
value=qlast.Constant.string(model_desc.model_name),
1185+
),
1186+
qlast.AlterAnnotationValue(
1187+
name=qlast.ObjectRef(
1188+
name='model_provider',
1189+
module='ext::ai',
1190+
),
1191+
value=qlast.Constant.string(provider_desc.name),
1192+
),
1193+
qlast.AlterAnnotationValue(
1194+
name=qlast.ObjectRef(
1195+
name='embedding_model_max_input_tokens',
1196+
module='ext::ai',
1197+
),
1198+
value=qlast.Constant.string(str(model_desc.max_input_tokens)),
1199+
),
1200+
qlast.AlterAnnotationValue(
1201+
name=qlast.ObjectRef(
1202+
name='embedding_model_max_batch_tokens',
1203+
module='ext::ai',
1204+
),
1205+
value=qlast.Constant.string(str(model_desc.max_batch_tokens)),
1206+
),
1207+
qlast.AlterAnnotationValue(
1208+
name=qlast.ObjectRef(
1209+
name='embedding_model_max_output_dimensions',
1210+
module='ext::ai',
1211+
),
1212+
value=qlast.Constant.string(str(model_desc.max_output_dimensions)),
1213+
),
1214+
],
1215+
)
1216+
if model_desc.supports_shortening:
1217+
model_ast.commands.append(
1218+
qlast.AlterAnnotationValue(
1219+
name=qlast.ObjectRef(
1220+
name='embedding_model_supports_shortening',
1221+
module='ext::ai',
1222+
),
1223+
value=qlast.Constant.string('true'),
1224+
)
1225+
)
1226+
model_cmd = sd.compile_ddl(
1227+
schema,
1228+
model_ast,
1229+
)
1230+
model_cmd.set_attribute_value('is_schema_generated', True)
1231+
1232+
# Doing this since using model_cmd or a copy doesn't work
1233+
dummy_model_cmd = sd.compile_ddl(
1234+
schema,
1235+
model_ast,
1236+
)
1237+
dummy_model_cmd.set_attribute_value('is_schema_generated', True)
1238+
new_schema = dummy_model_cmd.apply(schema, context)
1239+
model_type = new_schema.get(model_typename, None, type=s_types.Type)
1240+
assert model_type is not None
1241+
model_typeshell = so.ObjectShell(
1242+
name=model_typename,
1243+
schemaclass=type(model_type),
1244+
)
1245+
1246+
return model_cmd, model_typeshell
1247+
1248+
9901249
class CreateIndex(
9911250
IndexCommand,
9921251
referencing.CreateReferencedInheritingObject[Index],
@@ -1382,6 +1641,8 @@ def _create_begin(
13821641
schema: s_schema.Schema,
13831642
context: sd.CommandContext,
13841643
) -> s_schema.Schema:
1644+
self._handle_ai_index_op(schema, context)
1645+
13851646
schema = super()._create_begin(schema, context)
13861647
referrer_ctx = self.get_referrer_context(context)
13871648
if (
@@ -1601,6 +1862,19 @@ def _get_referenced_embedding_model(
16011862
kwargs = self.scls.get_concrete_kwargs_as_values(schema)
16021863
model_name = kwargs["embedding_model"]
16031864

1865+
if ':' in model_name:
1866+
# We are doing URI lookup.
1867+
# If a new model type is needed, it will have been created in
1868+
# _handle_ai_index_op.
1869+
model_name_parts = model_name.split(':')
1870+
if len(model_name_parts) > 2:
1871+
raise errors.SchemaDefinitionError(
1872+
f"Invalid model uri, ':' used more than once: {model_name}",
1873+
span=self.span
1874+
)
1875+
1876+
model_name = model_name_parts[1]
1877+
16041878
models = get_defined_ext_ai_embedding_models(schema, model_name)
16051879
if len(models) == 0:
16061880
raise errors.SchemaDefinitionError(

edb/schema/objtypes.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,11 +538,17 @@ def _get_ast(
538538
*,
539539
parent_node: Optional[qlast.DDLOperation] = None,
540540
) -> Optional[qlast.DDLOperation]:
541-
if (self.get_attribute_value('expr_type')
542-
and not self.get_attribute_value('expr')):
541+
if (
542+
(
543543
# This is a nested view type, e.g
544544
# __FooAlias_bar produced by FooAlias := (SELECT Foo { bar: ... })
545545
# and should obviously not appear as a top level definition.
546+
self.get_attribute_value('expr_type')
547+
and not self.get_attribute_value('expr')
548+
)
549+
# Or another generated type which should not appear.
550+
or self.get_orig_attribute_value('is_schema_generated')
551+
):
546552
return None
547553
else:
548554
return super()._get_ast(schema, context, parent_node=parent_node)
@@ -599,6 +605,24 @@ class AlterObjectType(
599605
):
600606
astnode = qlast.AlterObjectType
601607

608+
def _get_ast(
609+
self,
610+
schema: s_schema.Schema,
611+
context: sd.CommandContext,
612+
*,
613+
parent_node: Optional[qlast.DDLOperation] = None,
614+
) -> Optional[qlast.DDLOperation]:
615+
if (
616+
hasattr(self, 'scls')
617+
and self.scls.get_is_schema_generated(schema)
618+
):
619+
# This is another automatically generated type.
620+
# Appropriate DDL will be generated from the corresponding
621+
# AlterObject node.
622+
return None
623+
else:
624+
return super()._get_ast(schema, context, parent_node=parent_node)
625+
602626
def _alter_begin(
603627
self,
604628
schema: s_schema.Schema,

edb/schema/schema.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@
9696
sn.UnqualName('__derived__'),
9797
sn.UnqualName('__ext_casts__'),
9898
sn.UnqualName('__ext_index_matches__'),
99+
sn.UnqualName('__ext_generated_types__'),
99100
)
100101

101102
# Specifies the order of processing of files and directories in lib/

0 commit comments

Comments
 (0)