Skip to content

Commit 6ba7124

Browse files
committed
Store vectors in the engine as byte array instead of float array. It simplifies the logic a lot.
1 parent b9de299 commit 6ba7124

File tree

12 files changed

+111
-91
lines changed

12 files changed

+111
-91
lines changed

enginetest/memory_engine_test.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,9 +1001,6 @@ func TestDatabaseCollationWire(t *testing.T) {
10011001
}
10021002

10031003
func TestTypesOverWire(t *testing.T) {
1004-
if _, ok := os.LookupEnv("CI_TEST"); !ok {
1005-
t.Skip("Skipping test that requires CI_TEST=true")
1006-
}
10071004
harness := enginetest.NewDefaultMemoryHarness()
10081005
enginetest.TestTypesOverWire(t, harness, harness.SessionBuilder())
10091006
}

enginetest/queries/type_wire_queries.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
package queries
1616

1717
import (
18-
"encoding/binary"
19-
2018
"github.com/dolthub/go-mysql-server/sql"
2119
)
2220

@@ -30,9 +28,11 @@ type TypeWireTest struct {
3028
}
3129

3230
func floatsToString(fs ...float32) string {
33-
result := make([]byte, 4*len(fs))
34-
binary.Encode(result, binary.LittleEndian, fs)
35-
return string(result)
31+
return string(sql.EncodeVector(fs))
32+
}
33+
34+
func floatsToBytes(fs ...float32) []byte {
35+
return sql.EncodeVector(fs)
3636
}
3737

3838
// TypeWireTests are used to ensure that types are properly represented over the wire (vs being directly returned from

enginetest/queries/vector_ddl_queries.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,17 +43,17 @@ var VectorDDLQueries = []ScriptTest{
4343
},
4444
{
4545
Query: `SELECT id, small_vec, medium_vec FROM test_vectors WHERE id = 2`,
46-
Expected: []sql.Row{{2, []float32{3.5, 4.5}, nil}},
46+
Expected: []sql.Row{{2, floatsToBytes(3.5, 4.5), nil}},
4747
},
4848
{
4949
Query: `SELECT id, small_vec, medium_vec FROM test_vectors WHERE id = 1`,
50-
Expected: []sql.Row{{1, []float32{1.0, 2.0}, []float32{1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0}}},
50+
Expected: []sql.Row{{1, floatsToBytes(1.0, 2.0), floatsToBytes(1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0)}},
5151
},
5252
{
5353
Query: `SELECT id, small_vec FROM test_vectors ORDER BY id`,
5454
Expected: []sql.Row{
55-
{1, []float32{1.0, 2.0}},
56-
{2, []float32{3.5, 4.5}},
55+
{1, floatsToBytes(1.0, 2.0)},
56+
{2, floatsToBytes(3.5, 4.5)},
5757
},
5858
},
5959
{
@@ -62,15 +62,15 @@ var VectorDDLQueries = []ScriptTest{
6262
},
6363
{
6464
Query: `SELECT small_vec FROM test_vectors WHERE id = 1`,
65-
Expected: []sql.Row{{[]float32{10.0, 20.0}}},
65+
Expected: []sql.Row{{floatsToBytes(10.0, 20.0)}},
6666
},
6767
{
6868
Query: `INSERT INTO test_vectors VALUES (3, 0x0000204100002041, NULL, NULL)`, // [10.0, 10.0]
6969
Expected: []sql.Row{{types.NewOkResult(1)}},
7070
},
7171
{
7272
Query: `SELECT small_vec FROM test_vectors WHERE id = 3`,
73-
Expected: []sql.Row{{[]float32{10.0, 10.0}}},
73+
Expected: []sql.Row{{floatsToBytes(10.0, 10.0)}},
7474
},
7575
},
7676
},
@@ -138,9 +138,9 @@ var VectorDDLQueries = []ScriptTest{
138138
{
139139
Query: `SELECT id, vec2 FROM format_vectors ORDER BY id`,
140140
Expected: []sql.Row{
141-
{1, []float32{1.0, 2.0}},
142-
{2, []float32{3.0, 4.0}},
143-
{3, []float32{5.5, 6700}},
141+
{1, floatsToBytes(1.0, 2.0)},
142+
{2, floatsToBytes(3.0, 4.0)},
143+
{3, floatsToBytes(5.5, 6700)},
144144
},
145145
},
146146
},

