Skip to content

Commit 3bcd756

Browse files
authored
Merge pull request #2721 from dolthub/nicktobey/vector2
Additional support for vector indexes.
2 parents b8ae9a1 + 07b686d commit 3bcd756

22 files changed

+342
-62
lines changed

enginetest/enginetests.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5870,6 +5870,18 @@ func TestIndexes(t *testing.T, h Harness) {
58705870
}
58715871
}
58725872

5873+
func TestVectorIndexes(t *testing.T, h Harness) {
5874+
for _, tt := range queries.VectorIndexQueries {
5875+
TestScript(t, h, tt)
5876+
}
5877+
}
5878+
5879+
func TestVectorFunctions(t *testing.T, h Harness) {
5880+
for _, tt := range queries.VectorFunctionQueries {
5881+
TestScript(t, h, tt)
5882+
}
5883+
}
5884+
58735885
func TestIndexPrefix(t *testing.T, h Harness) {
58745886
for _, tt := range queries.IndexPrefixQueries {
58755887
TestScript(t, h, tt)

enginetest/memory_engine_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,14 @@ func TestIndexes(t *testing.T) {
890890
enginetest.TestIndexes(t, enginetest.NewDefaultMemoryHarness())
891891
}
892892

893+
func TestVectorIndexes(t *testing.T) {
894+
enginetest.TestVectorIndexes(t, enginetest.NewDefaultMemoryHarness())
895+
}
896+
897+
func TestVectorFunctions(t *testing.T) {
898+
enginetest.TestVectorFunctions(t, enginetest.NewDefaultMemoryHarness())
899+
}
900+
893901
func TestIndexPrefix(t *testing.T) {
894902
enginetest.TestIndexPrefix(t, enginetest.NewDefaultMemoryHarness())
895903
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright 2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package queries
16+
17+
import (
18+
"github.com/dolthub/go-mysql-server/sql"
19+
"github.com/dolthub/go-mysql-server/sql/types"
20+
)
21+
22+
var VectorFunctionQueries = []ScriptTest{
23+
{
24+
Name: "basic usage of VEC_DISTANCE without index",
25+
SetUpScript: []string{
26+
"create table vectors (id int primary key, v json);",
27+
`insert into vectors values (1, '[3.0,4.0]'), (2, '[0.0,0.0]'), (3, '[1.0,-1.0]'), (4, '[-2.0,0.0]');`,
28+
},
29+
Assertions: []ScriptTestAssertion{
30+
{
31+
Query: "select VEC_DISTANCE('[10.0]', '[20.0]');",
32+
Expected: []sql.Row{{100.0}},
33+
},
34+
{
35+
Query: "select VEC_DISTANCE_L2_SQUARED('[1.0, 2.0]', '[5.0, 5.0]');",
36+
Expected: []sql.Row{{25.0}},
37+
},
38+
{
39+
Query: "select * from vectors order by VEC_DISTANCE('[0.0,0.0]', v)",
40+
Expected: []sql.Row{
41+
{2, types.MustJSON(`[0.0, 0.0]`)},
42+
{3, types.MustJSON(`[1.0, -1.0]`)},
43+
{4, types.MustJSON(`[-2.0, 0.0]`)},
44+
{1, types.MustJSON(`[3.0, 4.0]`)},
45+
},
46+
},
47+
{
48+
Query: "select * from vectors order by VEC_DISTANCE_L2_SQUARED('[-2.0,0.0]', v)",
49+
Expected: []sql.Row{
50+
{4, types.MustJSON(`[-2.0, 0.0]`)},
51+
{2, types.MustJSON(`[0.0, 0.0]`)},
52+
{3, types.MustJSON(`[1.0, -1.0]`)},
53+
{1, types.MustJSON(`[3.0, 4.0]`)},
54+
},
55+
},
56+
},
57+
},
58+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Copyright 2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package queries
16+
17+
import (
18+
"github.com/dolthub/go-mysql-server/sql"
19+
"github.com/dolthub/go-mysql-server/sql/types"
20+
)
21+
22+
var VectorIndexQueries = []ScriptTest{
23+
{
24+
Name: "basic vector index",
25+
SetUpScript: []string{
26+
"create table vectors (id int primary key, v json);",
27+
`insert into vectors values (1, '[4.0,3.0]'), (2, '[0.0,0.0]'), (3, '[-1.0,1.0]'), (4, '[0.0,-2.0]');`,
28+
`create vector index v_idx on vectors(v);`,
29+
},
30+
Assertions: []ScriptTestAssertion{
31+
{
32+
Query: "show create table vectors",
33+
Expected: []sql.Row{
34+
{"vectors", "CREATE TABLE `vectors` (\n `id` int NOT NULL,\n `v` json,\n PRIMARY KEY (`id`),\n VECTOR KEY `v_idx` (`v`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"},
35+
},
36+
},
37+
{
38+
Query: "select * from vectors order by VEC_DISTANCE('[0.0,0.0]', v) limit 4",
39+
Expected: []sql.Row{
40+
{2, types.MustJSON(`[0.0, 0.0]`)},
41+
{3, types.MustJSON(`[-1.0, 1.0]`)},
42+
{4, types.MustJSON(`[0.0, -2.0]`)},
43+
{1, types.MustJSON(`[4.0, 3.0]`)},
44+
},
45+
ExpectedIndexes: []string{"v_idx"},
46+
},
47+
{
48+
// Only queries with a limit can use a vector index.
49+
Query: "select * from vectors order by VEC_DISTANCE('[0.0,0.0]', v)",
50+
Expected: []sql.Row{
51+
{2, types.MustJSON(`[0.0, 0.0]`)},
52+
{3, types.MustJSON(`[-1.0, 1.0]`)},
53+
{4, types.MustJSON(`[0.0, -2.0]`)},
54+
{1, types.MustJSON(`[4.0, 3.0]`)},
55+
},
56+
ExpectedIndexes: nil,
57+
},
58+
{
59+
Query: "select * from vectors order by VEC_DISTANCE_L2_SQUARED('[0.0,-2.0]', v) limit 4",
60+
Expected: []sql.Row{
61+
{4, types.MustJSON(`[0.0, -2.0]`)},
62+
{2, types.MustJSON(`[0.0, 0.0]`)},
63+
{3, types.MustJSON(`[-1.0, 1.0]`)},
64+
{1, types.MustJSON(`[4.0, 3.0]`)},
65+
},
66+
ExpectedIndexes: []string{"v_idx"},
67+
},
68+
{
69+
// Ensure vector index is not used for range lookups.
70+
Query: "select * from vectors order by v limit 4",
71+
Expected: []sql.Row{
72+
{3, types.MustJSON(`[-1.0, 1.0]`)},
73+
{4, types.MustJSON(`[0.0, -2.0]`)},
74+
{2, types.MustJSON(`[0.0, 0.0]`)},
75+
{1, types.MustJSON(`[4.0, 3.0]`)},
76+
},
77+
ExpectedIndexes: []string{},
78+
},
79+
},
80+
},
81+
}

memory/index.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020

2121
"github.com/dolthub/go-mysql-server/sql"
2222
"github.com/dolthub/go-mysql-server/sql/expression"
23+
"github.com/dolthub/go-mysql-server/sql/expression/function/vector"
2324
"github.com/dolthub/go-mysql-server/sql/fulltext"
2425
"github.com/dolthub/go-mysql-server/sql/types"
2526
)
@@ -38,7 +39,7 @@ type Index struct {
3839
Fulltext bool
3940
// If SupportedVectorFunction is non-nil, this index can be used to optimize ORDER BY
4041
// expressions on this type of distance function.
41-
SupportedVectorFunction expression.DistanceType
42+
SupportedVectorFunction vector.DistanceType
4243
CommentStr string
4344
PrefixLens []uint16
4445
fulltextInfo
@@ -121,11 +122,15 @@ func (idx *Index) IsFullText() bool {
121122
return idx.Fulltext
122123
}
123124

125+
func (idx *Index) IsVector() bool {
126+
return idx.SupportedVectorFunction != nil
127+
}
128+
124129
func (idx *Index) CanSupportOrderBy(expr sql.Expression) bool {
125130
if idx.SupportedVectorFunction == nil {
126131
return false
127132
}
128-
dist, isDist := expr.(*expression.Distance)
133+
dist, isDist := expr.(*vector.Distance)
129134
return isDist && idx.SupportedVectorFunction.CanEval(dist.DistanceType)
130135
}
131136

memory/table.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"github.com/dolthub/go-mysql-server/sql"
3030
"github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors"
3131
"github.com/dolthub/go-mysql-server/sql/expression"
32+
"github.com/dolthub/go-mysql-server/sql/expression/function/vector"
3233
"github.com/dolthub/go-mysql-server/sql/fulltext"
3334
"github.com/dolthub/go-mysql-server/sql/iters"
3435
"github.com/dolthub/go-mysql-server/sql/transform"
@@ -1597,7 +1598,6 @@ func (t *IndexedTable) LookupPartitions(ctx *sql.Context, lookup sql.IndexLookup
15971598

15981599
if lookup.VectorOrderAndLimit.OrderBy != nil {
15991600
return &vectorPartitionIter{
1600-
Column: lookup.Index.(*Index).Exprs[0],
16011601
OrderAndLimit: lookup.VectorOrderAndLimit,
16021602
}, nil
16031603
}
@@ -2011,6 +2011,11 @@ func (t *Table) createIndex(data *TableData, name string, columns []sql.IndexCol
20112011
}
20122012
}
20132013

2014+
var vectorFunction vector.DistanceType
2015+
if constraint == sql.IndexConstraint_Vector {
2016+
vectorFunction = vector.DistanceL2Squared{}
2017+
}
2018+
20142019
return &Index{
20152020
DB: t.dbName(),
20162021
DriverName: "",
@@ -2021,7 +2026,7 @@ func (t *Table) createIndex(data *TableData, name string, columns []sql.IndexCol
20212026
Unique: constraint == sql.IndexConstraint_Unique,
20222027
Spatial: constraint == sql.IndexConstraint_Spatial,
20232028
Fulltext: constraint == sql.IndexConstraint_Fulltext,
2024-
SupportedVectorFunction: nil,
2029+
SupportedVectorFunction: vectorFunction,
20252030
CommentStr: comment,
20262031
PrefixLens: prefixLengths,
20272032
}, nil
@@ -2121,7 +2126,7 @@ func (t *Table) CreateFulltextIndex(ctx *sql.Context, indexDef sql.IndexDef, key
21212126
return nil
21222127
}
21232128

2124-
func (t *Table) CreateVectorIndex(ctx *sql.Context, idx sql.IndexDef, distanceType expression.DistanceType) error {
2129+
func (t *Table) CreateVectorIndex(ctx *sql.Context, idx sql.IndexDef, distanceType vector.DistanceType) error {
21252130
if len(idx.Columns) > 1 {
21262131
return fmt.Errorf("vector indexes must have exactly one column")
21272132
}

sql/analyzer/index_analyzer_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ func (i *dummyIdx) Table() string { return i.table }
148148
func (i *dummyIdx) IsUnique() bool { return false }
149149
func (i *dummyIdx) IsSpatial() bool { return false }
150150
func (i *dummyIdx) IsFullText() bool { return false }
151+
func (i *dummyIdx) IsVector() bool { return false }
151152
func (i *dummyIdx) Comment() string { return "" }
152153
func (i *dummyIdx) IsGenerated() bool { return false }
153154
func (i *dummyIdx) CanSupportOrderBy(sql.Expression) bool { return false }

sql/analyzer/replace_order_by_distance.go

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package analyzer
33
import (
44
"github.com/dolthub/go-mysql-server/sql"
55
"github.com/dolthub/go-mysql-server/sql/expression"
6+
"github.com/dolthub/go-mysql-server/sql/expression/function/vector"
67
"github.com/dolthub/go-mysql-server/sql/plan"
78
"github.com/dolthub/go-mysql-server/sql/transform"
89
)
@@ -12,9 +13,9 @@ func replaceIdxOrderByDistance(ctx *sql.Context, a *Analyzer, n sql.Node, scope
1213
return replaceIdxOrderByDistanceHelper(ctx, scope, n, nil)
1314
}
1415

15-
func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node sql.Node, sortNode plan.Sortable) (sql.Node, transform.TreeIdentity, error) {
16+
func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node sql.Node, sortNode *plan.TopN) (sql.Node, transform.TreeIdentity, error) {
1617
switch n := node.(type) {
17-
case plan.Sortable:
18+
case *plan.TopN:
1819
sortNode = n // lowest parent sort node
1920
case *plan.ResolvedTable:
2021
if sortNode == nil {
@@ -40,6 +41,11 @@ func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node s
4041
if err != nil {
4142
return nil, transform.SameTree, err
4243
}
44+
45+
// Column references have not been assigned their final indexes yet, so do that for the ORDER BY expression now.
46+
// We can safely do this because an expression that references other tables won't pass `isSortFieldsValidPrefix` below.
47+
sortNode = offsetAssignIndexes(sortNode).(*plan.TopN)
48+
4349
sfExprs := normalizeExpressions(tableAliases, sortNode.GetSortFields().ToExpressions()...)
4450
sfAliases := aliasedExpressionsInNode(sortNode)
4551

@@ -49,18 +55,21 @@ func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node s
4955
if len(sfExprs) != 1 {
5056
return n, transform.SameTree, nil
5157
}
52-
distance, isDistance := sfExprs[0].(*expression.Distance)
58+
distance, isDistance := sfExprs[0].(*vector.Distance)
5359
if !isDistance {
5460
return n, transform.SameTree, nil
5561
}
5662
var column sql.Expression
63+
var literal sql.Expression
5764
_, leftIsLiteral := distance.LeftChild.(*expression.Literal)
5865
if leftIsLiteral {
5966
column = distance.RightChild
67+
literal = distance.LeftChild
6068
} else {
6169
_, rightIsLiteral := distance.RightChild.(*expression.Literal)
6270
if rightIsLiteral {
6371
column = distance.LeftChild
72+
literal = distance.RightChild
6473
} else {
6574
return n, transform.SameTree, nil
6675
}
@@ -82,17 +91,15 @@ func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node s
8291
return n, transform.SameTree, nil
8392
}
8493

85-
var limit sql.Expression
86-
if topn, ok := sortNode.(*plan.TopN); ok {
87-
limit = topn.Limit
88-
}
94+
limit := sortNode.Limit
8995

9096
lookup := sql.IndexLookup{
9197
Index: idx,
9298
Ranges: sql.MySQLRangeCollection{},
9399
VectorOrderAndLimit: sql.OrderAndLimit{
94100
OrderBy: distance,
95101
Limit: limit,
102+
Literal: literal,
96103
},
97104
}
98105
nn, err := plan.NewStaticIndexedAccessForTableNode(n, lookup)
@@ -108,7 +115,7 @@ func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node s
108115
var err error
109116
same := transform.SameTree
110117
switch c := child.(type) {
111-
case *plan.Project, *plan.TableAlias, *plan.ResolvedTable, *plan.Filter, *plan.Limit, *plan.Offset, *plan.Sort, *plan.IndexedTableAccess:
118+
case *plan.Project, *plan.TableAlias, *plan.ResolvedTable, *plan.Filter, *plan.Limit, *plan.TopN, *plan.Offset, *plan.Sort, *plan.IndexedTableAccess:
112119
newChildren[i], same, err = replaceIdxOrderByDistanceHelper(ctx, scope, child, sortNode)
113120
default:
114121
newChildren[i] = c

sql/analyzer/replace_sort.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,10 @@ func replaceIdxSortHelper(ctx *sql.Context, scope *plan.Scope, node sql.Node, so
107107
if idxCandidate.IsSpatial() {
108108
continue
109109
}
110+
if idxCandidate.IsVector() {
111+
// TODO: It's possible that we may be able to use vector indexes for point lookups, but not range lookups
112+
continue
113+
}
110114
if isSortFieldsValidPrefix(sfExprs, sfAliases, idxCandidate.Expressions()) {
111115
idx = idxCandidate
112116
break

sql/analyzer/validate_create_table.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@ func validateModifyColumn(ctx *sql.Context, initialSch sql.Schema, schema sql.Sc
505505
if !strings.EqualFold(col.Name, oldColName) {
506506
continue
507507
}
508-
if types.IsJSON(newCol.Type) {
508+
if types.IsJSON(newCol.Type) && !index.IsVector() {
509509
return nil, sql.ErrJSONIndex.New(col.Name)
510510
}
511511
var prefixLen int64
@@ -883,7 +883,7 @@ func validateIndex(ctx *sql.Context, colMap map[string]*sql.Column, idxDef *sql.
883883
return sql.ErrDuplicateColumn.New(schCol.Name)
884884
}
885885
seenCols[schCol.Name] = struct{}{}
886-
if types.IsJSON(schCol.Type) {
886+
if types.IsJSON(schCol.Type) && !idxDef.IsVector() {
887887
return sql.ErrJSONIndex.New(schCol.Name)
888888
}
889889

0 commit comments

Comments
 (0)