Skip to content

Commit ea88909

Browse files
committed
Add integration tests for vector indexes.
1 parent fe514c9 commit ea88909

File tree

4 files changed

+142
-0
lines changed

4 files changed

+142
-0
lines changed

enginetest/enginetests.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5870,6 +5870,18 @@ func TestIndexes(t *testing.T, h Harness) {
58705870
}
58715871
}
58725872

5873+
func TestVectorIndexes(t *testing.T, h Harness) {
5874+
for _, tt := range queries.VectorIndexQueries {
5875+
TestScript(t, h, tt)
5876+
}
5877+
}
5878+
5879+
func TestVectorFunctions(t *testing.T, h Harness) {
5880+
for _, tt := range queries.VectorFunctionQueries {
5881+
TestScript(t, h, tt)
5882+
}
5883+
}
5884+
58735885
func TestIndexPrefix(t *testing.T, h Harness) {
58745886
for _, tt := range queries.IndexPrefixQueries {
58755887
TestScript(t, h, tt)

enginetest/memory_engine_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -902,6 +902,14 @@ func TestIndexes(t *testing.T) {
902902
enginetest.TestIndexes(t, enginetest.NewDefaultMemoryHarness())
903903
}
904904

905+
func TestVectorIndexes(t *testing.T) {
906+
enginetest.TestVectorIndexes(t, enginetest.NewDefaultMemoryHarness())
907+
}
908+
909+
func TestVectorFunctions(t *testing.T) {
910+
enginetest.TestVectorFunctions(t, enginetest.NewDefaultMemoryHarness())
911+
}
912+
905913
func TestIndexPrefix(t *testing.T) {
906914
enginetest.TestIndexPrefix(t, enginetest.NewDefaultMemoryHarness())
907915
}
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright 2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package queries
16+
17+
import (
18+
"github.com/dolthub/go-mysql-server/sql"
19+
"github.com/dolthub/go-mysql-server/sql/types"
20+
)
21+
22+
var VectorFunctionQueries = []ScriptTest{
23+
{
24+
Name: "basic usage of VEC_DISTANCE without index",
25+
SetUpScript: []string{
26+
"create table vectors (id int primary key, v json);",
27+
`insert into vectors values (1, '[3.0,4.0]'), (2, '[0.0,0.0]'), (3, '[1.0,-1.0]'), (4, '[-2.0,0.0]');`,
28+
},
29+
Assertions: []ScriptTestAssertion{
30+
{
31+
Query: "select VEC_DISTANCE('[10.0]', '[20.0]');",
32+
Expected: []sql.Row{{100.0}},
33+
},
34+
{
35+
Query: "select VEC_DISTANCE_L2_SQUARED('[1.0, 2.0]', '[5.0, 5.0]');",
36+
Expected: []sql.Row{{25.0}},
37+
},
38+
{
39+
Query: "select * from vectors order by VEC_DISTANCE('[0.0,0.0]', v)",
40+
Expected: []sql.Row{
41+
{2, types.MustJSON(`[0.0, 0.0]`)},
42+
{3, types.MustJSON(`[1.0, -1.0]`)},
43+
{4, types.MustJSON(`[-2.0, 0.0]`)},
44+
{1, types.MustJSON(`[3.0, 4.0]`)},
45+
},
46+
},
47+
{
48+
Query: "select * from vectors order by VEC_DISTANCE_L2_SQUARED('[-2.0,0.0]', v)",
49+
Expected: []sql.Row{
50+
{4, types.MustJSON(`[-2.0, 0.0]`)},
51+
{2, types.MustJSON(`[0.0, 0.0]`)},
52+
{3, types.MustJSON(`[1.0, -1.0]`)},
53+
{1, types.MustJSON(`[3.0, 4.0]`)},
54+
},
55+
},
56+
},
57+
},
58+
}
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Copyright 2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package queries
16+
17+
import (
18+
"github.com/dolthub/go-mysql-server/sql"
19+
"github.com/dolthub/go-mysql-server/sql/types"
20+
)
21+
22+
var VectorIndexQueries = []ScriptTest{
23+
{
24+
Name: "basic vector index",
25+
SetUpScript: []string{
26+
"create table vectors (id int primary key, v json);",
27+
`insert into vectors values (1, '[3.0,4.0]'), (2, '[0.0,0.0]'), (3, '[1.0,-1.0]'), (4, '[-2.0,0.0]');`,
28+
},
29+
Assertions: []ScriptTestAssertion{
30+
{
31+
Query: `create vector index v_idx on vectors(v);`,
32+
Expected: []sql.Row{
33+
{types.OkResult{RowsAffected: 0}},
34+
},
35+
},
36+
{
37+
Query: "show create table vectors",
38+
Expected: []sql.Row{
39+
{"vectors", "CREATE TABLE `vectors` (\n `id` int NOT NULL,\n `v` json,\n PRIMARY KEY (`id`),\n VECTOR KEY `v_idx` (`v`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"},
40+
},
41+
},
42+
{
43+
Query: "select * from vectors order by VEC_DISTANCE('[0.0,0.0]', v)",
44+
Expected: []sql.Row{
45+
{2, types.MustJSON(`[0.0, 0.0]`)},
46+
{3, types.MustJSON(`[1.0, -1.0]`)},
47+
{4, types.MustJSON(`[-2.0, 0.0]`)},
48+
{1, types.MustJSON(`[3.0, 4.0]`)},
49+
},
50+
ExpectedIndexes: []string{"v_idx"},
51+
},
52+
{
53+
Query: "select * from vectors order by VEC_DISTANCE_L2_SQUARED('[-2.0,0.0]', v)",
54+
Expected: []sql.Row{
55+
{4, types.MustJSON(`[-2.0, 0.0]`)},
56+
{2, types.MustJSON(`[0.0, 0.0]`)},
57+
{3, types.MustJSON(`[1.0, -1.0]`)},
58+
{1, types.MustJSON(`[3.0, 4.0]`)},
59+
},
60+
ExpectedIndexes: []string{"v_idx"},
61+
},
62+
},
63+
},
64+
}

0 commit comments

Comments
 (0)