Skip to content

Commit e6156ef

Browse files
committed
Implement missing functions for vector types.
1 parent d8161a3 commit e6156ef

File tree

10 files changed

+375
-74
lines changed

10 files changed

+375
-74
lines changed

enginetest/queries/type_wire_queries.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ var TypeWireTests = []TypeWireTest{
820820
Name: "VECTOR",
821821
SetUpScript: []string{
822822
`CREATE TABLE test (pk INT PRIMARY KEY, v1 VECTOR(2), v2 VECTOR(3));`,
823-
`INSERT INTO test VALUES (1, VEC_FROMTEXT('[1.0, 2.0]', '[1.0, 2.0, 3.0]'), (2, '[4.0, 5.0]', '[4.0, 5.0, 6.0]'));`,
823+
`INSERT INTO test VALUES (1, VEC_FROMTEXT('[1.0, 2.0]'), VEC_FROMTEXT('[1.0, 2.0, 3.0]')), (2, VEC_FROMTEXT('[4.0, 5.0]'), VEC_FROMTEXT('[4.0, 5.0, 6.0]'));`,
824824
},
825825
Queries: []string{
826826
`SELECT * FROM test ORDER BY pk;`,

enginetest/queries/vector_ddl_queries.go

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,15 @@ var VectorDDLQueries = []ScriptTest{
3838
},
3939
Assertions: []ScriptTestAssertion{
4040
{
41-
Query: `SHOW CREATE TABLE vectors`,
42-
Expected: []sql.Row{{"vectors", "CREATE TABLE `vectors` (\n `id` int NOT NULL,\n `small_vec` VECTOR(2),\n `medium_vec` VECTOR(10),\n `large_vec` VECTOR(1000), \n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
41+
Query: `SHOW CREATE TABLE test_vectors`,
42+
Expected: []sql.Row{{"test_vectors", "CREATE TABLE `test_vectors` (\n `id` int NOT NULL,\n `small_vec` VECTOR(2),\n `medium_vec` VECTOR(10),\n `large_vec` VECTOR(1000),\n PRIMARY KEY (`id`)\n) ENGINE=InnoDB DEFAULT CHARSET=utf8mb4 COLLATE=utf8mb4_0900_bin"}},
4343
},
4444
{
45-
Query: `SELECT id, small_vec, medium_vec FROM vectors WHERE id = 2`,
45+
Query: `SELECT id, small_vec, medium_vec FROM test_vectors WHERE id = 2`,
4646
Expected: []sql.Row{{2, []float32{3.5, 4.5}, nil}},
4747
},
4848
{
49-
Query: `SELECT id, small_vec, medium_vec FROM vectors WHERE id = 1`,
49+
Query: `SELECT id, small_vec, medium_vec FROM test_vectors WHERE id = 1`,
5050
Expected: []sql.Row{{1, []float32{1.0, 2.0}, []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0}}},
5151
},
5252
{
@@ -57,21 +57,20 @@ var VectorDDLQueries = []ScriptTest{
5757
},
5858
},
5959
{
60-
Query: `UPDATE test_vectors SET small_vec = '[10.0, 20.0]' WHERE id = 1`,
60+
Query: `UPDATE test_vectors SET small_vec = STRING_TO_VECTOR('[10.0, 20.0]') WHERE id = 1`,
6161
Expected: []sql.Row{{types.OkResult{RowsAffected: 1, Info: plan.UpdateInfo{Matched: 1, Updated: 1}}}},
6262
},
6363
{
6464
Query: `SELECT small_vec FROM test_vectors WHERE id = 1`,
6565
Expected: []sql.Row{{[]float32{10.0, 20.0}}},
6666
},
6767
{
68-
Query: `INSERT INTO test_vectors VALUES
69-
(3, 0x00F0803F00004040, NULL, NULL)`,
68+
Query: `INSERT INTO test_vectors VALUES (3, 0x0000204100002041, NULL, NULL)`, // [10.0, 10.0]
7069
Expected: []sql.Row{{types.NewOkResult(1)}},
7170
},
7271
{
7372
Query: `SELECT small_vec FROM test_vectors WHERE id = 3`,
74-
Expected: []sql.Row{{[]float32{1.0, 2.0}}},
73+
Expected: []sql.Row{{[]float32{10.0, 10.0}}},
7574
},
7675
},
7776
},
@@ -83,7 +82,7 @@ var VectorDDLQueries = []ScriptTest{
8382
Assertions: []ScriptTestAssertion{
8483
{
8584
Query: `INSERT INTO error_vectors VALUES (1, '[1.0, 2.0]')`,
86-
ExpectedErrStr: "",
85+
ExpectedErrStr: "value of type string cannot be converted to 'vector' type",
8786
},
8887
{
8988
Query: `INSERT INTO error_vectors VALUES (1, STRING_TO_VECTOR('[1.0, 2.0]'))`,
@@ -95,23 +94,23 @@ var VectorDDLQueries = []ScriptTest{
9594
},
9695
{
9796
Query: `INSERT INTO error_vectors VALUES (3, STRING_TO_VECTOR('[1.0, invalid, 3.0]'))`,
98-
ExpectedErrStr: "invalid VECTOR JSON format: invalid character 'i' looking for beginning of value",
97+
ExpectedErrStr: "can't convert JSON to vector: invalid character 'i' looking for beginning of value",
9998
},
10099
{
101100
Query: `INSERT INTO error_vectors VALUES (4, STRING_TO_VECTOR('invalid_json'))`,
102-
ExpectedErrStr: "VECTOR must be in JSON array format",
101+
ExpectedErrStr: "can't convert JSON to vector: invalid character 'i' looking for beginning of value",
103102
},
104103
{
105104
Query: `INSERT INTO error_vectors VALUES (5, STRING_TO_VECTOR('[1.0, "not an array"]'))`,
106-
ExpectedErrStr: "VECTOR must be in JSON array format",
105+
ExpectedErrStr: "can't convert JSON to vector; expected array of floats, but array contained string",
107106
},
108107
{
109108
Query: `INSERT INTO error_vectors VALUES (5, STRING_TO_VECTOR('"not an array"'))`,
110-
ExpectedErrStr: "VECTOR must be in JSON array format",
109+
ExpectedErrStr: "can't convert JSON to vector; expected array, got string",
111110
},
112111
{
113-
Query: `CREATE TABLE error_vectors (id INT PRIMARY KEY, vec3 VECTOR(-3))`,
114-
ExpectedErr: sql.ErrInvalidColTypeDefinition,
112+
Query: `CREATE TABLE error_vectors (id INT PRIMARY KEY, vec3 VECTOR(-3))`,
113+
ExpectedErrStr: "syntax error at position 62 near 'VECTOR'",
115114
},
116115
{
117116
Query: `CREATE TABLE error_vectors (id INT PRIMARY KEY, vec3 VECTOR(0))`,

enginetest/queries/vector_function_queries.go

Lines changed: 10 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ var VectorFunctionTestCases = []VectorFunctionTestCase{
4242
hex: "0000803F",
4343
length: 4,
4444
base64: "AACAPw==",
45-
md5: "84b7e6bf8e3cd674011866c606a8011d",
45+
md5: "429d81ed2795e3c586906c6c335aa136",
4646
sha1: "5bb96baed2a67ef718989bf7de91433ca9b9f8cf",
4747
sha2: "e00e5eb9444182f352323374ef4e08ebcb784725fdd4fd612d7730540b3e0c8c",
4848
},
@@ -53,7 +53,7 @@ var VectorFunctionTestCases = []VectorFunctionTestCase{
5353
charLength: 8,
5454
hex: "0000004000004040",
5555
length: 8,
56-
base64: "AAAAQAAAQEA",
56+
base64: "AAAAQAAAQEA=",
5757
md5: "f37b6e459e9e2d49261fe42d3a7bff07",
5858
sha1: "fd3352c0e141970e5b1c45d1755760d018cfe32d",
5959
sha2: "2fd848aa90e817e10e20985de4e8ac6a09b0fe70623d6b952e46800be6b025b9",
@@ -122,15 +122,15 @@ var VectorFunctionQueries = []ScriptTest{
122122
},
123123
{
124124
Query: `SELECT VECTOR_TO_STRING(STRING_TO_VECTOR("[1.0, 2.0]"));`,
125-
Expected: []sql.Row{{"[1.0, 2.0]"}},
125+
Expected: []sql.Row{{"[1, 2]"}},
126126
},
127127
{
128128
Query: `SELECT FROM_VECTOR(TO_VECTOR("[1.0, 2.0]"));`,
129-
Expected: []sql.Row{{"[1.0, 2.0]"}},
129+
Expected: []sql.Row{{"[1, 2]"}},
130130
},
131131
{
132132
Query: `SELECT VEC_ToText(VEC_FromText("[1.0, 2.0]"));`,
133-
Expected: []sql.Row{{"[1.0, 2.0]"}},
133+
Expected: []sql.Row{{"[1, 2]"}},
134134
},
135135
},
136136
},
@@ -151,58 +151,19 @@ var VectorFunctionQueries = []ScriptTest{
151151
},
152152
{
153153
Query: "select VEC_DISTANCE_EUCLIDEAN('[1.0, 2.0]', '[5.0, 5.0]');",
154-
Expected: []sql.Row{{float32(5.0)}},
154+
Expected: []sql.Row{{5.0}},
155155
},
156156
{
157157
Query: `SELECT DISTANCE(STRING_TO_VECTOR("[0.0, 0.0]"), STRING_TO_VECTOR("[3.0, 4.0]"), "EUCLIDEAN");`,
158-
Expected: []sql.Row{{float32(5.0)}},
158+
Expected: []sql.Row{{5.0}},
159159
},
160160
{
161161
Query: "select VEC_DISTANCE_COSINE(STRING_TO_VECTOR('[0.0, 3.0]'), '[5.0, 5.0]');",
162-
Expected: []sql.Row{{float32(15.0)}},
162+
Expected: []sql.Row{{0.29289321881345254}},
163163
},
164164
{
165-
Query: `SELECT DISTANCE("[0.0, 3.0]", STRING_TO_VECTOR("[5.0, 5.0]"), "COSINE");`,
166-
Expected: []sql.Row{{float32(5.0)}},
167-
},
168-
{
169-
Query: "select * from vectors order by VEC_DISTANCE('[0.0,0.0]', v)",
170-
Expected: []sql.Row{
171-
{2, types.MustJSON(`[0.0, 0.0]`)},
172-
{3, types.MustJSON(`[1.0, -1.0]`)},
173-
{4, types.MustJSON(`[-2.0, 0.0]`)},
174-
{1, types.MustJSON(`[3.0, 4.0]`)},
175-
},
176-
},
177-
{
178-
Query: "select * from vectors order by VEC_DISTANCE_L2_SQUARED('[-2.0,0.0]', v)",
179-
Expected: []sql.Row{
180-
{4, types.MustJSON(`[-2.0, 0.0]`)},
181-
{2, types.MustJSON(`[0.0, 0.0]`)},
182-
{3, types.MustJSON(`[1.0, -1.0]`)},
183-
{1, types.MustJSON(`[3.0, 4.0]`)},
184-
},
185-
},
186-
},
187-
},
188-
{
189-
Name: "test that existing functions accept vectors",
190-
Assertions: []ScriptTestAssertion{
191-
{
192-
Query: "select VEC_DISTANCE_EUCLIDEAN('[1.0, 2.0]', '[5.0, 5.0]');",
193-
Expected: []sql.Row{{float32(5.0)}},
194-
},
195-
{
196-
Query: `SELECT DISTANCE(STRING_TO_VECTOR("[0.0, 0.0]"), STRING_TO_VECTOR("[3.0, 4.0]"), "EUCLIDEAN");`,
197-
Expected: []sql.Row{{float32(5.0)}},
198-
},
199-
{
200-
Query: "select VEC_DISTANCE_COSINE(STRING_TO_VECTOR('[0.0, 3.0]'), '[5.0, 5.0]');",
201-
Expected: []sql.Row{{float32(15.0)}},
202-
},
203-
{
204-
Query: `SELECT DISTANCE("[0.0, 3.0]", STRING_TO_VECTOR("[5.0, 5.0]"), "COSINE");`,
205-
Expected: []sql.Row{{float32(5.0)}},
165+
Query: `SELECT DISTANCE("[1.0, 1.0]", STRING_TO_VECTOR("[-1.0, 1.0]"), "COSINE");`,
166+
Expected: []sql.Row{{1.0}},
206167
},
207168
{
208169
Query: "select * from vectors order by VEC_DISTANCE('[0.0,0.0]', v)",

sql/expression/function/length.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,10 @@ func (l *Length) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
123123
return int32(wrapper.MaxByteLength()), nil
124124
}
125125

126+
// Getting the length of a vector doesn't require converting it to a string, it just returns the length in bytes
127+
if vec, ok := val.([]float32); ok {
128+
return int32(4 * len(vec)), nil
129+
}
126130
content, collation, err := types.ConvertToCollatedString(ctx, val, l.Child.Type())
127131
if err != nil {
128132
return nil, err

sql/expression/function/registry.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,15 @@ var BuiltIns = []sql.Function{
327327
sql.Function1{Name: "var_samp", Fn: func(e sql.Expression) sql.Expression { return aggregation.NewVarSamp(e) }},
328328
sql.Function2{Name: "vec_distance", Fn: vector.NewL2SquaredDistance},
329329
sql.Function2{Name: "vec_distance_l2_squared", Fn: vector.NewL2SquaredDistance},
330+
sql.Function2{Name: "vec_distance_euclidean", Fn: vector.NewEuclideanDistance},
331+
sql.Function2{Name: "vec_distance_cosine", Fn: vector.NewCosineDistance},
332+
sql.FunctionN{Name: "distance", Fn: vector.NewGenericDistance},
333+
sql.Function1{Name: "string_to_vector", Fn: vector.NewStringToVector},
334+
sql.Function1{Name: "to_vector", Fn: vector.NewStringToVector},
335+
sql.Function1{Name: "vec_fromtext", Fn: vector.NewStringToVector},
336+
sql.Function1{Name: "vector_to_string", Fn: vector.NewVectorToString},
337+
sql.Function1{Name: "from_vector", Fn: vector.NewVectorToString},
338+
sql.Function1{Name: "vec_totext", Fn: vector.NewVectorToString},
330339
sql.Function1{Name: "weekday", Fn: NewWeekday},
331340
sql.Function1{Name: "weekofyear", Fn: NewWeekOfYear},
332341
sql.Function1{Name: "year", Fn: NewYear},

sql/expression/function/string.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ package function
1616

1717
import (
1818
"bytes"
19+
"encoding/binary"
1920
"encoding/hex"
2021
"fmt"
2122
"math"
@@ -256,6 +257,13 @@ func (h *Hex) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
256257
}
257258
return hexForString(string(b)), nil
258259

260+
case []float32:
261+
buf := make([]byte, 4*len(val))
262+
for i, v := range val {
263+
binary.Encode(buf[4*i:], binary.LittleEndian, v)
264+
}
265+
return hexForString(string(buf)), nil
266+
259267
case types.GeometryValue:
260268
return hexForString(string(val.Serialize())), nil
261269

@@ -601,6 +609,8 @@ func (h *Bitlength) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
601609
return 64, nil
602610
case string:
603611
return 8 * len([]byte(val)), nil
612+
case []float32:
613+
return 32 * len(val), nil
604614
case time.Time:
605615
return 128, nil
606616
}
Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// Copyright 2024 Dolthub, Inc.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
package vector
16+
17+
import (
18+
"fmt"
19+
20+
"github.com/dolthub/go-mysql-server/sql"
21+
"github.com/dolthub/go-mysql-server/sql/expression"
22+
"github.com/dolthub/go-mysql-server/sql/types"
23+
)
24+
25+
// StringToVector converts a JSON string representation to a vector
26+
type StringToVector struct {
27+
expression.UnaryExpression
28+
}
29+
30+
var _ sql.Expression = (*StringToVector)(nil)
31+
var _ sql.FunctionExpression = (*StringToVector)(nil)
32+
var _ sql.CollationCoercible = (*StringToVector)(nil)
33+
34+
func NewStringToVector(e sql.Expression) sql.Expression {
35+
return &StringToVector{UnaryExpression: expression.UnaryExpression{Child: e}}
36+
}
37+
38+
func (s *StringToVector) FunctionName() string {
39+
return "string_to_vector"
40+
}
41+
42+
func (s *StringToVector) Description() string {
43+
return "converts a JSON array string to a vector"
44+
}
45+
46+
func (s *StringToVector) Type() sql.Type {
47+
return types.VectorType{}
48+
}
49+
50+
func (s *StringToVector) CollationCoercibility(_ *sql.Context) (collation sql.CollationID, coercibility byte) {
51+
return sql.Collation_binary, 5
52+
}
53+
54+
func (s *StringToVector) String() string {
55+
return fmt.Sprintf("STRING_TO_VECTOR(%s)", s.Child)
56+
}
57+
58+
func (s *StringToVector) WithChildren(children ...sql.Expression) (sql.Expression, error) {
59+
if len(children) != 1 {
60+
return nil, sql.ErrInvalidChildrenNumber.New(s, len(children), 1)
61+
}
62+
return NewStringToVector(children[0]), nil
63+
}
64+
65+
func (s *StringToVector) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
66+
val, err := s.Child.Eval(ctx, row)
67+
if err != nil {
68+
return nil, err
69+
}
70+
if val == nil {
71+
return nil, nil
72+
}
73+
74+
return sql.ConvertToVector(ctx, val)
75+
}
76+
77+
// VectorToString converts a vector to a JSON string representation
78+
type VectorToString struct {
79+
expression.UnaryExpression
80+
}
81+
82+
var _ sql.Expression = (*VectorToString)(nil)
83+
var _ sql.FunctionExpression = (*VectorToString)(nil)
84+
var _ sql.CollationCoercible = (*VectorToString)(nil)
85+
86+
func NewVectorToString(e sql.Expression) sql.Expression {
87+
return &VectorToString{UnaryExpression: expression.UnaryExpression{Child: e}}
88+
}
89+
90+
func (v *VectorToString) FunctionName() string {
91+
return "vector_to_string"
92+
}
93+
94+
func (v *VectorToString) Description() string {
95+
return "converts a vector to a JSON array string"
96+
}
97+
98+
func (v *VectorToString) Type() sql.Type {
99+
return types.LongText
100+
}
101+
102+
func (v *VectorToString) CollationCoercibility(_ *sql.Context) (collation sql.CollationID, coercibility byte) {
103+
return sql.Collation_binary, 5
104+
}
105+
106+
func (v *VectorToString) String() string {
107+
return fmt.Sprintf("VECTOR_TO_STRING(%s)", v.Child)
108+
}
109+
110+
func (v *VectorToString) WithChildren(children ...sql.Expression) (sql.Expression, error) {
111+
if len(children) != 1 {
112+
return nil, sql.ErrInvalidChildrenNumber.New(v, len(children), 1)
113+
}
114+
return NewVectorToString(children[0]), nil
115+
}
116+
117+
func (v *VectorToString) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
118+
val, err := v.Child.Eval(ctx, row)
119+
if err != nil {
120+
return nil, err
121+
}
122+
if val == nil {
123+
return nil, nil
124+
}
125+
return types.JSONDocument{Val: val}.JSONString()
126+
}

0 commit comments

Comments
 (0)