Skip to content

Commit 7a21abe

Browse files
committed
Add initial vector type implementation and tests.
1 parent 628730f commit 7a21abe

File tree

11 files changed

+726
-2
lines changed

11 files changed

+726
-2
lines changed

enginetest/enginetests.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6128,6 +6128,15 @@ func TestVectorFunctions(t *testing.T, h Harness) {
61286128
for _, tt := range queries.VectorFunctionQueries {
61296129
TestScript(t, h, tt)
61306130
}
6131+
for _, testCase := range queries.VectorFunctionTestCases {
6132+
TestScript(t, h, queries.MakeVectorFunctionTest(testCase))
6133+
}
6134+
}
6135+
6136+
func TestVectorType(t *testing.T, h Harness) {
6137+
for _, tt := range queries.VectorDDLQueries {
6138+
TestScript(t, h, tt)
6139+
}
61316140
}
61326141

61336142
func TestIndexPrefix(t *testing.T, h Harness) {

enginetest/memory_engine_test.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -922,6 +922,10 @@ func TestVectorFunctions(t *testing.T) {
922922
enginetest.TestVectorFunctions(t, enginetest.NewDefaultMemoryHarness())
923923
}
924924

925+
func TestVectorType(t *testing.T) {
926+
enginetest.TestVectorType(t, enginetest.NewDefaultMemoryHarness())
927+
}
928+
925929
func TestIndexPrefix(t *testing.T) {
926930
enginetest.TestIndexPrefix(t, enginetest.NewDefaultMemoryHarness())
927931
}

enginetest/queries/create_table_queries.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,12 @@ var CreateTableQueries = []WriteQueryTest{
318318
SelectQuery: `SHOW CREATE TABLE t1`,
319319
ExpectedSelect: []sql.Row{{"t1", "CREATE TABLE `t1` (\n `pk` bit(2) DEFAULT b'10'\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
320320
},
321+
{
322+
WriteQuery: `CREATE TABLE embeddings (id INT, vector_col VECTOR(128) NOT NULL, small_vec VECTOR(1))`,
323+
ExpectedWriteResult: []sql.Row{{types.NewOkResult(0)}},
324+
SelectQuery: `SHOW CREATE TABLE embeddings`,
325+
ExpectedSelect: []sql.Row{{"embeddings", "CREATE TABLE `embeddings` (\n `id` int,\n `vector_col` VECTOR(128) NOT NULL,\n `small_vec` VECTOR(1)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
326+
},
321327
}
322328

323329
var CreateTableScriptTests = []ScriptTest{

enginetest/queries/type_wire_queries.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -816,4 +816,21 @@ var TypeWireTests = []TypeWireTest{
816816
{{"1", "[[\"a\",1]]"}, {"2", "[{\"key1\":\"value1\",\"key2\":\"value2\"}]"}},
817817
},
818818
},
819+
{
820+
Name: "VECTOR",
821+
SetUpScript: []string{
822+
`CREATE TABLE test (pk INT PRIMARY KEY, v1 VECTOR(2), v2 VECTOR(3));`,
823+
`INSERT INTO test VALUES (1, VEC_FROMTEXT('[1.0, 2.0]', '[1.0, 2.0, 3.0]'), (2, '[4.0, 5.0]', '[4.0, 5.0, 6.0]'));`,
824+
},
825+
Queries: []string{
826+
`SELECT * FROM test ORDER BY pk;`,
827+
`SELECT v1, v2 FROM test ORDER BY pk;`,
828+
`SELECT pk, v1 FROM test WHERE pk = 1;`,
829+
},
830+
Results: [][]sql.Row{
831+
{{"1", "[1,2]", "[1,2,3]"}, {"2", "[4,5]", "[4,5,6]"}},
832+
{{"[1,2]", "[1,2,3]"}, {"[4,5]", "[4,5,6]"}},
833+
{{"1", "[1,2]"}},
834+
},
835+
},
819836
}
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
// Copyright 2022 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package queries
16+
17+
import (
18+
"github.com/dolthub/go-mysql-server/sql"
19+
"github.com/dolthub/go-mysql-server/sql/plan"
20+
"github.com/dolthub/go-mysql-server/sql/types"
21+
)
22+
23+
// VectorDDLQueries tests VECTOR type creation, insertion, and querying
24+
var VectorDDLQueries = []ScriptTest{
25+
{
26+
Name: "basic VECTOR type creation and manipulation",
27+
SetUpScript: []string{
28+
`CREATE TABLE test_vectors (
29+
id INT PRIMARY KEY,
30+
small_vec VECTOR(2),
31+
medium_vec VECTOR(10),
32+
large_vec VECTOR(1000)
33+
)`,
34+
`INSERT INTO test_vectors VALUES
35+
(1, STRING_TO_VECTOR('[1.0, 2.0]'), STRING_TO_VECTOR('[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]'), NULL),
36+
(2, STRING_TO_VECTOR('[3.5, 4.5]'), NULL, NULL
37+
)`,
38+
},
39+
Assertions: []ScriptTestAssertion{
40+
{
41+
Query: `SHOW CREATE TABLE vectors`,
42+
Expected: []sql.Row{{"vectors", "CREATE TABLE `vectors` (\n `id` int NOT NULL,\n `small_vec` VECTOR(2),\n `medium_vec` VECTOR(10),\n `large_vec` VECTOR(1000), \n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
43+
},
44+
{
45+
Query: `SELECT id, small_vec, medium_vec FROM vectors WHERE id = 2`,
46+
Expected: []sql.Row{{2, []float32{3.5, 4.5}, nil}},
47+
},
48+
{
49+
Query: `SELECT id, small_vec, medium_vec FROM vectors WHERE id = 1`,
50+
Expected: []sql.Row{{1, []float32{1.0, 2.0}, []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0}}},
51+
},
52+
{
53+
Query: `SELECT id, small_vec FROM test_vectors ORDER BY id`,
54+
Expected: []sql.Row{
55+
{1, []float32{1.0, 2.0}},
56+
{2, []float32{3.5, 4.5}},
57+
},
58+
},
59+
{
60+
Query: `UPDATE test_vectors SET small_vec = '[10.0, 20.0]' WHERE id = 1`,
61+
Expected: []sql.Row{{types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}}},
62+
},
63+
{
64+
Query: `SELECT small_vec FROM test_vectors WHERE id = 1`,
65+
Expected: []sql.Row{{[]float32{10.0, 20.0}}},
66+
},
67+
{
68+
Query: `INSERT INTO test_vectors VALUES
69+
(3, 0x00F0803F00004040, NULL, NULL)`,
70+
Expected: []sql.Row{{types.NewOkResult(1)}},
71+
},
72+
{
73+
Query: `SELECT small_vec FROM test_vectors WHERE id = 3`,
74+
Expected: []sql.Row{{[]float32{1.0, 2.0}}},
75+
},
76+
},
77+
},
78+
{
79+
Name: "VECTOR type error conditions",
80+
SetUpScript: []string{
81+
`CREATE TABLE error_vectors (id INT PRIMARY KEY, vec3 VECTOR(3))`,
82+
},
83+
Assertions: []ScriptTestAssertion{
84+
{
85+
Query: `INSERT INTO error_vectors VALUES (1, '[1.0, 2.0]')`,
86+
ExpectedErrStr: "",
87+
},
88+
{
89+
Query: `INSERT INTO error_vectors VALUES (1, STRING_TO_VECTOR('[1.0, 2.0]'))`,
90+
ExpectedErrStr: "VECTOR dimension mismatch: expected 3, got 2",
91+
},
92+
{
93+
Query: `INSERT INTO error_vectors VALUES (2, STRING_TO_VECTOR('[1.0, 2.0, 3.0, 4.0]'))`,
94+
ExpectedErrStr: "VECTOR dimension mismatch: expected 3, got 4",
95+
},
96+
{
97+
Query: `INSERT INTO error_vectors VALUES (3, STRING_TO_VECTOR('[1.0, invalid, 3.0]'))`,
98+
ExpectedErrStr: "invalid VECTOR JSON format: invalid character 'i' looking for beginning of value",
99+
},
100+
{
101+
Query: `INSERT INTO error_vectors VALUES (4, STRING_TO_VECTOR('invalid_json'))`,
102+
ExpectedErrStr: "VECTOR must be in JSON array format",
103+
},
104+
{
105+
Query: `INSERT INTO error_vectors VALUES (5, STRING_TO_VECTOR('[1.0, "not an array"]'))`,
106+
ExpectedErrStr: "VECTOR must be in JSON array format",
107+
},
108+
{
109+
Query: `INSERT INTO error_vectors VALUES (5, STRING_TO_VECTOR('"not an array"'))`,
110+
ExpectedErrStr: "VECTOR must be in JSON array format",
111+
},
112+
{
113+
Query: `CREATE TABLE error_vectors (id INT PRIMARY KEY, vec3 VECTOR(-3))`,
114+
ExpectedErr: sql.ErrInvalidColTypeDefinition,
115+
},
116+
{
117+
Query: `CREATE TABLE error_vectors (id INT PRIMARY KEY, vec3 VECTOR(0))`,
118+
ExpectedErr: sql.ErrInvalidColTypeDefinition,
119+
},
120+
{
121+
Query: `CREATE TABLE error_vectors (id INT PRIMARY KEY, vec3 VECTOR(17000))`,
122+
ExpectedErr: sql.ErrInvalidColTypeDefinition,
123+
},
124+
},
125+
},
126+
{
127+
Name: "VECTOR type with different data formats",
128+
SetUpScript: []string{
129+
`CREATE TABLE format_vectors (id INT PRIMARY KEY, vec2 VECTOR(2))`,
130+
},
131+
Assertions: []ScriptTestAssertion{
132+
{
133+
Query: `INSERT INTO format_vectors VALUES
134+
(1, STRING_TO_VECTOR('[1.0, 2.0]')),
135+
(2, STRING_TO_VECTOR('[3, 4]')),
136+
(3, STRING_TO_VECTOR('[55e-1, 67e2]'))`,
137+
Expected: []sql.Row{{types.NewOkResult(3)}},
138+
},
139+
{
140+
Query: `SELECT id, vec2 FROM format_vectors ORDER BY id`,
141+
Expected: []sql.Row{
142+
{1, []float32{1.0, 2.0}},
143+
{2, []float32{3.0, 4.0}},
144+
{3, []float32{5.5, 6700}},
145+
},
146+
},
147+
},
148+
},
149+
}

0 commit comments

Comments
 (0)