Skip to content

Commit 9555784

Browse files
authored
Merge pull request #2661 from dolthub/nicktobey/vector
Example in-memory Vector index using the existing index APIs.
2 parents 5aacdb1 + af74afc commit 9555784

20 files changed

+1587
-860
lines changed

memory/index.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,11 @@ type Index struct {
3636
Unique bool
3737
Spatial bool
3838
Fulltext bool
39-
CommentStr string
40-
PrefixLens []uint16
39+
// If SupportedVectorFunction is non-nil, this index can be used to optimize ORDER BY
40+
// expressions on this type of distance function.
41+
SupportedVectorFunction expression.DistanceType
42+
CommentStr string
43+
PrefixLens []uint16
4144
fulltextInfo
4245
}
4346

@@ -118,6 +121,14 @@ func (idx *Index) IsFullText() bool {
118121
return idx.Fulltext
119122
}
120123

124+
func (idx *Index) CanSupportOrderBy(expr sql.Expression) bool {
125+
if idx.SupportedVectorFunction == nil {
126+
return false
127+
}
128+
dist, isDist := expr.(*expression.Distance)
129+
return isDist && idx.SupportedVectorFunction.CanEval(dist.DistanceType)
130+
}
131+
121132
func (idx *Index) Comment() string {
122133
return idx.CommentStr
123134
}

memory/table.go

Lines changed: 99 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ import (
3030
"github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors"
3131
"github.com/dolthub/go-mysql-server/sql/expression"
3232
"github.com/dolthub/go-mysql-server/sql/fulltext"
33+
"github.com/dolthub/go-mysql-server/sql/iters"
3334
"github.com/dolthub/go-mysql-server/sql/transform"
3435
"github.com/dolthub/go-mysql-server/sql/types"
3536
)
@@ -345,6 +346,36 @@ type spatialRangePartitionIter struct {
345346
minX, minY, maxX, maxY float64
346347
}
347348

349+
// vectorPartitionIter is the sql.PartitionIter for vector indexes.
350+
// Because it only ever has one partition, it also implements sql.Partition
351+
// and returns itself in calls to Next.
352+
type vectorPartitionIter struct {
353+
Column sql.Expression
354+
sql.OrderAndLimit
355+
visited bool
356+
}
357+
358+
var _ sql.PartitionIter = (*vectorPartitionIter)(nil)
359+
var _ sql.Partition = (*vectorPartitionIter)(nil)
360+
361+
// Key returns the key used to distinguish partitions. Since it only ever has one partition,
362+
// this value is unused.
363+
func (v *vectorPartitionIter) Key() []byte {
364+
return nil
365+
}
366+
367+
func (v *vectorPartitionIter) Close(_ *sql.Context) error {
368+
return nil
369+
}
370+
371+
func (v *vectorPartitionIter) Next(_ *sql.Context) (sql.Partition, error) {
372+
if v.visited {
373+
return nil, io.EOF
374+
}
375+
v.visited = true
376+
return v, nil
377+
}
378+
348379
var _ sql.PartitionIter = (*spatialRangePartitionIter)(nil)
349380

