@@ -362,7 +362,7 @@ class Index(
362362 )
363363
364364 # If this is a generated AI index, track the created type
365- generated_ai_index_type = so .SchemaField (
365+ generated_ai_model_type = so .SchemaField (
366366 s_types .Type ,
367367 default = None ,
368368 )
@@ -981,6 +981,8 @@ def _handle_ai_index_op(
981981 self ,
982982 schema : s_schema .Schema ,
983983 context : sd .CommandContext ,
984+ * ,
985+ is_alter : bool = False ,
984986 ) -> None :
985987 if context .canonical :
986988 return
@@ -1005,15 +1007,63 @@ def _handle_ai_index_op(
10051007 )
10061008 model_name = embedding_model_expr .as_python_value ()
10071009 if ':' in model_name :
1008- model_command , model_typeshell = _create_ai_model_type (
1009- model_name , schema , context , span = self .span ,
1010+ pschema = schema
1011+
1012+ drop_old_model_cmd : Optional [sd .Command ] = None
1013+ if is_alter :
1014+ drop_old_model_cmd = self ._delete_ai_model_type (
1015+ self .scls , schema , context )
1016+ with context .suspend_dep_verification ():
1017+ pschema = drop_old_model_cmd .apply (pschema , context )
1018+
1019+ model_cmd , model_typeshell = _create_ai_model_type (
1020+ model_name , pschema , context , span = self .span ,
10101021 )
1011- self .add_prerequisite (model_command )
1022+ if drop_old_model_cmd :
1023+ model_cmd .prepend (drop_old_model_cmd )
1024+
1025+ self .add_prerequisite (model_cmd )
10121026 self .set_attribute_value (
1013- 'generated_ai_index_type ' ,
1027+ 'generated_ai_model_type ' ,
10141028 model_typeshell ,
10151029 )
10161030
1031+ def _delete_ai_model_type (
1032+ self ,
1033+ scls : Index ,
1034+ schema : s_schema .Schema ,
1035+ context : sd .CommandContext ,
1036+ ) -> sd .CommandGroup :
1037+
1038+ model_type : Optional [s_types .Type ] = (
1039+ scls .get_generated_ai_model_type (schema )
1040+ )
1041+ if not model_type :
1042+ return sd .DeltaRoot ()
1043+
1044+ delta = sd .DeltaRoot ()
1045+
1046+ alter_index = scls .init_delta_command (schema , sd .AlterObject )
1047+ alter_index .canonical = True
1048+
1049+ # Unset generated_ai_model_type, so the types can be dropped
1050+ alter_index .add (
1051+ sd .AlterObjectProperty (
1052+ property = 'generated_ai_model_type' , new_value = None
1053+ )
1054+ )
1055+ delta .add (alter_index )
1056+
1057+ drop_model_type = model_type .init_delta_command (
1058+ schema , sd .DeleteObject , if_exists = True , if_unused = True
1059+ )
1060+ subcmds = drop_model_type ._canonicalize (schema , context , model_type )
1061+ drop_model_type .update (subcmds )
1062+
1063+ delta .add (drop_model_type )
1064+
1065+ return delta
1066+
10171067 def ast_ignore_field_ownership (self , field : str ) -> bool :
10181068 """Whether to force generating an AST even though field isn't owned"""
10191069 return field == "deferred"
@@ -1118,7 +1168,7 @@ def _create_ai_model_type(
11181168 * ,
11191169 span : Optional [parsing .Span ],
11201170) -> tuple [
1121- Optional [ sd .Command ] ,
1171+ sd .Command ,
11221172 so .ObjectShell [s_types .Type ],
11231173]:
11241174 model_name_parts = model_name .split (':' )
@@ -1156,7 +1206,7 @@ def _create_ai_model_type(
11561206 name = model_typename ,
11571207 schemaclass = type (model_type ),
11581208 )
1159- return None , model_typeshell
1209+ return sd . DeltaRoot () , model_typeshell
11601210
11611211 else :
11621212 # If the model type does not exist, create it and also add a link
@@ -1641,7 +1691,7 @@ def _create_begin(
16411691 schema : s_schema .Schema ,
16421692 context : sd .CommandContext ,
16431693 ) -> s_schema .Schema :
1644- self ._handle_ai_index_op (schema , context )
1694+ self ._handle_ai_index_op (schema , context , is_alter = False )
16451695
16461696 schema = super ()._create_begin (schema , context )
16471697 referrer_ctx = self .get_referrer_context (context )
@@ -1925,6 +1975,8 @@ def _alter_begin(
19251975 schema : s_schema .Schema ,
19261976 context : sd .CommandContext ,
19271977 ) -> s_schema .Schema :
1978+ self ._handle_ai_index_op (schema , context , is_alter = True )
1979+
19281980 schema = super ()._alter_begin (schema , context )
19291981 referrer_ctx = self .get_referrer_context (context )
19301982 if (
@@ -2036,6 +2088,13 @@ def _delete_begin(
20362088 if not context .canonical :
20372089 for param in self .scls .get_params (schema ).objects (schema ):
20382090 self .add (param .init_delta_command (schema , sd .DeleteObject ))
2091+ if (
2092+ not context .canonical
2093+ and self .scls .get_generated_ai_model_type (schema )
2094+ ):
2095+ self .add_caused (
2096+ self ._delete_ai_model_type (self .scls , schema , context )
2097+ )
20392098 return schema
20402099
20412100 @classmethod
0 commit comments