Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions enginetest/enginetests.go
Original file line number Diff line number Diff line change
Expand Up @@ -5871,6 +5871,18 @@ func TestIndexes(t *testing.T, h Harness) {
}
}

func TestVectorIndexes(t *testing.T, h Harness) {
for _, tt := range queries.VectorIndexQueries {
TestScript(t, h, tt)
}
}

func TestVectorFunctions(t *testing.T, h Harness) {
for _, tt := range queries.VectorFunctionQueries {
TestScript(t, h, tt)
}
}

func TestIndexPrefix(t *testing.T, h Harness) {
for _, tt := range queries.IndexPrefixQueries {
TestScript(t, h, tt)
Expand Down
8 changes: 8 additions & 0 deletions enginetest/memory_engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -890,6 +890,14 @@ func TestIndexes(t *testing.T) {
enginetest.TestIndexes(t, enginetest.NewDefaultMemoryHarness())
}

func TestVectorIndexes(t *testing.T) {
enginetest.TestVectorIndexes(t, enginetest.NewDefaultMemoryHarness())
}

func TestVectorFunctions(t *testing.T) {
enginetest.TestVectorFunctions(t, enginetest.NewDefaultMemoryHarness())
}

func TestIndexPrefix(t *testing.T) {
enginetest.TestIndexPrefix(t, enginetest.NewDefaultMemoryHarness())
}
Expand Down
58 changes: 58 additions & 0 deletions enginetest/queries/vector_function_queries.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package queries

import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)

var VectorFunctionQueries = []ScriptTest{
{
Name: "basic usage of VEC_DISTANCE without index",
SetUpScript: []string{
"create table vectors (id int primary key, v json);",
`insert into vectors values (1, '[3.0,4.0]'), (2, '[0.0,0.0]'), (3, '[1.0,-1.0]'), (4, '[-2.0,0.0]');`,
},
Assertions: []ScriptTestAssertion{
{
Query: "select VEC_DISTANCE('[10.0]', '[20.0]');",
Expected: []sql.Row{{100.0}},
},
{
Query: "select VEC_DISTANCE_L2_SQUARED('[1.0, 2.0]', '[5.0, 5.0]');",
Expected: []sql.Row{{25.0}},
},
{
Query: "select * from vectors order by VEC_DISTANCE('[0.0,0.0]', v)",
Expected: []sql.Row{
{2, types.MustJSON(`[0.0, 0.0]`)},
{3, types.MustJSON(`[1.0, -1.0]`)},
{4, types.MustJSON(`[-2.0, 0.0]`)},
{1, types.MustJSON(`[3.0, 4.0]`)},
},
},
{
Query: "select * from vectors order by VEC_DISTANCE_L2_SQUARED('[-2.0,0.0]', v)",
Expected: []sql.Row{
{4, types.MustJSON(`[-2.0, 0.0]`)},
{2, types.MustJSON(`[0.0, 0.0]`)},
{3, types.MustJSON(`[1.0, -1.0]`)},
{1, types.MustJSON(`[3.0, 4.0]`)},
},
},
},
},
}
81 changes: 81 additions & 0 deletions enginetest/queries/vector_index_queries.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package queries

import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/types"
)