350381
func (i spatialRangePartitionIter) Close(ctx *sql.Context) error {
@@ -515,6 +546,25 @@ func (t *Table) PartitionRows(ctx *sql.Context, partition sql.Partition) (sql.Ro
515546
filters = append(t.filters, r.rang)
516547
}
517548

549+
if vectorPartition, ok := partition.(*vectorPartitionIter); ok {
550+
// Assume only one partition for now
551+
rows := data.partitions[string(data.partitionKeys[0])]
552+
553+
sf := sql.SortFields{
554+
{Column: vectorPartition.OrderBy, Order: sql.Ascending},
555+
}
556+
557+
if vectorPartition.Limit != nil {
558+
limit, err := iters.GetInt64Value(ctx, vectorPartition.Limit)
559+
if err != nil {
560+
return nil, err
561+
}
562+
return iters.NewTopRowsIter(sf, limit, vectorPartition.CalcFoundRows, sql.RowsToRowIter(rows...), 0), nil
563+
}
564+
565+
return iters.NewSortIter(sf, sql.RowsToRowIter(rows...)), nil
566+
}
567+
518568
rows, ok := data.partitions[string(partition.Key())]
519569
if !ok {
520570
return nil, sql.ErrPartitionNotFound.New(partition.Key())
@@ -1544,6 +1594,14 @@ var _ sql.StatisticsTable = (*IndexedTable)(nil)
15441594

15451595
func (t *IndexedTable) LookupPartitions(ctx *sql.Context, lookup sql.IndexLookup) (sql.PartitionIter, error) {
15461596
memIdx := lookup.Index.(*Index)
1597+
1598+
if lookup.VectorOrderAndLimit.OrderBy != nil {
1599+
return &vectorPartitionIter{
1600+
Column: lookup.Index.(*Index).Exprs[0],
1601+
OrderAndLimit: lookup.VectorOrderAndLimit,
1602+
}, nil
1603+
}
1604+
15471605
lookupRanges, ok := lookup.Ranges.(sql.MySQLRangeCollection)
15481606
if !ok {
15491607
return nil, fmt.Errorf("expected MySQL ranges in memory indexed table")
@@ -1639,6 +1697,10 @@ func (t *IndexedTable) PartitionRows(ctx *sql.Context, partition sql.Partition)
16391697
return iter, nil
16401698
}
16411699

1700+
if _, ok := partition.(*vectorPartitionIter); ok {
1701+
return iter, nil
1702+
}
1703+
16421704
if t.Lookup.Index != nil {
16431705
idx := t.Lookup.Index.(*Index)
16441706
sf := make(sql.SortFields, len(idx.Exprs))
@@ -1950,17 +2012,18 @@ func (t *Table) createIndex(data *TableData, name string, columns []sql.IndexCol
19502012
}
19512013

19522014
return &Index{
1953-
DB: t.dbName(),
1954-
DriverName: "",
1955-
Tbl: t,
1956-
TableName: t.name,
1957-
Exprs: exprs,
1958-
Name: name,
1959-
Unique: constraint == sql.IndexConstraint_Unique,
1960-
Spatial: constraint == sql.IndexConstraint_Spatial,
1961-
Fulltext: constraint == sql.IndexConstraint_Fulltext,
1962-
CommentStr: comment,
1963-
PrefixLens: prefixLengths,
2015+
DB: t.dbName(),
2016+
DriverName: "",
2017+
Tbl: t,
2018+
TableName: t.name,
2019+
Exprs: exprs,
2020+
Name: name,
2021+
Unique: constraint == sql.IndexConstraint_Unique,
2022+
Spatial: constraint == sql.IndexConstraint_Spatial,
2023+
Fulltext: constraint == sql.IndexConstraint_Fulltext,
2024+
SupportedVectorFunction: nil,
2025+
CommentStr: comment,
2026+
PrefixLens: prefixLengths,
19642027
}, nil
19652028
}
19662029

@@ -2058,6 +2121,31 @@ func (t *Table) CreateFulltextIndex(ctx *sql.Context, indexDef sql.IndexDef, key
20582121
return nil
20592122
}
20602123

2124+
func (t *Table) CreateVectorIndex(ctx *sql.Context, idx sql.IndexDef, distanceType expression.DistanceType) error {
2125+
if len(idx.Columns) > 1 {
2126+
return fmt.Errorf("vector indexes must have exactly one column")
2127+
}
2128+
2129+
sess := SessionFromContext(ctx)
2130+
data := sess.tableData(t)
2131+
2132+
if data.indexes == nil {
2133+
data.indexes = make(map[string]sql.Index)
2134+
}
2135+
2136+
index, err := t.createIndex(data, idx.Name, idx.Columns, idx.Constraint, idx.Comment)
2137+
if err != nil {
2138+
return err
2139+
}
2140+
index.(*Index).SupportedVectorFunction = distanceType
2141+
2142+
// Store the computed index name in the case of an empty index name being passed in
2143+
data.indexes[strings.ToLower(index.ID())] = index
2144+
sess.putTable(data)
2145+
2146+
return nil
2147+
}
2148+
20612149
// ModifyStoredCollation implements sql.CollationAlterableTable
20622150
func (t *Table) ModifyStoredCollation(ctx *sql.Context, collation sql.CollationID) error {
20632151
return fmt.Errorf("converting the collations of columns is not yet supported")

sql/analyzer/index_analyzer_test.go

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,16 @@ func (i dummyIdx) Expressions() []string {
142142
}
143143
return exprs
144144
}
145-
func (i *dummyIdx) ID() string { return i.id }
146-
func (i *dummyIdx) Database() string { return i.database }
147-
func (i *dummyIdx) Table() string { return i.table }
148-
func (i *dummyIdx) IsUnique() bool { return false }
149-
func (i *dummyIdx) IsSpatial() bool { return false }
150-
func (i *dummyIdx) IsFullText() bool { return false }
151-
func (i *dummyIdx) Comment() string { return "" }
152-
func (i *dummyIdx) IsGenerated() bool { return false }
145+
func (i *dummyIdx) ID() string { return i.id }
146+
func (i *dummyIdx) Database() string { return i.database }
147+
func (i *dummyIdx) Table() string { return i.table }
148+
func (i *dummyIdx) IsUnique() bool { return false }
149+
func (i *dummyIdx) IsSpatial() bool { return false }
150+
func (i *dummyIdx) IsFullText() bool { return false }
151+
func (i *dummyIdx) Comment() string { return "" }
152+
func (i *dummyIdx) IsGenerated() bool { return false }
153+
func (i *dummyIdx) CanSupportOrderBy(sql.Expression) bool { return false }
154+
153155
func (i *dummyIdx) IndexType() string { return "BTREE" }
154156
func (i *dummyIdx) PrefixLengths() []uint16 { return nil }
155157

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
package analyzer
2+
3+
import (
4+
"github.com/dolthub/go-mysql-server/sql"
5+
"github.com/dolthub/go-mysql-server/sql/expression"
6+
"github.com/dolthub/go-mysql-server/sql/plan"
7+
"github.com/dolthub/go-mysql-server/sql/transform"
8+
)
9+
10+
// replaceIdxSort applies an IndexAccess when there is an `OrderBy` over a prefix of any columns with Indexes
11+
func replaceIdxOrderByDistance(ctx *sql.Context, a *Analyzer, n sql.Node, scope *plan.Scope, sel RuleSelector, qFlags *sql.QueryFlags) (sql.Node, transform.TreeIdentity, error) {
12+
return replaceIdxOrderByDistanceHelper(ctx, scope, n, nil)
13+
}
14+
15+
func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node sql.Node, sortNode plan.Sortable) (sql.Node, transform.TreeIdentity, error) {
16+
switch n := node.(type) {
17+
case plan.Sortable:
18+
sortNode = n // lowest parent sort node
19+
case *plan.ResolvedTable:
20+
if sortNode == nil {
21+
return n, transform.SameTree, nil
22+
}
23+
24+
table := n.UnderlyingTable()
25+
idxTbl, ok := table.(sql.IndexAddressableTable)
26+
if !ok {
27+
return n, transform.SameTree, nil
28+
}
29+
if indexSearchable, ok := table.(sql.IndexSearchableTable); ok && indexSearchable.SkipIndexCosting() {
30+
return n, transform.SameTree, nil
31+
}
32+
33+
tableAliases, err := getTableAliases(sortNode, scope)
34+
if err != nil {
35+
return n, transform.SameTree, nil
36+
}
37+
38+
var idx sql.Index
39+
idxs, err := idxTbl.GetIndexes(ctx)
40+
if err != nil {
41+
return nil, transform.SameTree, err
42+
}
43+
sfExprs := normalizeExpressions(tableAliases, sortNode.GetSortFields().ToExpressions()...)
44+
sfAliases := aliasedExpressionsInNode(sortNode)
45+
46+
// TODO: Instead of checking both sides of the expression,
47+
// use a previous pass to normalize distance functions so
48+
// that the literal is always on the same side.
49+
if len(sfExprs) != 1 {
50+
return n, transform.SameTree, nil
51+
}
52+
distance, isDistance := sfExprs[0].(*expression.Distance)
53+
if !isDistance {
54+
return n, transform.SameTree, nil
55+
}
56+
var column sql.Expression
57+
_, leftIsLiteral := distance.LeftChild.(*expression.Literal)
58+
if leftIsLiteral {
59+
column = distance.RightChild
60+
} else {
61+
_, rightIsLiteral := distance.RightChild.(*expression.Literal)
62+
if rightIsLiteral {
63+
column = distance.LeftChild
64+
} else {
65+
return n, transform.SameTree, nil
66+
}
67+
}
68+
69+
for _, idxCandidate := range idxs {
70+
if idxCandidate.IsSpatial() {
71+
continue
72+
}
73+
if !idxCandidate.CanSupportOrderBy(distance) {
74+
continue
75+
}
76+
if isSortFieldsValidPrefix([]sql.Expression{column}, sfAliases, idxCandidate.Expressions()) {
77+
idx = idxCandidate
78+
break
79+
}
80+
}
81+
if idx == nil {
82+
return n, transform.SameTree, nil
83+
}
84+
85+
var limit sql.Expression
86+
if topn, ok := sortNode.(*plan.TopN); ok {
87+
limit = topn.Limit
88+
}
89+
90+
lookup := sql.IndexLookup{
91+
Index: idx,
92+
Ranges: sql.MySQLRangeCollection{},
93+
VectorOrderAndLimit: sql.OrderAndLimit{
94+
OrderBy: distance,
95+
Limit: limit,
96+
},
97+
}
98+
nn, err := plan.NewStaticIndexedAccessForTableNode(n, lookup)
99+
if err != nil {
100+
return nil, transform.SameTree, err
101+
}
102+
return nn, transform.NewTree, err
103+
}
104+
105+
allSame := transform.SameTree
106+
newChildren := make([]sql.Node, len(node.Children()))
107+
for i, child := range node.Children() {
108+
var err error
109+
same := transform.SameTree
110+
switch c := child.(type) {
111+
case *plan.Project, *plan.TableAlias, *plan.ResolvedTable, *plan.Filter, *plan.Limit, *plan.Offset, *plan.Sort, *plan.IndexedTableAccess:
112+
newChildren[i], same, err = replaceIdxOrderByDistanceHelper(ctx, scope, child, sortNode)
113+
default:
114+
newChildren[i] = c
115+
}
116+
if err != nil {
117+
return nil, transform.SameTree, err
118+
}
119+
allSame = allSame && same
120+
}
121+
122+
if allSame {
123+
return node, transform.SameTree, nil
124+
}
125+
126+
// if sort node was replaced with indexed access, drop sort node
127+
if node == sortNode {
128+
return newChildren[0], transform.NewTree, nil
129+
}
130+
131+
newNode, err := node.WithChildren(newChildren...)
132+
if err != nil {
133+
return nil, transform.SameTree, err
134+
}
135+
return newNode, transform.NewTree, nil
136+
}

sql/analyzer/rule_ids.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ const (
111111
replaceAggId // replaceAgg
112112
replaceIdxSortId // replaceIdxSort
113113
insertTopNId // insertTopN
114+
replaceIdxOrderByDistanceId // replaceIdxOrderByDistance
114115
applyHashInId // applyHashIn
115116
resolveInsertRowsId // resolveInsertRows
116117
resolvePreparedInsertId // resolvePreparedInsert

0 commit comments

Comments
 (0)