@@ -27,7 +27,7 @@ import (
2727 "github.com/dolthub/go-mysql-server/sql"
2828)
2929
30- var vectorValueType = reflect .TypeOf ([]float64 {})
30+ var vectorValueType = reflect .TypeOf ([]float32 {})
3131
3232// VectorType represents the VECTOR(N) type.
3333// It stores a fixed-length array of N floating point numbers.
@@ -64,8 +64,8 @@ func (t VectorType) Compare(ctx context.Context, a interface{}, b interface{}) (
6464 return 0 , err
6565 }
6666
67- avec := av .([]float64 )
68- bvec := bv .([]float64 )
67+ avec := av .([]float32 )
68+ bvec := bv .([]float32 )
6969
7070 aBytes := make ([]byte , 4 )
7171 bBytes := make ([]byte , 4 )
@@ -88,7 +88,16 @@ func (t VectorType) Convert(ctx context.Context, v interface{}) (interface{}, sq
8888 }
8989
9090 switch val := v .(type ) {
91- case []float64 :
91+ case []byte :
92+ if len (val ) != 4 * t .Dimensions {
93+ return nil , sql .OutOfRange , fmt .Errorf ("cannot convert BINARY(%d) to VECTOR(%d), need BINARY(%d)" , len (val ), t .Dimensions , 4 * t .Dimensions )
94+ }
95+ result := make ([]float32 , t .Dimensions )
96+ for i := range result {
97+ binary .Decode (val [4 * i :4 * (i + 1 )], binary .LittleEndian , & result [i ])
98+ }
99+ return result , sql .InRange , nil
100+ case []float32 :
92101 if len (val ) != t .Dimensions {
93102 return nil , sql .OutOfRange , fmt .Errorf ("VECTOR dimension mismatch: expected %d, got %d" , t .Dimensions , len (val ))
94103 }
@@ -97,26 +106,26 @@ func (t VectorType) Convert(ctx context.Context, v interface{}) (interface{}, sq
97106 if len (val ) != t .Dimensions {
98107 return nil , sql .OutOfRange , fmt .Errorf ("VECTOR dimension mismatch: expected %d, got %d" , t .Dimensions , len (val ))
99108 }
100- result := make ([]float64 , t .Dimensions )
109+ result := make ([]float32 , t .Dimensions )
101110 for i , elem := range val {
102111 switch e := elem .(type ) {
103112 case float64 :
104- result [i ] = e
113+ result [i ] = float32 ( e )
105114 case float32 :
106- result [i ] = float64 ( e )
115+ result [i ] = e
107116 case int :
108- result [i ] = float64 (e )
117+ result [i ] = float32 (e )
109118 case int64 :
110- result [i ] = float64 (e )
119+ result [i ] = float32 (e )
111120 case int32 :
112- result [i ] = float64 (e )
121+ result [i ] = float32 (e )
113122 default :
114123 if str , ok := elem .(string ); ok {
115124 f , err := strconv .ParseFloat (str , 64 )
116125 if err != nil {
117126 return nil , sql .OutOfRange , fmt .Errorf ("invalid vector element: %v" , elem )
118127 }
119- result [i ] = f
128+ result [i ] = float32 ( f )
120129 } else {
121130 return nil , sql .OutOfRange , fmt .Errorf ("invalid vector element: %v" , elem )
122131 }
@@ -182,7 +191,7 @@ func (t VectorType) ValueType() reflect.Type {
182191
183192// Zero implements Type interface.
184193func (t VectorType ) Zero () interface {} {
185- return make ([]float64 , t .Dimensions )
194+ return make ([]float32 , t .Dimensions )
186195}
187196
188197// CollationCoercibility implements sql.CollationCoercible interface.
0 commit comments