Skip to content

Commit c688058

Browse files
committed
Merge branch 'main' into zachmu/enginetests5
2 parents 8da487f + 81b13e8 commit c688058

24 files changed

+461
-130
lines changed

enginetest/enginetests.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5870,6 +5870,18 @@ func TestIndexes(t *testing.T, h Harness) {
58705870
}
58715871
}
58725872

5873+
func TestVectorIndexes(t *testing.T, h Harness) {
5874+
for _, tt := range queries.VectorIndexQueries {
5875+
TestScript(t, h, tt)
5876+
}
5877+
}
5878+
5879+
func TestVectorFunctions(t *testing.T, h Harness) {
5880+
for _, tt := range queries.VectorFunctionQueries {
5881+
TestScript(t, h, tt)
5882+
}
5883+
}
5884+
58735885
func TestIndexPrefix(t *testing.T, h Harness) {
58745886
for _, tt := range queries.IndexPrefixQueries {
58755887
TestScript(t, h, tt)

enginetest/memory_engine_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -890,6 +890,14 @@ func TestIndexes(t *testing.T) {
890890
enginetest.TestIndexes(t, enginetest.NewDefaultMemoryHarness())
891891
}
892892

893+
func TestVectorIndexes(t *testing.T) {
894+
enginetest.TestVectorIndexes(t, enginetest.NewDefaultMemoryHarness())
895+
}
896+
897+
func TestVectorFunctions(t *testing.T) {
898+
enginetest.TestVectorFunctions(t, enginetest.NewDefaultMemoryHarness())
899+
}
900+
893901
func TestIndexPrefix(t *testing.T) {
894902
enginetest.TestIndexPrefix(t, enginetest.NewDefaultMemoryHarness())
895903
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright 2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package queries
16+
17+
import (
18+
"github.com/dolthub/go-mysql-server/sql"
19+
"github.com/dolthub/go-mysql-server/sql/types"
20+
)
21+
22+
var VectorFunctionQueries = []ScriptTest{
23+
{
24+
Name: "basic usage of VEC_DISTANCE without index",
25+
SetUpScript: []string{
26+
"create table vectors (id int primary key, v json);",
27+
`insert into vectors values (1, '[3.0,4.0]'), (2, '[0.0,0.0]'), (3, '[1.0,-1.0]'), (4, '[-2.0,0.0]');`,
28+
},
29+
Assertions: []ScriptTestAssertion{
30+
{
31+
Query: "select VEC_DISTANCE('[10.0]', '[20.0]');",
32+
Expected: []sql.Row{{100.0}},
33+
},
34+
{
35+
Query: "select VEC_DISTANCE_L2_SQUARED('[1.0, 2.0]', '[5.0, 5.0]');",
36+
Expected: []sql.Row{{25.0}},
37+
},
38+
{
39+
Query: "select * from vectors order by VEC_DISTANCE('[0.0,0.0]', v)",
40+
Expected: []sql.Row{
41+
{2, types.MustJSON(`[0.0, 0.0]`)},
42+
{3, types.MustJSON(`[1.0, -1.0]`)},
43+
{4, types.MustJSON(`[-2.0, 0.0]`)},
44+
{1, types.MustJSON(`[3.0, 4.0]`)},
45+
},
46+
},
47+
{
48+
Query: "select * from vectors order by VEC_DISTANCE_L2_SQUARED('[-2.0,0.0]', v)",
49+
Expected: []sql.Row{
50+
{4, types.MustJSON(`[-2.0, 0.0]`)},
51+
{2, types.MustJSON(`[0.0, 0.0]`)},
52+
{3, types.MustJSON(`[1.0, -1.0]`)},
53+
{1, types.MustJSON(`[3.0, 4.0]`)},
54+
},
55+
},
56+
},
57+
},
58+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
// Copyright 2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package queries
16+
17+
import (
18+
"github.com/dolthub/go-mysql-server/sql"
19+
"github.com/dolthub/go-mysql-server/sql/types"
20+
)
21+
22+
var VectorIndexQueries = []ScriptTest{
23+
{
24+
Name: "basic vector index",
25+
SetUpScript: []string{
26+
"create table vectors (id int primary key, v json);",
27+
`insert into vectors values (1, '[4.0,3.0]'), (2, '[0.0,0.0]'), (3, '[-1.0,1.0]'), (4, '[0.0,-2.0]');`,
28+
`create vector index v_idx on vectors(v);`,
29+
},
30+
Assertions: []ScriptTestAssertion{
31+
{
32+
Query: "show create table vectors",
33+
Expected: []sql.Row{
34+
{"vectors", "CREATE TABLE `vectors` (\n `id` int NOT NULL,\n `v` json,\n PRIMARY KEY (`id`),\n VECTOR KEY `v_idx` (`v`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"},
35+
},
36+
},
37+
{
38+
Query: "select * from vectors order by VEC_DISTANCE('[0.0,0.0]', v) limit 4",
39+
Expected: []sql.Row{
40+
{2, types.MustJSON(`[0.0, 0.0]`)},
41+
{3, types.MustJSON(`[-1.0, 1.0]`)},
42+
{4, types.MustJSON(`[0.0, -2.0]`)},
43+
{1, types.MustJSON(`[4.0, 3.0]`)},
44+
},
45+
ExpectedIndexes: []string{"v_idx"},
46+
},
47+
{
48+
// Only queries with a limit can use a vector index.
49+
Query: "select * from vectors order by VEC_DISTANCE('[0.0,0.0]', v)",
50+
Expected: []sql.Row{
51+
{2, types.MustJSON(`[0.0, 0.0]`)},
52+
{3, types.MustJSON(`[-1.0, 1.0]`)},
53+
{4, types.MustJSON(`[0.0, -2.0]`)},
54+
{1, types.MustJSON(`[4.0, 3.0]`)},
55+
},
56+
ExpectedIndexes: nil,
57+
},
58+
{
59+
Query: "select * from vectors order by VEC_DISTANCE_L2_SQUARED('[0.0,-2.0]', v) limit 4",
60+
Expected: []sql.Row{
61+
{4, types.MustJSON(`[0.0, -2.0]`)},
62+
{2, types.MustJSON(`[0.0, 0.0]`)},
63+
{3, types.MustJSON(`[-1.0, 1.0]`)},
64+
{1, types.MustJSON(`[4.0, 3.0]`)},
65+
},
66+
ExpectedIndexes: []string{"v_idx"},
67+
},
68+
{
69+
// Ensure vector index is not used for range lookups.
70+
Query: "select * from vectors order by v limit 4",
71+
Expected: []sql.Row{
72+
{3, types.MustJSON(`[-1.0, 1.0]`)},
73+
{4, types.MustJSON(`[0.0, -2.0]`)},
74+
{2, types.MustJSON(`[0.0, 0.0]`)},
75+
{1, types.MustJSON(`[4.0, 3.0]`)},
76+
},
77+
ExpectedIndexes: []string{},
78+
},
79+
},
80+
},
81+
}

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ require (
66
github.com/dolthub/go-icu-regex v0.0.0-20240916130659-0118adc6b662
77
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71
88
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81
9-
github.com/dolthub/vitess v0.0.0-20241104125316-860772ba6683
9+
github.com/dolthub/vitess v0.0.0-20241111235433-a20a5ab9d7c9
1010
github.com/go-kit/kit v0.10.0
1111
github.com/go-sql-driver/mysql v1.7.2-0.20231213112541-0004702b931d
1212
github.com/gocraft/dbr/v2 v2.7.2

go.sum

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,8 @@ github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71 h1:bMGS25NWAGTE
5858
github.com/dolthub/jsonpath v0.0.2-0.20240227200619-19675ab05c71/go.mod h1:2/2zjLQ/JOOSbbSboojeg+cAwcRV0fDLzIiWch/lhqI=
5959
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81 h1:7/v8q9XGFa6q5Ap4Z/OhNkAMBaK5YeuEzwJt+NZdhiE=
6060
github.com/dolthub/sqllogictest/go v0.0.0-20201107003712-816f3ae12d81/go.mod h1:siLfyv2c92W1eN/R4QqG/+RjjX5W2+gCTRjZxBjI3TY=
61-
github.com/dolthub/vitess v0.0.0-20241104125316-860772ba6683 h1:2/RJeUfNAXS7mbBnEr9C36htiCJHk5XldDPzhxtEsME=
62-
github.com/dolthub/vitess v0.0.0-20241104125316-860772ba6683/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM=
61+
github.com/dolthub/vitess v0.0.0-20241111235433-a20a5ab9d7c9 h1:s36zDuLPuZRWC0nBCJs2Z8joP19eKEtcsIsuE8K9Kx0=
62+
github.com/dolthub/vitess v0.0.0-20241111235433-a20a5ab9d7c9/go.mod h1:uBvlRluuL+SbEWTCZ68o0xvsdYZER3CEG/35INdzfJM=
6363
github.com/dustin/go-humanize v0.0.0-20171111073723-bb3d318650d4/go.mod h1:HtrtbFcZ19U5GC7JDqmcUSB87Iq5E25KnS6fMYU6eOk=
6464
github.com/eapache/go-resiliency v1.1.0/go.mod h1:kFI+JgMyC7bLPUVY133qvEBtVayf5mFgVsvEsIPBvNs=
6565
github.com/eapache/go-xerial-snappy v0.0.0-20180814174437-776d5712da21/go.mod h1:+020luEh2TKB4/GOp8oxxtq0Daoen/Cii55CzbTV6DU=

memory/index.go

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import (
2020

2121
"github.com/dolthub/go-mysql-server/sql"
2222
"github.com/dolthub/go-mysql-server/sql/expression"
23+
"github.com/dolthub/go-mysql-server/sql/expression/function/vector"
2324
"github.com/dolthub/go-mysql-server/sql/fulltext"
2425
"github.com/dolthub/go-mysql-server/sql/types"
2526
)
@@ -38,7 +39,7 @@ type Index struct {
3839
Fulltext bool
3940
// If SupportedVectorFunction is non-nil, this index can be used to optimize ORDER BY
4041
// expressions on this type of distance function.
41-
SupportedVectorFunction expression.DistanceType
42+
SupportedVectorFunction vector.DistanceType
4243
CommentStr string
4344
PrefixLens []uint16
4445
fulltextInfo
@@ -121,11 +122,15 @@ func (idx *Index) IsFullText() bool {
121122
return idx.Fulltext
122123
}
123124

125+
func (idx *Index) IsVector() bool {
126+
return idx.SupportedVectorFunction != nil
127+
}
128+
124129
func (idx *Index) CanSupportOrderBy(expr sql.Expression) bool {
125130
if idx.SupportedVectorFunction == nil {
126131
return false
127132
}
128-
dist, isDist := expr.(*expression.Distance)
133+
dist, isDist := expr.(*vector.Distance)
129134
return isDist && idx.SupportedVectorFunction.CanEval(dist.DistanceType)
130135
}
131136

memory/table.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import (
2929
"github.com/dolthub/go-mysql-server/sql"
3030
"github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors"
3131
"github.com/dolthub/go-mysql-server/sql/expression"
32+
"github.com/dolthub/go-mysql-server/sql/expression/function/vector"
3233
"github.com/dolthub/go-mysql-server/sql/fulltext"
3334
"github.com/dolthub/go-mysql-server/sql/iters"
3435
"github.com/dolthub/go-mysql-server/sql/transform"
@@ -1597,7 +1598,6 @@ func (t *IndexedTable) LookupPartitions(ctx *sql.Context, lookup sql.IndexLookup
15971598

15981599
if lookup.VectorOrderAndLimit.OrderBy != nil {
15991600
return &vectorPartitionIter{
1600-
Column: lookup.Index.(*Index).Exprs[0],
16011601
OrderAndLimit: lookup.VectorOrderAndLimit,
16021602
}, nil
16031603
}
@@ -2011,6 +2011,11 @@ func (t *Table) createIndex(data *TableData, name string, columns []sql.IndexCol
20112011
}
20122012
}
20132013

2014+
var vectorFunction vector.DistanceType
2015+
if constraint == sql.IndexConstraint_Vector {
2016+
vectorFunction = vector.DistanceL2Squared{}
2017+
}
2018+
20142019
return &Index{
20152020
DB: t.dbName(),
20162021
DriverName: "",
@@ -2021,7 +2026,7 @@ func (t *Table) createIndex(data *TableData, name string, columns []sql.IndexCol
20212026
Unique: constraint == sql.IndexConstraint_Unique,
20222027
Spatial: constraint == sql.IndexConstraint_Spatial,
20232028
Fulltext: constraint == sql.IndexConstraint_Fulltext,
2024-
SupportedVectorFunction: nil,
2029+
SupportedVectorFunction: vectorFunction,
20252030
CommentStr: comment,
20262031
PrefixLens: prefixLengths,
20272032
}, nil
@@ -2121,7 +2126,7 @@ func (t *Table) CreateFulltextIndex(ctx *sql.Context, indexDef sql.IndexDef, key
21212126
return nil
21222127
}
21232128

2124-
func (t *Table) CreateVectorIndex(ctx *sql.Context, idx sql.IndexDef, distanceType expression.DistanceType) error {
2129+
func (t *Table) CreateVectorIndex(ctx *sql.Context, idx sql.IndexDef, distanceType vector.DistanceType) error {
21252130
if len(idx.Columns) > 1 {
21262131
return fmt.Errorf("vector indexes must have exactly one column")
21272132
}

sql/analyzer/index_analyzer_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ func (i *dummyIdx) Table() string { return i.table }
148148
func (i *dummyIdx) IsUnique() bool { return false }
149149
func (i *dummyIdx) IsSpatial() bool { return false }
150150
func (i *dummyIdx) IsFullText() bool { return false }
151+
func (i *dummyIdx) IsVector() bool { return false }
151152
func (i *dummyIdx) Comment() string { return "" }
152153
func (i *dummyIdx) IsGenerated() bool { return false }
153154
func (i *dummyIdx) CanSupportOrderBy(sql.Expression) bool { return false }

sql/analyzer/replace_order_by_distance.go

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package analyzer
33
import (
44
"github.com/dolthub/go-mysql-server/sql"
55
"github.com/dolthub/go-mysql-server/sql/expression"
6+
"github.com/dolthub/go-mysql-server/sql/expression/function/vector"
67
"github.com/dolthub/go-mysql-server/sql/plan"
78
"github.com/dolthub/go-mysql-server/sql/transform"
89
)
@@ -12,9 +13,9 @@ func replaceIdxOrderByDistance(ctx *sql.Context, a *Analyzer, n sql.Node, scope
1213
return replaceIdxOrderByDistanceHelper(ctx, scope, n, nil)
1314
}
1415

15-
func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node sql.Node, sortNode plan.Sortable) (sql.Node, transform.TreeIdentity, error) {
16+
func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node sql.Node, sortNode *plan.TopN) (sql.Node, transform.TreeIdentity, error) {
1617
switch n := node.(type) {
17-
case plan.Sortable:
18+
case *plan.TopN:
1819
sortNode = n // lowest parent sort node
1920
case *plan.ResolvedTable:
2021
if sortNode == nil {
@@ -40,6 +41,11 @@ func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node s
4041
if err != nil {
4142
return nil, transform.SameTree, err
4243
}
44+
45+
// Column references have not been assigned their final indexes yet, so do that for the ORDER BY expression now.
46+
// We can safely do this because an expression that references other tables won't pass `isSortFieldsValidPrefix` below.
47+
sortNode = offsetAssignIndexes(sortNode).(*plan.TopN)
48+
4349
sfExprs := normalizeExpressions(tableAliases, sortNode.GetSortFields().ToExpressions()...)
4450
sfAliases := aliasedExpressionsInNode(sortNode)
4551

@@ -49,18 +55,21 @@ func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node s
4955
if len(sfExprs) != 1 {
5056
return n, transform.SameTree, nil
5157
}
52-
distance, isDistance := sfExprs[0].(*expression.Distance)
58+
distance, isDistance := sfExprs[0].(*vector.Distance)
5359
if !isDistance {
5460
return n, transform.SameTree, nil
5561
}
5662
var column sql.Expression
63+
var literal sql.Expression
5764
_, leftIsLiteral := distance.LeftChild.(*expression.Literal)
5865
if leftIsLiteral {
5966
column = distance.RightChild
67+
literal = distance.LeftChild
6068
} else {
6169
_, rightIsLiteral := distance.RightChild.(*expression.Literal)
6270
if rightIsLiteral {
6371
column = distance.LeftChild
72+
literal = distance.RightChild
6473
} else {
6574
return n, transform.SameTree, nil
6675
}
@@ -82,17 +91,15 @@ func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node s
8291
return n, transform.SameTree, nil
8392
}
8493

85-
var limit sql.Expression
86-
if topn, ok := sortNode.(*plan.TopN); ok {
87-
limit = topn.Limit
88-
}
94+
limit := sortNode.Limit
8995

9096
lookup := sql.IndexLookup{
9197
Index: idx,
9298
Ranges: sql.MySQLRangeCollection{},
9399
VectorOrderAndLimit: sql.OrderAndLimit{
94100
OrderBy: distance,
95101
Limit: limit,
102+
Literal: literal,
96103
},
97104
}
98105
nn, err := plan.NewStaticIndexedAccessForTableNode(n, lookup)
@@ -108,7 +115,7 @@ func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node s
108115
var err error
109116
same := transform.SameTree
110117
switch c := child.(type) {
111-
case *plan.Project, *plan.TableAlias, *plan.ResolvedTable, *plan.Filter, *plan.Limit, *plan.Offset, *plan.Sort, *plan.IndexedTableAccess:
118+
case *plan.Project, *plan.TableAlias, *plan.ResolvedTable, *plan.Filter, *plan.Limit, *plan.TopN, *plan.Offset, *plan.Sort, *plan.IndexedTableAccess:
112119
newChildren[i], same, err = replaceIdxOrderByDistanceHelper(ctx, scope, child, sortNode)
113120
default:
114121
newChildren[i] = c

0 commit comments

Comments
 (0)