Skip to content

Commit d8161a3

Browse files
committed
Make vectors arrays of 32-bit floats, not 64-bit floats.
1 parent 7a21abe commit d8161a3

File tree

3 files changed

+41
-25
lines changed

3 files changed

+41
-25
lines changed

sql/core.go

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -325,15 +325,15 @@ func ConvertToBool(ctx *Context, v interface{}) (bool, error) {
325325
}
326326
}
327327

328-
func ConvertToVector(ctx context.Context, v interface{}) ([]float64, error) {
328+
func ConvertToVector(ctx context.Context, v interface{}) ([]float32, error) {
329329
switch b := v.(type) {
330-
case []float64:
330+
case []float32:
331331
return b, nil
332332
case string:
333333
var val interface{}
334334
err := json.Unmarshal([]byte(b), &val)
335335
if err != nil {
336-
return nil, err
336+
return nil, fmt.Errorf("can't convert JSON to vector: %w", err)
337337
}
338338
return convertJsonInterfaceToVector(val)
339339
case JSONWrapper:
@@ -347,18 +347,25 @@ func ConvertToVector(ctx context.Context, v interface{}) ([]float64, error) {
347347
}
348348
}
349349

350-
func convertJsonInterfaceToVector(val interface{}) ([]float64, error) {
350+
func convertJsonInterfaceToVector(val interface{}) ([]float32, error) {
351351
array, ok := val.([]interface{})
352352
if !ok {
353-
return nil, fmt.Errorf("can't convert JSON to vector; expected array, got %v", val)
353+
return nil, fmt.Errorf("can't convert JSON to vector; expected array, got %T", val)
354354
}
355-
res := make([]float64, len(array))
355+
res := make([]float32, len(array))
356356
for i, elem := range array {
357-
floatElem, ok := elem.(float64)
358-
if !ok {
359-
return nil, fmt.Errorf("can't convert JSON to vector; expected array of floats, got %v", elem)
357+
switch v := elem.(type) {
358+
case float32:
359+
res[i] = v
360+
case float64:
361+
res[i] = float32(v)
362+
case int64:
363+
res[i] = float32(v)
364+
case int32:
365+
res[i] = float32(v)
366+
default:
367+
return nil, fmt.Errorf("can't convert JSON to vector; expected array of floats, but array contained %T", elem)
360368
}
361-
res[i] = floatElem
362369
}
363370
return res, nil
364371
}

sql/expression/function/vector/distance.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import (
2525

2626
type DistanceType interface {
2727
String() string
28-
Eval(left []float64, right []float64) (float64, error)
28+
Eval(left []float32, right []float32) (float64, error)
2929
CanEval(distanceType DistanceType) bool
3030
FunctionName() string
3131
Description() string
@@ -40,14 +40,14 @@ func (d DistanceL2Squared) String() string {
4040
return "VEC_DISTANCE_L2_SQUARED"
4141
}
4242

43-
func (d DistanceL2Squared) Eval(left []float64, right []float64) (float64, error) {
43+
func (d DistanceL2Squared) Eval(left []float32, right []float32) (float64, error) {
4444
if len(left) != len(right) {
4545
return 0, fmt.Errorf("attempting to find distance between vectors of different lengths: %d vs %d", len(left), len(right))
4646
}
4747
var total float64 = 0
4848
for i, l := range left {
4949
r := right[i]
50-
total += (l - r) * (l - r)
50+
total += float64(l-r) * float64(l-r)
5151
}
5252
return total, nil
5353
}

sql/types/vector.go

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import (
2727
"github.com/dolthub/go-mysql-server/sql"
2828
)
2929

30-
var vectorValueType = reflect.TypeOf([]float64{})
30+
var vectorValueType = reflect.TypeOf([]float32{})
3131

3232
// VectorType represents the VECTOR(N) type.
3333
// It stores a fixed-length array of N floating point numbers.
@@ -64,8 +64,8 @@ func (t VectorType) Compare(ctx context.Context, a interface{}, b interface{}) (
6464
return 0, err
6565
}
6666

67-
avec := av.([]float64)
68-
bvec := bv.([]float64)
67+
avec := av.([]float32)
68+
bvec := bv.([]float32)
6969

7070
aBytes := make([]byte, 4)
7171
bBytes := make([]byte, 4)
@@ -88,7 +88,16 @@ func (t VectorType) Convert(ctx context.Context, v interface{}) (interface{}, sq
8888
}
8989

9090
switch val := v.(type) {
91-
case []float64:
91+
case []byte:
92+
if len(val) != 4*t.Dimensions {
93+
return nil, sql.OutOfRange, fmt.Errorf("cannot convert BINARY(%d) to VECTOR(%d), need BINARY(%d)", len(val), t.Dimensions, 4*t.Dimensions)
94+
}
95+
result := make([]float32, t.Dimensions)
96+
for i := range result {
97+
binary.Decode(val[4*i:4*(i+1)], binary.LittleEndian, &result[i])
98+
}
99+
return result, sql.InRange, nil
100+
case []float32:
92101
if len(val) != t.Dimensions {
93102
return nil, sql.OutOfRange, fmt.Errorf("VECTOR dimension mismatch: expected %d, got %d", t.Dimensions, len(val))
94103
}
@@ -97,26 +106,26 @@ func (t VectorType) Convert(ctx context.Context, v interface{}) (interface{}, sq
97106
if len(val) != t.Dimensions {
98107
return nil, sql.OutOfRange, fmt.Errorf("VECTOR dimension mismatch: expected %d, got %d", t.Dimensions, len(val))
99108
}
100-
result := make([]float64, t.Dimensions)
109+
result := make([]float32, t.Dimensions)
101110
for i, elem := range val {
102111
switch e := elem.(type) {
103112
case float64:
104-
result[i] = e
113+
result[i] = float32(e)
105114
case float32:
106-
result[i] = float64(e)
115+
result[i] = e
107116
case int:
108-
result[i] = float64(e)
117+
result[i] = float32(e)
109118
case int64:
110-
result[i] = float64(e)
119+
result[i] = float32(e)
111120
case int32:
112-
result[i] = float64(e)
121+
result[i] = float32(e)
113122
default:
114123
if str, ok := elem.(string); ok {
115124
f, err := strconv.ParseFloat(str, 64)
116125
if err != nil {
117126
return nil, sql.OutOfRange, fmt.Errorf("invalid vector element: %v", elem)
118127
}
119-
result[i] = f
128+
result[i] = float32(f)
120129
} else {
121130
return nil, sql.OutOfRange, fmt.Errorf("invalid vector element: %v", elem)
122131
}
@@ -182,7 +191,7 @@ func (t VectorType) ValueType() reflect.Type {
182191

183192
// Zero implements Type interface.
184193
func (t VectorType) Zero() interface{} {
185-
return make([]float64, t.Dimensions)
194+
return make([]float32, t.Dimensions)
186195
}
187196

188197
// CollationCoercibility implements sql.CollationCoercible interface.

0 commit comments

Comments
 (0)