@@ -30,6 +30,7 @@ import (
3030 "github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors"
3131 "github.com/dolthub/go-mysql-server/sql/expression"
3232 "github.com/dolthub/go-mysql-server/sql/fulltext"
33+ "github.com/dolthub/go-mysql-server/sql/iters"
3334 "github.com/dolthub/go-mysql-server/sql/transform"
3435 "github.com/dolthub/go-mysql-server/sql/types"
3536)
@@ -345,6 +346,36 @@ type spatialRangePartitionIter struct {
345346 minX , minY , maxX , maxY float64
346347}
347348
349+ // vectorPartitionIter is the sql.PartitionIter for vector indexes.
350+ // Because it only ever has one partition, it also implements sql.Partition
351+ // and returns itself in calls to Next.
352+ type vectorPartitionIter struct {
353+ Column sql.Expression
354+ sql.OrderAndLimit
355+ visited bool
356+ }
357+
358+ var _ sql.PartitionIter = (* vectorPartitionIter )(nil )
359+ var _ sql.Partition = (* vectorPartitionIter )(nil )
360+
361+ // Key returns the key used to distinguish partitions. Since it only ever has one partition,
362+ // this value is unused.
363+ func (v * vectorPartitionIter ) Key () []byte {
364+ return nil
365+ }
366+
367+ func (v * vectorPartitionIter ) Close (_ * sql.Context ) error {
368+ return nil
369+ }
370+
371+ func (v * vectorPartitionIter ) Next (_ * sql.Context ) (sql.Partition , error ) {
372+ if v .visited {
373+ return nil , io .EOF
374+ }
375+ v .visited = true
376+ return v , nil
377+ }
378+
348379var _ sql.PartitionIter = (* spatialRangePartitionIter )(nil )
349380
350381func (i spatialRangePartitionIter ) Close (ctx * sql.Context ) error {
@@ -515,6 +546,25 @@ func (t *Table) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.Ro
515546 filters = append (t .filters , r .rang )
516547 }
517548
549+ if vectorPartition , ok := partition .(* vectorPartitionIter ); ok {
550+ // Assume only one partition for now
551+ rows := data .partitions [string (data .partitionKeys [0 ])]
552+
553+ sf := sql.SortFields {
554+ {Column : vectorPartition .OrderBy , Order : sql .Ascending },
555+ }
556+
557+ if vectorPartition .Limit != nil {
558+ limit , err := iters .GetInt64Value (ctx , vectorPartition .Limit )
559+ if err != nil {
560+ return nil , err
561+ }
562+ return iters .NewTopRowsIter (sf , limit , vectorPartition .CalcFoundRows , sql .RowsToRowIter (rows ... ), 0 ), nil
563+ }
564+
565+ return iters .NewSortIter (sf , sql .RowsToRowIter (rows ... )), nil
566+ }
567+
518568 rows , ok := data .partitions [string (partition .Key ())]
519569 if ! ok {
520570 return nil , sql .ErrPartitionNotFound .New (partition .Key ())
@@ -1544,6 +1594,14 @@ var _ sql.StatisticsTable = (*IndexedTable)(nil)
15441594
15451595func (t * IndexedTable ) LookupPartitions (ctx * sql.Context , lookup sql.IndexLookup ) (sql.PartitionIter , error ) {
15461596 memIdx := lookup .Index .(* Index )
1597+
1598+ if lookup .VectorOrderAndLimit .OrderBy != nil {
1599+ return & vectorPartitionIter {
1600+ Column : lookup .Index .(* Index ).Exprs [0 ],
1601+ OrderAndLimit : lookup .VectorOrderAndLimit ,
1602+ }, nil
1603+ }
1604+
15471605 lookupRanges , ok := lookup .Ranges .(sql.MySQLRangeCollection )
15481606 if ! ok {
15491607 return nil , fmt .Errorf ("expected MySQL ranges in memory indexed table" )
@@ -1639,6 +1697,10 @@ func (t *IndexedTable) PartitionRows(ctx *sql.Context, partition sql.Partition)
16391697 return iter , nil
16401698 }
16411699
1700+ if _ , ok := partition .(* vectorPartitionIter ); ok {
1701+ return iter , nil
1702+ }
1703+
16421704 if t .Lookup .Index != nil {
16431705 idx := t .Lookup .Index .(* Index )
16441706 sf := make (sql.SortFields , len (idx .Exprs ))
@@ -1950,17 +2012,18 @@ func (t *Table) createIndex(data *TableData, name string, columns []sql.IndexCol
19502012 }
19512013
19522014 return & Index {
1953- DB : t .dbName (),
1954- DriverName : "" ,
1955- Tbl : t ,
1956- TableName : t .name ,
1957- Exprs : exprs ,
1958- Name : name ,
1959- Unique : constraint == sql .IndexConstraint_Unique ,
1960- Spatial : constraint == sql .IndexConstraint_Spatial ,
1961- Fulltext : constraint == sql .IndexConstraint_Fulltext ,
1962- CommentStr : comment ,
1963- PrefixLens : prefixLengths ,
2015+ DB : t .dbName (),
2016+ DriverName : "" ,
2017+ Tbl : t ,
2018+ TableName : t .name ,
2019+ Exprs : exprs ,
2020+ Name : name ,
2021+ Unique : constraint == sql .IndexConstraint_Unique ,
2022+ Spatial : constraint == sql .IndexConstraint_Spatial ,
2023+ Fulltext : constraint == sql .IndexConstraint_Fulltext ,
2024+ SupportedVectorFunction : nil ,
2025+ CommentStr : comment ,
2026+ PrefixLens : prefixLengths ,
19642027 }, nil
19652028}
19662029
@@ -2058,6 +2121,31 @@ func (t *Table) CreateFulltextIndex(ctx *sql.Context, indexDef sql.IndexDef, key
20582121 return nil
20592122}
20602123
2124+ func (t * Table ) CreateVectorIndex (ctx * sql.Context , idx sql.IndexDef , distanceType expression.DistanceType ) error {
2125+ if len (idx .Columns ) > 1 {
2126+ return fmt .Errorf ("vector indexes must have exactly one column" )
2127+ }
2128+
2129+ sess := SessionFromContext (ctx )
2130+ data := sess .tableData (t )
2131+
2132+ if data .indexes == nil {
2133+ data .indexes = make (map [string ]sql.Index )
2134+ }
2135+
2136+ index , err := t .createIndex (data , idx .Name , idx .Columns , idx .Constraint , idx .Comment )
2137+ if err != nil {
2138+ return err
2139+ }
2140+ index .(* Index ).SupportedVectorFunction = distanceType
2141+
2142+ // Store the computed index name in the case of an empty index name being passed in
2143+ data .indexes [strings .ToLower (index .ID ())] = index
2144+ sess .putTable (data )
2145+
2146+ return nil
2147+ }
2148+
20612149// ModifyStoredCollation implements sql.CollationAlterableTable
20622150func (t * Table ) ModifyStoredCollation (ctx * sql.Context , collation sql.CollationID ) error {
20632151 return fmt .Errorf ("converting the collations of columns is not yet supported" )
0 commit comments