enginetest/queries/vector_function_queries.go

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,20 +111,25 @@ var VectorFunctionQueries = []ScriptTest{
111111
Assertions: []ScriptTestAssertion{
112112
{
113113
Query: `SELECT STRING_TO_VECTOR("[1.0, 2.0]");`,
114-
Expected: []sql.Row{{[]float32{1.0, 2.0}}},
114+
Expected: []sql.Row{{floatsToBytes(1.0, 2.0)}},
115115
},
116116
{
117117
Query: `SELECT TO_VECTOR("[1.0, 2.0]");`,
118-
Expected: []sql.Row{{[]float32{1.0, 2.0}}},
118+
Expected: []sql.Row{{floatsToBytes(1.0, 2.0)}},
119119
},
120120
{
121121
Query: `SELECT VEC_FromText("[1.0, 2.0]");`,
122-
Expected: []sql.Row{{[]float32{1.0, 2.0}}},
122+
Expected: []sql.Row{{floatsToBytes(1.0, 2.0)}},
123123
},
124124
{
125125
Query: `SELECT VECTOR_TO_STRING(STRING_TO_VECTOR("[1.0, 2.0]"));`,
126126
Expected: []sql.Row{{"[1, 2]"}},
127127
},
128+
{
129+
Query: `select VECTOR_TO_STRING(0x0000803F);`,
130+
Expected: []sql.Row{{"[1]"}},
131+
},
132+
128133
{
129134
Query: `SELECT FROM_VECTOR(TO_VECTOR("[1.0, 2.0]"));`,
130135
Expected: []sql.Row{{"[1, 2]"}},

enginetest/server_engine.go

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ package enginetest
1616

1717
import (
1818
gosql "database/sql"
19-
"encoding/binary"
2019
"encoding/json"
2120
"errors"
2221
"fmt"
@@ -398,13 +397,6 @@ func convertValue(ctx *sql.Context, sch sql.Schema, row sql.Row) sql.Row {
398397
row[i] = r
399398
}
400399
}
401-
case query.Type_VECTOR:
402-
if row[i] != nil {
403-
r := row[i].([]byte)
404-
dimensions := len(r) / 4
405-
row[i] = make([]float32, dimensions)
406-
binary.Decode(r, binary.LittleEndian, row[i])
407-
}
408400
case query.Type_TIME:
409401
if row[i] != nil {
410402
r, _, err := types.TimespanType_{}.Convert(ctx, string(row[i].([]byte)))

sql/core.go

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,13 @@ import (
1818
"context"
1919
"encoding/json"
2020
"fmt"
21+
"github.com/dolthub/go-mysql-server/sql/values"
2122
trace2 "runtime/trace"
2223
"strconv"
2324
"strings"
2425
"sync/atomic"
2526
"time"
27+
"unsafe"
2628

2729
"github.com/shopspring/decimal"
2830
)
@@ -325,10 +327,22 @@ func ConvertToBool(ctx *Context, v interface{}) (bool, error) {
325327
}
326328
}
327329

330+
// DecodeVector decodes a byte slice that represents a vector. This is needed for distance functions.
331+
func DecodeVector(buf []byte) []float32 {
332+
return unsafe.Slice((*float32)(unsafe.Pointer(&buf[0])), len(buf)/int(values.Float32Size))
333+
}
334+
335+
// EncodeVector encodes a byte slice that represents a vector.
336+
func EncodeVector(floats []float32) []byte {
337+
return unsafe.Slice((*byte)(unsafe.Pointer(&floats[0])), len(floats)*int(values.Float32Size))
338+
}
339+
328340
func ConvertToVector(ctx context.Context, v interface{}) ([]float32, error) {
329341
switch b := v.(type) {
330342
case []float32:
331343
return b, nil
344+
case []byte:
345+
return DecodeVector(b), nil
332346
case string:
333347
var val interface{}
334348
err := json.Unmarshal([]byte(b), &val)

sql/expression/function/length.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,6 @@ func (l *Length) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
123123
return int32(wrapper.MaxByteLength()), nil
124124
}
125125

126-
// Getting the length of a vector doesn't require converting it to a string, it just returns the length in bytes
127-
if vec, ok := val.([]float32); ok {
128-
return int32(4 * len(vec)), nil
129-
}
130126
content, collation, err := types.ConvertToCollatedString(ctx, val, l.Child.Type())
131127
if err != nil {
132128
return nil, err

sql/expression/function/string.go

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ package function
1616

1717
import (
1818
"bytes"
19-
"encoding/binary"
2019
"encoding/hex"
2120
"fmt"
2221
"math"
@@ -257,11 +256,6 @@ func (h *Hex) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
257256
}
258257
return hexForString(string(b)), nil
259258

260-
case []float32:
261-
buf := make([]byte, 4*len(val))
262-
binary.Encode(buf, binary.LittleEndian, val)
263-
return hexForString(string(buf)), nil
264-
265259
case types.GeometryValue:
266260
return hexForString(string(val.Serialize())), nil
267261

@@ -607,8 +601,8 @@ func (h *Bitlength) Eval(ctx *sql.Context, row sql.Row) (interface{}, error) {
607601
return 64, nil
608602
case string:
609603
return 8 * len([]byte(val)), nil
610-
case []float32:
611-
return 32 * len(val), nil
604+
case []byte:
605+
return 8 * len(val), nil
612606
case time.Time:
613607
return 128, nil
614608
}

sql/expression/function/vector/conversion.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,13 @@ func (s *StringToVector) Eval(ctx *sql.Context, row sql.Row) (interface{}, error
7171
return nil, nil
7272
}
7373

74-
return sql.ConvertToVector(ctx, val)
74+
// TODO: Instead of using the JSON parser and then encoding, it would be more efficient to parse and encode
75+
// in a single step
76+
floats, err := sql.ConvertToVector(ctx, val)
77+
if err != nil {
78+
return nil, err
79+
}
80+
return sql.EncodeVector(floats), nil
7581
}
7682

7783
// VectorToString converts a vector to a JSON string representation
@@ -122,5 +128,9 @@ func (v *VectorToString) Eval(ctx *sql.Context, row sql.Row) (interface{}, error
122128
if val == nil {
123129
return nil, nil
124130
}
125-
return types.JSONDocument{Val: val}.JSONString()
131+
b, ok := val.([]byte)
132+
if !ok {
133+
return nil, fmt.Errorf("incorrect argument to VECTOR_TO_STRING: expected a vector, got %T", val)
134+
}
135+
return types.JSONDocument{Val: sql.DecodeVector(b)}.JSONString()
126136
}

sql/types/strings.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ package types
1616

1717
import (
1818
"context"
19-
"encoding/binary"
2019
"fmt"
2120
"reflect"
2221
"strconv"
@@ -414,9 +413,6 @@ func ConvertToBytes(ctx context.Context, v interface{}, t sql.StringType, dest [
414413
// We'll check for that below, immediately before extending the slice.
415414
val = s
416415
start = 0
417-
case []float32:
418-
val = make([]byte, 4*len(s))
419-
binary.Encode(val, binary.LittleEndian, s)
420416
case time.Time:
421417
val = s.AppendFormat(dest, sql.TimestampDatetimeLayout)
422418
case decimal.Decimal:

0 commit comments

Comments
 (0)