Skip to content

Commit 997c868

Browse files
committed
merge main
2 parents 793f52d + ec78760 commit 997c868

File tree

328 files changed

+8884
-3270
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

328 files changed

+8884
-3270
lines changed

driver/rows.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ func (r *Rows) convert(col int, v driver.Value) interface{} {
114114
}
115115
}
116116

117-
sqlValue, _, err := r.cols[col].Type.Convert(v)
117+
sqlValue, _, err := r.cols[col].Type.Convert(r.ctx, v)
118118
if err != nil {
119119
break
120120
}

driver/value.go

Lines changed: 0 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -16,56 +16,13 @@ package driver
1616

1717
import (
1818
"database/sql/driver"
19-
"errors"
20-
"fmt"
2119
"strconv"
2220
"time"
2321

2422
"github.com/dolthub/vitess/go/sqltypes"
2523
"github.com/dolthub/vitess/go/vt/sqlparser"
26-
27-
"github.com/dolthub/go-mysql-server/sql"
28-
"github.com/dolthub/go-mysql-server/sql/expression"
29-
"github.com/dolthub/go-mysql-server/sql/types"
3024
)
3125

32-
// ErrUnsupportedType is returned when a query argument of an unsupported type is passed to a statement
33-
var ErrUnsupportedType = errors.New("unsupported type")
34-
35-
func valueToExpr(v driver.Value) (sql.Expression, error) {
36-
if v == nil {
37-
return expression.NewLiteral(nil, types.Null), nil
38-
}
39-
40-
var typ sql.Type
41-
var err error
42-
switch v := v.(type) {
43-
case int64:
44-
typ = types.Int64
45-
case float64:
46-
typ = types.Float64
47-
case bool:
48-
typ = types.Boolean
49-
case []byte:
50-
typ, err = types.CreateBinary(sqltypes.Blob, int64(len(v)))
51-
case string:
52-
typ, err = types.CreateStringWithDefaults(sqltypes.Text, int64(len(v)))
53-
case time.Time:
54-
typ = types.Datetime
55-
default:
56-
return nil, fmt.Errorf("%w: %T", ErrUnsupportedType, v)
57-
}
58-
if err != nil {
59-
return nil, err
60-
}
61-
62-
c, _, err := typ.Convert(v)
63-
if err != nil {
64-
return nil, err
65-
}
66-
return expression.NewLiteral(c, typ), nil
67-
}
68-
6926
func valuesToBindings(vals []driver.Value) (map[string]sqlparser.Expr, error) {
7027
if len(vals) == 0 {
7128
return nil, nil

engine.go

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ type Engine struct {
150150
Parser sql.Parser
151151
}
152152

153-
var _ analyzer.StatementRunner = (*Engine)(nil)
153+
var _ sql.StatementRunner = (*Engine)(nil)
154154

155155
type ColumnWithRawDefault struct {
156156
SqlColumn *sql.Column
@@ -273,7 +273,7 @@ func clearWarnings(ctx *sql.Context, node sql.Node) {
273273
}
274274
}
275275

276-
func bindingsToExprs(bindings map[string]*querypb.BindVariable) (map[string]sql.Expression, error) {
276+
func bindingsToExprs(ctx *sql.Context, bindings map[string]*querypb.BindVariable) (map[string]sql.Expression, error) {
277277
res := make(map[string]sql.Expression, len(bindings))
278278
for k, v := range bindings {
279279
v, err := sqltypes.NewValue(v.Type, v.Value)
@@ -282,7 +282,7 @@ func bindingsToExprs(bindings map[string]*querypb.BindVariable) (map[string]sql.
282282
}
283283
switch {
284284
case v.Type() == sqltypes.Year:
285-
v, _, err := types.Year.Convert(string(v.ToBytes()))
285+
v, _, err := types.Year.Convert(ctx, string(v.ToBytes()))
286286
if err != nil {
287287
return nil, err
288288
}
@@ -293,7 +293,7 @@ func bindingsToExprs(bindings map[string]*querypb.BindVariable) (map[string]sql.
293293
return nil, err
294294
}
295295
t := types.Int64
296-
c, _, err := t.Convert(v)
296+
c, _, err := t.Convert(ctx, v)
297297
if err != nil {
298298
return nil, err
299299
}
@@ -304,7 +304,7 @@ func bindingsToExprs(bindings map[string]*querypb.BindVariable) (map[string]sql.
304304
return nil, err
305305
}
306306
t := types.Uint64
307-
c, _, err := t.Convert(v)
307+
c, _, err := t.Convert(ctx, v)
308308
if err != nil {
309309
return nil, err
310310
}
@@ -315,20 +315,20 @@ func bindingsToExprs(bindings map[string]*querypb.BindVariable) (map[string]sql.
315315
return nil, err
316316
}
317317
t := types.Float64
318-
c, _, err := t.Convert(v)
318+
c, _, err := t.Convert(ctx, v)
319319
if err != nil {
320320
return nil, err
321321
}
322322
res[k] = expression.NewLiteral(c, t)
323323
case v.Type() == sqltypes.Decimal:
324-
v, _, err := types.InternalDecimalType.Convert(string(v.ToBytes()))
324+
v, _, err := types.InternalDecimalType.Convert(ctx, string(v.ToBytes()))
325325
if err != nil {
326326
return nil, err
327327
}
328328
res[k] = expression.NewLiteral(v, types.InternalDecimalType)
329329
case v.Type() == sqltypes.Bit:
330330
t := types.MustCreateBitType(types.BitTypeMaxBits)
331-
v, _, err := t.Convert(v.ToBytes())
331+
v, _, err := t.Convert(ctx, v.ToBytes())
332332
if err != nil {
333333
return nil, err
334334
}
@@ -340,7 +340,7 @@ func bindingsToExprs(bindings map[string]*querypb.BindVariable) (map[string]sql.
340340
if err != nil {
341341
return nil, err
342342
}
343-
v, _, err := t.Convert(v.ToBytes())
343+
v, _, err := t.Convert(ctx, v.ToBytes())
344344
if err != nil {
345345
return nil, err
346346
}
@@ -350,7 +350,7 @@ func bindingsToExprs(bindings map[string]*querypb.BindVariable) (map[string]sql.
350350
if err != nil {
351351
return nil, err
352352
}
353-
v, _, err := t.Convert(v.ToBytes())
353+
v, _, err := t.Convert(ctx, v.ToBytes())
354354
if err != nil {
355355
return nil, err
356356
}
@@ -364,14 +364,14 @@ func bindingsToExprs(bindings map[string]*querypb.BindVariable) (map[string]sql.
364364
if err != nil {
365365
return nil, err
366366
}
367-
v, _, err := t.Convert(string(v.ToBytes()))
367+
v, _, err := t.Convert(ctx, string(v.ToBytes()))
368368
if err != nil {
369369
return nil, err
370370
}
371371
res[k] = expression.NewLiteral(v, t)
372372
case v.Type() == sqltypes.Time:
373373
t := types.Time
374-
v, _, err := t.Convert(string(v.ToBytes()))
374+
v, _, err := t.Convert(ctx, string(v.ToBytes()))
375375
if err != nil {
376376
return nil, err
377377
}
@@ -686,7 +686,7 @@ func (e *Engine) bindExecuteQueryNode(ctx *sql.Context, query string, eq *plan.E
686686
t = types.Null
687687
}
688688
if val != nil {
689-
val, _, err = t.Promote().Convert(val)
689+
val, _, err = t.Promote().Convert(ctx, val)
690690
if err != nil {
691691
return nil, nil
692692
}

engine_test.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,9 +139,10 @@ func TestBindingsToExprs(t *testing.T) {
139139
},
140140
}
141141

142+
ctx := sql.NewEmptyContext()
142143
for _, c := range cases {
143144
t.Run(c.Name, func(t *testing.T) {
144-
res, err := bindingsToExprs(c.Bindings)
145+
res, err := bindingsToExprs(ctx, c.Bindings)
145146
if !c.Err {
146147
require.NoError(t, err)
147148
require.Equal(t, c.Result, res)

enginetest/engine_only_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -701,11 +701,11 @@ func TestCollationCoercion(t *testing.T) {
701701
require.Equal(t, 1, len(rows))
702702
require.Equal(t, 1, len(rows[0]))
703703
if i == 0 {
704-
num, _, err := types.Int64.Convert(rows[0][0])
704+
num, _, err := types.Int64.Convert(ctx, rows[0][0])
705705
require.NoError(t, err)
706706
require.Equal(t, test.Coercibility, num.(int64))
707707
} else {
708-
str, _, err := types.LongText.Convert(rows[0][0])
708+
str, _, err := types.LongText.Convert(ctx, rows[0][0])
709709
require.NoError(t, err)
710710
require.Equal(t, test.Collation.Name(), str.(string))
711711
}

enginetest/enginetests.go

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ func TestBrokenQueries(t *testing.T, harness Harness) {
263263
// queries during debugging.
264264
func RunQueryTests(t *testing.T, harness Harness, queries []queries.QueryTest) {
265265
for _, tt := range queries {
266-
TestQuery(t, harness, tt.Query, tt.Expected, tt.ExpectedColumns, nil)
266+
testQuery(t, harness, tt.Query, tt.Expected, tt.ExpectedColumns, nil, tt.WrapBehavior)
267267
}
268268
}
269269

@@ -810,7 +810,9 @@ func TestOrderByGroupBy(t *testing.T, harness Harness) {
810810
panic(fmt.Sprintf("unexpected type %T", v))
811811
}
812812

813-
team := row[1].(string)
813+
team, ok, err := sql.Unwrap[string](ctx, row[1])
814+
require.NoError(t, err)
815+
require.True(t, ok)
814816
switch team {
815817
case "red":
816818
require.True(t, val == 3 || val == 4)
@@ -846,7 +848,9 @@ func TestOrderByGroupBy(t *testing.T, harness Harness) {
846848
panic(fmt.Sprintf("unexpected type %T", v))
847849
}
848850

849-
team := row[1].(string)
851+
team, ok, err := sql.Unwrap[string](ctx, row[1])
852+
require.True(t, ok)
853+
require.NoError(t, err)
850854
switch team {
851855
case "red":
852856
require.True(t, val == 3 || val == 4)
@@ -2078,7 +2082,7 @@ func TestUserPrivileges(t *testing.T, harness ClientHarness) {
20782082
// See the comment on QuickPrivilegeTest for a more in-depth explanation, but essentially we treat
20792083
// nil in script.Expected as matching "any" non-error result.
20802084
if script.Expected != nil && (rows != nil || len(script.Expected) != 0) {
2081-
CheckResults(t, harness, script.Expected, nil, sch, rows, lastQuery, engine)
2085+
CheckResults(ctx, t, harness, script.Expected, nil, sch, rows, lastQuery, engine)
20822086
}
20832087
})
20842088
}
@@ -6039,15 +6043,16 @@ func findRole(toUser string, roles []*mysql_db.RoleEdge) *mysql_db.RoleEdge {
60396043
}
60406044

60416045
func TestBlobs(t *testing.T, h Harness) {
6046+
ctx := sql.NewEmptyContext()
60426047
h.Setup(setup.MydbData, setup.BlobData, setup.MytableData)
60436048

60446049
// By default, strict_mysql_compatibility is disabled, but these tests require it to be enabled.
6045-
err := sql.SystemVariables.SetGlobal("strict_mysql_compatibility", int8(1))
6050+
err := sql.SystemVariables.SetGlobal(ctx, "strict_mysql_compatibility", int8(1))
60466051
require.NoError(t, err)
60476052
for _, tt := range queries.BlobErrors {
60486053
runQueryErrorTest(t, h, tt)
60496054
}
6050-
err = sql.SystemVariables.SetGlobal("strict_mysql_compatibility", int8(0))
6055+
err = sql.SystemVariables.SetGlobal(ctx, "strict_mysql_compatibility", int8(0))
60516056
require.NoError(t, err)
60526057

60536058
e := mustNewEngine(t, h)

0 commit comments

Comments
 (0)