Skip to content

Commit 8d2ac20

Browse files
author
James Cor
committed
opt
1 parent f3068cf commit 8d2ac20

File tree

2 files changed

+23
-20
lines changed

2 files changed

+23
-20
lines changed

enginetest/enginetests.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5851,7 +5851,7 @@ func TestTypesOverWire(t *testing.T, harness ClientHarness, sessionBuilder serve
58515851
break
58525852
}
58535853
expectedEngineRow := make([]*string, len(engineRow))
5854-
row, err := server.RowToSQL(ctx, sch, engineRow, nil, nil)
5854+
row, err := server.RowToSQL(ctx, sch, engineRow, nil, nil, nil)
58555855
if !assert.NoError(t, err) {
58565856
break
58575857
}

server/handler.go

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -491,7 +491,7 @@ func (h *Handler) doQuery(
491491
// zero/single return schema use spooling shortcut
492492
if types.IsOkResultSchema(schema) {
493493
r, err = resultForOkIter(sqlCtx, rowIter)
494-
} else if schema == nil {
494+
} else if len(schema) == 0 {
495495
r, err = resultForEmptyIter(sqlCtx, rowIter, resultFields)
496496
} else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) {
497497
r, bm, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, bm)
@@ -615,7 +615,11 @@ func resultForMax1RowIter(
615615
}
616616

617617
bm = sql.NewByteBufferManager()
618-
outputRow, err := RowToSQL(ctx, schema, row, nil, bm)
618+
maxCaps := make([]int, len(schema))
619+
for i, col := range schema {
620+
maxCaps[i] = getMaxTypeCapacity(ctx, col.Type)
621+
}
622+
outputRow, err := RowToSQL(ctx, schema, row, nil, maxCaps, bm)
619623
if err != nil {
620624
// Important to return ByteBufferManager even in error, as we still need to release any allocated memory.
621625
return nil, bm, err
@@ -708,6 +712,12 @@ func (h *Handler) resultForDefaultIter(
708712
var resChan = make(chan managedResult, 4)
709713
var res *sqltypes.Result
710714
var bm *sql.ByteBufferManager
715+
716+
// TODO: find good place to put this
717+
maxCaps := make([]int, len(schema))
718+
for i, col := range schema {
719+
maxCaps[i] = getMaxTypeCapacity(ctx, col.Type)
720+
}
711721
eg.Go(func() (err error) {
712722
defer pan2err(&err)
713723
defer wg.Done()
@@ -747,7 +757,7 @@ func (h *Handler) resultForDefaultIter(
747757
}
748758

749759
var outRow []sqltypes.Value
750-
outRow, err = RowToSQL(ctx, schema, row, projs, bm)
760+
outRow, err = RowToSQL(ctx, schema, row, projs, maxCaps, bm)
751761
if err != nil {
752762
return err
753763
}
@@ -895,6 +905,7 @@ func (h *Handler) resultForValueRowIter(
895905
var resChan = make(chan bufferedResult, 4)
896906
var res *sqltypes.Result
897907
var bm *sql.ByteBufferManager
908+
maxCaps := make([]int, len(schema))
898909
eg.Go(func() (err error) {
899910
defer pan2err(&err)
900911
defer close(resChan)
@@ -926,7 +937,7 @@ func (h *Handler) resultForValueRowIter(
926937
}
927938

928939
var outRow []sqltypes.Value
929-
outRow, err = RowValueToSQLValues(ctx, schema, row, bm)
940+
outRow, err = RowValueToSQLValues(ctx, schema, row, maxCaps, bm)
930941
if err != nil {
931942
return err
932943
}
@@ -1286,8 +1297,7 @@ func getMaxTypeCapacity(ctx *sql.Context, typ sql.Type) (res int) {
12861297
return
12871298
}
12881299

1289-
func toSQL(ctx *sql.Context, typ sql.Type, bm *sql.ByteBufferManager, val any) (sqltypes.Value, error) {
1290-
maxCap := getMaxTypeCapacity(ctx, typ)
1300+
func toSQL(ctx *sql.Context, typ sql.Type, maxCap int, bm *sql.ByteBufferManager, val any) (sqltypes.Value, error) {
12911301
if maxCap == 0 {
12921302
return typ.SQL(ctx, nil, val)
12931303
}
@@ -1299,7 +1309,7 @@ func toSQL(ctx *sql.Context, typ sql.Type, bm *sql.ByteBufferManager, val any) (
12991309
return ret, nil
13001310
}
13011311

1302-
func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Expression, bm *sql.ByteBufferManager) ([]sqltypes.Value, error) {
1312+
func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Expression, maxCaps []int, bm *sql.ByteBufferManager) ([]sqltypes.Value, error) {
13031313
// need to make sure the schema is not null as some plan schema is defined as null (e.g. IfElseBlock) // TODO: do we really?
13041314
if len(sch) == 0 {
13051315
return []sqltypes.Value{}, nil
@@ -1313,7 +1323,7 @@ func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Express
13131323
outVals[i] = sqltypes.NULL
13141324
continue
13151325
}
1316-
outVals[i], err = toSQL(ctx, col.Type, bm, row[i])
1326+
outVals[i], err = toSQL(ctx, col.Type, maxCaps[i], bm, row[i])
13171327
if err != nil {
13181328
return nil, err
13191329
}
@@ -1331,20 +1341,15 @@ func RowToSQL(ctx *sql.Context, sch sql.Schema, row sql.Row, projs []sql.Express
13311341
outVals[i] = sqltypes.NULL
13321342
continue
13331343
}
1334-
outVals[i], err = toSQL(ctx, col.Type, bm, field)
1344+
outVals[i], err = toSQL(ctx, col.Type, maxCaps[i], bm, field)
13351345
if err != nil {
13361346
return nil, err
13371347
}
13381348
}
13391349
return outVals, nil
13401350
}
13411351

1342-
func RowValueToSQLValues(ctx *sql.Context, sch sql.Schema, row sql.ValueRow, bm *sql.ByteBufferManager) ([]sqltypes.Value, error) {
1343-
// TODO: we check for empty schema in doQuery, shouldn't need to check here
1344-
if len(sch) == 0 {
1345-
return []sqltypes.Value{}, nil
1346-
}
1347-
1352+
func RowValueToSQLValues(ctx *sql.Context, sch sql.Schema, row sql.ValueRow, maxCaps []int, bm *sql.ByteBufferManager) ([]sqltypes.Value, error) {
13481353
var err error
13491354
outVals := make([]sqltypes.Value, len(sch))
13501355
for i, col := range sch {
@@ -1359,17 +1364,15 @@ func RowValueToSQLValues(ctx *sql.Context, sch sql.Schema, row sql.ValueRow, bm
13591364
continue
13601365
}
13611366

1362-
// TODO: schema remains constant throughout this query, so no need to recalc this every time
1363-
maxCap := getMaxTypeCapacity(ctx, valType)
1364-
if maxCap == 0 {
1367+
if maxCaps[i] == 0 {
13651368
outVals[i], err = valType.SQLValue(ctx, row[i], nil)
13661369
if err != nil {
13671370
return nil, err
13681371
}
13691372
continue
13701373
}
13711374

1372-
outVals[i], err = valType.SQLValue(ctx, row[i], bm.Get(maxCap))
1375+
outVals[i], err = valType.SQLValue(ctx, row[i], bm.Get(maxCaps[i]))
13731376
if err != nil {
13741377
return nil, err
13751378
}

0 commit comments

Comments
 (0)