var VectorIndexQueries = []ScriptTest{
{
Name: "basic vector index",
SetUpScript: []string{
"create table vectors (id int primary key, v json);",
`insert into vectors values (1, '[4.0,3.0]'), (2, '[0.0,0.0]'), (3, '[-1.0,1.0]'), (4, '[0.0,-2.0]');`,
`create vector index v_idx on vectors(v);`,
},
Assertions: []ScriptTestAssertion{
{
Query: "show create table vectors",
Expected: []sql.Row{
{"vectors", "CREATE TABLE `vectors` (\n `id` int NOT NULL,\n `v` json,\n PRIMARY KEY (`id`),\n VECTOR KEY `v_idx` (`v`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"},
},
},
{
Query: "select * from vectors order by VEC_DISTANCE('[0.0,0.0]', v) limit 4",
Expected: []sql.Row{
{2, types.MustJSON(`[0.0, 0.0]`)},
{3, types.MustJSON(`[-1.0, 1.0]`)},
{4, types.MustJSON(`[0.0, -2.0]`)},
{1, types.MustJSON(`[4.0, 3.0]`)},
},
ExpectedIndexes: []string{"v_idx"},
},
{
// Only queries with a limit can use a vector index.
Query: "select * from vectors order by VEC_DISTANCE('[0.0,0.0]', v)",
Expected: []sql.Row{
{2, types.MustJSON(`[0.0, 0.0]`)},
{3, types.MustJSON(`[-1.0, 1.0]`)},
{4, types.MustJSON(`[0.0, -2.0]`)},
{1, types.MustJSON(`[4.0, 3.0]`)},
},
ExpectedIndexes: nil,
},
{
Query: "select * from vectors order by VEC_DISTANCE_L2_SQUARED('[0.0,-2.0]', v) limit 4",
Expected: []sql.Row{
{4, types.MustJSON(`[0.0, -2.0]`)},
{2, types.MustJSON(`[0.0, 0.0]`)},
{3, types.MustJSON(`[-1.0, 1.0]`)},
{1, types.MustJSON(`[4.0, 3.0]`)},
},
ExpectedIndexes: []string{"v_idx"},
},
{
// Ensure vector index is not used for range lookups.
Query: "select * from vectors order by v limit 4",
Expected: []sql.Row{
{3, types.MustJSON(`[-1.0, 1.0]`)},
{4, types.MustJSON(`[0.0, -2.0]`)},
{2, types.MustJSON(`[0.0, 0.0]`)},
{1, types.MustJSON(`[4.0, 3.0]`)},
},
ExpectedIndexes: []string{},
},
},
},
}
9 changes: 7 additions & 2 deletions memory/index.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (

"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/expression/function/vector"
"github.com/dolthub/go-mysql-server/sql/fulltext"
"github.com/dolthub/go-mysql-server/sql/types"
)
Expand All @@ -38,7 +39,7 @@ type Index struct {
Fulltext bool
// If SupportedVectorFunction is non-nil, this index can be used to optimize ORDER BY
// expressions on this type of distance function.
SupportedVectorFunction expression.DistanceType
SupportedVectorFunction vector.DistanceType
CommentStr string
PrefixLens []uint16
fulltextInfo
Expand Down Expand Up @@ -121,11 +122,15 @@ func (idx *Index) IsFullText() bool {
return idx.Fulltext
}

func (idx *Index) IsVector() bool {
return idx.SupportedVectorFunction != nil
}

func (idx *Index) CanSupportOrderBy(expr sql.Expression) bool {
if idx.SupportedVectorFunction == nil {
return false
}
dist, isDist := expr.(*expression.Distance)
dist, isDist := expr.(*vector.Distance)
return isDist && idx.SupportedVectorFunction.CanEval(dist.DistanceType)
}

Expand Down
11 changes: 8 additions & 3 deletions memory/table.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/analyzer/analyzererrors"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/expression/function/vector"
"github.com/dolthub/go-mysql-server/sql/fulltext"
"github.com/dolthub/go-mysql-server/sql/iters"
"github.com/dolthub/go-mysql-server/sql/transform"
Expand Down Expand Up @@ -1597,7 +1598,6 @@ func (t *IndexedTable) LookupPartitions(ctx *sql.Context, lookup sql.IndexLookup

if lookup.VectorOrderAndLimit.OrderBy != nil {
return &vectorPartitionIter{
Column: lookup.Index.(*Index).Exprs[0],
OrderAndLimit: lookup.VectorOrderAndLimit,
}, nil
}
Expand Down Expand Up @@ -2011,6 +2011,11 @@ func (t *Table) createIndex(data *TableData, name string, columns []sql.IndexCol
}
}

var vectorFunction vector.DistanceType
if constraint == sql.IndexConstraint_Vector {
vectorFunction = vector.DistanceL2Squared{}
}

return &Index{
DB: t.dbName(),
DriverName: "",
Expand All @@ -2021,7 +2026,7 @@ func (t *Table) createIndex(data *TableData, name string, columns []sql.IndexCol
Unique: constraint == sql.IndexConstraint_Unique,
Spatial: constraint == sql.IndexConstraint_Spatial,
Fulltext: constraint == sql.IndexConstraint_Fulltext,
SupportedVectorFunction: nil,
SupportedVectorFunction: vectorFunction,
CommentStr: comment,
PrefixLens: prefixLengths,
}, nil
Expand Down Expand Up @@ -2121,7 +2126,7 @@ func (t *Table) CreateFulltextIndex(ctx *sql.Context, indexDef sql.IndexDef, key
return nil
}

func (t *Table) CreateVectorIndex(ctx *sql.Context, idx sql.IndexDef, distanceType expression.DistanceType) error {
func (t *Table) CreateVectorIndex(ctx *sql.Context, idx sql.IndexDef, distanceType vector.DistanceType) error {
if len(idx.Columns) > 1 {
return fmt.Errorf("vector indexes must have exactly one column")
}
Expand Down
1 change: 1 addition & 0 deletions sql/analyzer/index_analyzer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ func (i *dummyIdx) Table() string { return i.table }
func (i *dummyIdx) IsUnique() bool { return false }
func (i *dummyIdx) IsSpatial() bool { return false }
func (i *dummyIdx) IsFullText() bool { return false }
func (i *dummyIdx) IsVector() bool { return false }
func (i *dummyIdx) Comment() string { return "" }
func (i *dummyIdx) IsGenerated() bool { return false }
func (i *dummyIdx) CanSupportOrderBy(sql.Expression) bool { return false }
Expand Down
23 changes: 15 additions & 8 deletions sql/analyzer/replace_order_by_distance.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package analyzer
import (
"github.com/dolthub/go-mysql-server/sql"
"github.com/dolthub/go-mysql-server/sql/expression"
"github.com/dolthub/go-mysql-server/sql/expression/function/vector"
"github.com/dolthub/go-mysql-server/sql/plan"
"github.com/dolthub/go-mysql-server/sql/transform"
)
Expand All @@ -12,9 +13,9 @@ func replaceIdxOrderByDistance(ctx *sql.Context, a *Analyzer, n sql.Node, scope
return replaceIdxOrderByDistanceHelper(ctx, scope, n, nil)
}

func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node sql.Node, sortNode plan.Sortable) (sql.Node, transform.TreeIdentity, error) {
func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node sql.Node, sortNode *plan.TopN) (sql.Node, transform.TreeIdentity, error) {
switch n := node.(type) {
case plan.Sortable:
case *plan.TopN:
sortNode = n // lowest parent sort node
case *plan.ResolvedTable:
if sortNode == nil {
Expand All @@ -40,6 +41,11 @@ func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node s
if err != nil {
return nil, transform.SameTree, err
}

// Column references have not been assigned their final indexes yet, so do that for the ORDER BY expression now.
// We can safely do this because an expression that references other tables won't pass `isSortFieldsValidPrefix` below.
sortNode = offsetAssignIndexes(sortNode).(*plan.TopN)

sfExprs := normalizeExpressions(tableAliases, sortNode.GetSortFields().ToExpressions()...)
sfAliases := aliasedExpressionsInNode(sortNode)

Expand All @@ -49,18 +55,21 @@ func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node s
if len(sfExprs) != 1 {
return n, transform.SameTree, nil
}
distance, isDistance := sfExprs[0].(*expression.Distance)
distance, isDistance := sfExprs[0].(*vector.Distance)
if !isDistance {
return n, transform.SameTree, nil
}
var column sql.Expression
var literal sql.Expression
_, leftIsLiteral := distance.LeftChild.(*expression.Literal)
if leftIsLiteral {
column = distance.RightChild
literal = distance.LeftChild
} else {
_, rightIsLiteral := distance.RightChild.(*expression.Literal)
if rightIsLiteral {
column = distance.LeftChild
literal = distance.RightChild
} else {
return n, transform.SameTree, nil
}
Expand All @@ -82,17 +91,15 @@ func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node s
return n, transform.SameTree, nil
}

var limit sql.Expression
if topn, ok := sortNode.(*plan.TopN); ok {
limit = topn.Limit
}
limit := sortNode.Limit

lookup := sql.IndexLookup{
Index: idx,
Ranges: sql.MySQLRangeCollection{},
VectorOrderAndLimit: sql.OrderAndLimit{
OrderBy: distance,
Limit: limit,
Literal: literal,
},
}
nn, err := plan.NewStaticIndexedAccessForTableNode(n, lookup)
Expand All @@ -108,7 +115,7 @@ func replaceIdxOrderByDistanceHelper(ctx *sql.Context, scope *plan.Scope, node s
var err error
same := transform.SameTree
switch c := child.(type) {
case *plan.Project, *plan.TableAlias, *plan.ResolvedTable, *plan.Filter, *plan.Limit, *plan.Offset, *plan.Sort, *plan.IndexedTableAccess:
case *plan.Project, *plan.TableAlias, *plan.ResolvedTable, *plan.Filter, *plan.Limit, *plan.TopN, *plan.Offset, *plan.Sort, *plan.IndexedTableAccess:
newChildren[i], same, err = replaceIdxOrderByDistanceHelper(ctx, scope, child, sortNode)
default:
newChildren[i] = c
Expand Down
4 changes: 4 additions & 0 deletions sql/analyzer/replace_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,10 @@ func replaceIdxSortHelper(ctx *sql.Context, scope *plan.Scope, node sql.Node, so
if idxCandidate.IsSpatial() {
continue
}
if idxCandidate.IsVector() {
// TODO: It's possible that we may be able to use vector indexes for point lookups, but not range lookups
continue
}
if isSortFieldsValidPrefix(sfExprs, sfAliases, idxCandidate.Expressions()) {
idx = idxCandidate
break
Expand Down
4 changes: 2 additions & 2 deletions sql/analyzer/validate_create_table.go
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ func validateModifyColumn(ctx *sql.Context, initialSch sql.Schema, schema sql.Sc
if !strings.EqualFold(col.Name, oldColName) {
continue
}
if types.IsJSON(newCol.Type) {
if types.IsJSON(newCol.Type) && !index.IsVector() {
return nil, sql.ErrJSONIndex.New(col.Name)
}
var prefixLen int64
Expand Down Expand Up @@ -883,7 +883,7 @@ func validateIndex(ctx *sql.Context, colMap map[string]*sql.Column, idxDef *sql.
return sql.ErrDuplicateColumn.New(schCol.Name)
}
seenCols[schCol.Name] = struct{}{}
if types.IsJSON(schCol.Type) {
if types.IsJSON(schCol.Type) && !idxDef.IsVector() {
return sql.ErrJSONIndex.New(schCol.Name)
}

Expand Down
Loading
Loading