Skip to content

Commit ec65ed9

Browse files
authored
fix nil binding parameter formatcode (#1729)
1 parent d439afb commit ec65ed9

File tree

6 files changed

+231
-40
lines changed

6 files changed

+231
-40
lines changed

server/connection_data.go

Lines changed: 38 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,10 @@
1515
package server
1616

1717
import (
18-
"github.com/cockroachdb/errors"
18+
"strconv"
19+
"strings"
1920

21+
"github.com/cockroachdb/errors"
2022
"github.com/dolthub/go-mysql-server/sql"
2123
"github.com/dolthub/go-mysql-server/sql/expression"
2224
"github.com/dolthub/go-mysql-server/sql/plan"
@@ -111,7 +113,7 @@ func extractBindVarTypes(queryPlan sql.Node) ([]uint32, error) {
111113
inspectNode = queryPlan.Source
112114
}
113115

114-
types := make([]uint32, 0)
116+
types := make(map[string]uint32)
115117
var err error
116118
var extractBindVars func(n sql.Node, expr sql.Expression) bool
117119
extractBindVars = func(n sql.Node, expr sql.Expression) bool {
@@ -137,19 +139,23 @@ func extractBindVarTypes(queryPlan sql.Node) ([]uint32, error) {
137139
default:
138140
typOid, err = VitessTypeToObjectID(e.Type().Type())
139141
if err != nil {
140-
err = errors.Errorf("could not determine OID for placeholder %s: %w", e.Name, err)
142+
err = errors.Errorf("could not determine OID for placeholder %s: %e", e.Name, err)
141143
return false
142144
}
143145
}
144146
} else {
145147
// TODO: should remove usage non doltgres type
146148
typOid, err = VitessTypeToObjectID(e.Type().Type())
147149
if err != nil {
148-
err = errors.Errorf("could not determine OID for placeholder %s: %w", e.Name, err)
150+
err = errors.Errorf("could not determine OID for placeholder %s: %e", e.Name, err)
149151
return false
150152
}
151153
}
152-
types = append(types, typOid)
154+
if _, ok := types[e.Name]; ok {
155+
// sanity check
156+
err = errors.Errorf("double placeholder given for %s", e.Name)
157+
}
158+
types[e.Name] = typOid
153159
case *pgexprs.ExplicitCast:
154160
if bindVar, ok := e.Child().(*expression.BindVar); ok {
155161
var typOid uint32
@@ -158,11 +164,15 @@ func extractBindVarTypes(queryPlan sql.Node) ([]uint32, error) {
158164
} else {
159165
typOid, err = VitessTypeToObjectID(e.Type().Type())
160166
if err != nil {
161-
err = errors.Errorf("could not determine OID for placeholder %s: %w", bindVar.Name, err)
167+
err = errors.Errorf("could not determine OID for placeholder %s: %e", bindVar.Name, err)
162168
return false
163169
}
164170
}
165-
types = append(types, typOid)
171+
if _, ok = types[bindVar.Name]; ok {
172+
// sanity check
173+
err = errors.Errorf("double placeholder given for %s", bindVar.Name)
174+
}
175+
types[bindVar.Name] = typOid
166176
return false
167177
}
168178
// $1::text and similar get converted to a Convert expression wrapping the bindvar
@@ -171,18 +181,36 @@ func extractBindVarTypes(queryPlan sql.Node) ([]uint32, error) {
171181
var typOid uint32
172182
typOid, err = VitessTypeToObjectID(e.Type().Type())
173183
if err != nil {
174-
err = errors.Errorf("could not determine OID for placeholder %s: %w", bindVar.Name, err)
184+
err = errors.Errorf("could not determine OID for placeholder %s: %e", bindVar.Name, err)
175185
return false
176186
}
177-
types = append(types, typOid)
187+
if _, ok = types[bindVar.Name]; ok {
188+
// sanity check
189+
err = errors.Errorf("double placeholder given for %s", bindVar.Name)
190+
}
191+
types[bindVar.Name] = typOid
178192
return false
179193
}
180194
}
181195
return true
182196
}
183197

184198
transform.InspectExpressionsWithNode(inspectNode, extractBindVars)
185-
return types, err
199+
200+
// above finds types of bindvars in unordered form.
201+
// the list of types needs to be ordered as v1, v2, v3, etc.
202+
typesArr := make([]uint32, len(types))
203+
for i, t := range types {
204+
idx, err := strconv.ParseInt(strings.TrimPrefix(i, "v"), 10, 32)
205+
if err != nil {
206+
return nil, errors.Errorf("could not determine the index of placeholder %s: %e", i, err)
207+
}
208+
if int(idx-1) >= len(types) {
209+
return nil, errors.Errorf("could not determine the index of placeholder %s: %e", i, err)
210+
}
211+
typesArr[idx-1] = t
212+
}
213+
return typesArr, err
186214
}
187215

188216
// VitessTypeToObjectID returns a type, as defined by Vitess, into a type as defined by Postgres.

server/doltgres_handler.go

Lines changed: 93 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,15 @@ package server
1717
import (
1818
"context"
1919
"encoding/base64"
20+
"encoding/binary"
21+
"encoding/hex"
2022
goerrors "errors"
2123
"fmt"
2224
"io"
2325
"os"
2426
"regexp"
2527
"runtime/trace"
28+
"strconv"
2629
"sync"
2730
"time"
2831

@@ -40,6 +43,7 @@ import (
4043
"github.com/sirupsen/logrus"
4144

4245
"github.com/dolthub/doltgresql/core/id"
46+
"github.com/dolthub/doltgresql/postgres/parser/uuid"
4347
pgexprs "github.com/dolthub/doltgresql/server/expression"
4448
pgtypes "github.com/dolthub/doltgresql/server/types"
4549
)
@@ -119,7 +123,7 @@ func (h *DoltgresHandler) ComBind(ctx context.Context, c *mysql.Conn, query stri
119123
return nil, nil, err
120124
}
121125

122-
return queryPlan, schemaToFieldDescriptions(sqlCtx, queryPlan.Schema()), nil
126+
return queryPlan, schemaToFieldDescriptions(sqlCtx, queryPlan.Schema(), true), nil
123127
}
124128

125129
// ComExecuteBound implements the Handler interface.
@@ -135,7 +139,7 @@ func (h *DoltgresHandler) ComExecuteBound(ctx context.Context, conn *mysql.Conn,
135139
h.sel.QueryStarted()
136140
}
137141

138-
err := h.doQuery(ctx, conn, query, nil, analyzedPlan, h.executeBoundPlan, callback)
142+
err := h.doQuery(ctx, conn, query, nil, analyzedPlan, h.executeBoundPlan, callback, true)
139143
if err != nil {
140144
err = sql.CastSQLError(err)
141145
}
@@ -175,7 +179,7 @@ func (h *DoltgresHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, q
175179
},
176180
}
177181
} else {
178-
fields = schemaToFieldDescriptions(sqlCtx, analyzed.Schema())
182+
fields = schemaToFieldDescriptions(sqlCtx, analyzed.Schema(), true)
179183
}
180184
return analyzed, fields, nil
181185
}
@@ -188,7 +192,7 @@ func (h *DoltgresHandler) ComQuery(ctx context.Context, c *mysql.Conn, query str
188192
h.sel.QueryStarted()
189193
}
190194

191-
err := h.doQuery(ctx, c, query, parsed, nil, h.executeQuery, callback)
195+
err := h.doQuery(ctx, c, query, parsed, nil, h.executeQuery, callback, false)
192196
if err != nil {
193197
err = sql.CastSQLError(err)
194198
}
@@ -259,7 +263,11 @@ func (h *DoltgresHandler) NewContext(ctx context.Context, c *mysql.Conn, query s
259263
func (h *DoltgresHandler) convertBindParameters(ctx *sql.Context, types []uint32, formatCodes []int16, values [][]byte) (map[string]sqlparser.Expr, error) {
260264
bindings := make(map[string]sqlparser.Expr, len(values))
261265
for i := range values {
262-
bindVarString, err := h.convertBindParameterToString(types[i], values[i], formatCodes[i])
266+
formatCode := int16(0)
267+
if formatCodes != nil {
268+
formatCode = formatCodes[i]
269+
}
270+
bindVarString, err := h.convertBindParameterToString(types[i], values[i], formatCode)
263271
if err != nil {
264272
return nil, err
265273
}
@@ -312,6 +320,20 @@ func (h *DoltgresHandler) convertBindParameterToString(typ uint32, value []byte,
312320
} else {
313321
bindVarString = "false"
314322
}
323+
case typ == pgtype.ByteaOID && isBinaryFormat:
324+
bindVarString = `\x` + hex.EncodeToString(value)
325+
case typ == pgtype.Int2OID && isBinaryFormat:
326+
bindVarString = strconv.FormatInt(int64(binary.BigEndian.Uint16(value)), 10)
327+
case typ == pgtype.Int4OID && isBinaryFormat:
328+
bindVarString = strconv.FormatInt(int64(binary.BigEndian.Uint32(value)), 10)
329+
case typ == pgtype.Int8OID && isBinaryFormat:
330+
bindVarString = strconv.FormatInt(int64(binary.BigEndian.Uint64(value)), 10)
331+
case typ == pgtype.UUIDOID && isBinaryFormat:
332+
u, err := uuid.FromBytes(value)
333+
if err != nil {
334+
return "", err
335+
}
336+
bindVarString = u.String()
315337
default:
316338
// For text format or types that can handle binary-to-string conversion
317339
if err := h.pgTypeMap.Scan(typ, formatCode, value, &bindVarString); err != nil {
@@ -324,7 +346,7 @@ func (h *DoltgresHandler) convertBindParameterToString(typ uint32, value []byte,
324346

325347
var queryLoggingRegex = regexp.MustCompile(`[\r\n\t ]+`)
326348

327-
func (h *DoltgresHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement, analyzedPlan sql.Node, queryExec QueryExecutor, callback func(*sql.Context, *Result) error) error {
349+
func (h *DoltgresHandler) doQuery(ctx context.Context, c *mysql.Conn, query string, parsed sqlparser.Statement, analyzedPlan sql.Node, queryExec QueryExecutor, callback func(*sql.Context, *Result) error, isExecute bool) error {
328350
sqlCtx, err := h.sm.NewContextWithQuery(ctx, c, query)
329351
if err != nil {
330352
return err
@@ -374,11 +396,11 @@ func (h *DoltgresHandler) doQuery(ctx context.Context, c *mysql.Conn, query stri
374396
} else if schema == nil {
375397
r, err = resultForEmptyIter(sqlCtx, rowIter)
376398
} else if analyzer.FlagIsSet(qFlags, sql.QFlagMax1Row) {
377-
resultFields := schemaToFieldDescriptions(sqlCtx, schema)
378-
r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields)
399+
resultFields := schemaToFieldDescriptions(sqlCtx, schema, isExecute)
400+
r, err = resultForMax1RowIter(sqlCtx, schema, rowIter, resultFields, isExecute)
379401
} else {
380-
resultFields := schemaToFieldDescriptions(sqlCtx, schema)
381-
r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, schema, rowIter, callback, resultFields)
402+
resultFields := schemaToFieldDescriptions(sqlCtx, schema, isExecute)
403+
r, processedAtLeastOneBatch, err = h.resultForDefaultIter(sqlCtx, schema, rowIter, callback, resultFields, isExecute)
382404
}
383405
if err != nil {
384406
return err
@@ -432,18 +454,29 @@ func (h *DoltgresHandler) maybeReleaseAllLocks(c *mysql.Conn) {
432454
// These nodes will eventually return an OK result, but their intermediate forms here return a different schema
433455
// than they will at execution time.
434456
func nodeReturnsOkResultSchema(node sql.Node) bool {
435-
switch node.(type) {
436-
case *plan.InsertInto, *plan.Update, *plan.UpdateJoin, *plan.DeleteFrom:
457+
switch n := node.(type) {
458+
case *plan.InsertInto:
459+
return len(n.Returning) == 0
460+
case *plan.Update:
461+
return len(n.Returning) == 0
462+
case *plan.DeleteFrom, *plan.UpdateJoin:
437463
return true
438464
}
439465
return types.IsOkResultSchema(node.Schema())
440466
}
441467

442-
func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema) []pgproto3.FieldDescription {
468+
func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema, isPrepared bool) []pgproto3.FieldDescription {
443469
fields := make([]pgproto3.FieldDescription, len(s))
444470
for i, c := range s {
445471
var oid uint32
446472
var typmod = int32(-1)
473+
474+
// "Format" field: The format code being used for the field.
475+
// Currently, will be zero (text) or one (binary).
476+
// In a RowDescription returned from the statement variant of Describe,
477+
// the format code is not yet known and will always be zero.
478+
var formatCode = int16(0)
479+
447480
var err error
448481
if doltgresType, ok := c.Type.(*pgtypes.DoltgresType); ok {
449482
if doltgresType.TypType == pgtypes.TypeType_Domain {
@@ -452,26 +485,27 @@ func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema) []pgproto3.FieldD
452485
oid = id.Cache().ToOID(doltgresType.ID.AsId())
453486
}
454487
typmod = doltgresType.GetAttTypMod() // pg_attribute.atttypmod
488+
if isPrepared {
489+
switch doltgresType.ID {
490+
case pgtypes.Bytea.ID, pgtypes.Int16.ID, pgtypes.Int32.ID, pgtypes.Int64.ID, pgtypes.Uuid.ID:
491+
formatCode = 1
492+
}
493+
}
455494
} else {
456495
oid, err = VitessTypeToObjectID(c.Type.Type())
457496
if err != nil {
458497
panic(err)
459498
}
460499
}
461500

462-
// "Format" field: The format code being used for the field.
463-
// Currently, will be zero (text) or one (binary).
464-
// In a RowDescription returned from the statement variant of Describe,
465-
// the format code is not yet known and will always be zero.
466-
467501
fields[i] = pgproto3.FieldDescription{
468502
Name: []byte(c.Name),
469503
TableOID: uint32(0),
470504
TableAttributeNumber: uint16(0),
471505
DataTypeOID: oid,
472506
DataTypeSize: int16(c.Type.MaxTextResponseByteLength(ctx)),
473507
TypeModifier: typmod,
474-
Format: int16(0),
508+
Format: formatCode,
475509
}
476510
}
477511

@@ -515,7 +549,7 @@ func resultForEmptyIter(ctx *sql.Context, iter sql.RowIter) (*Result, error) {
515549
}
516550

517551
// resultForMax1RowIter ensures that an empty iterator returns at most one row
518-
func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, resultFields []pgproto3.FieldDescription) (*Result, error) {
552+
func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, resultFields []pgproto3.FieldDescription, isExecute bool) (*Result, error) {
519553
defer trace.StartRegion(ctx, "DoltgresHandler.resultForMax1RowIter").End()
520554
row, err := iter.Next(ctx)
521555
if err == io.EOF {
@@ -531,7 +565,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,
531565
return nil, err
532566
}
533567

534-
outputRow, err := rowToBytes(ctx, schema, row)
568+
outputRow, err := rowToBytes(ctx, schema, row, isExecute)
535569
if err != nil {
536570
return nil, err
537571
}
@@ -543,7 +577,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,
543577

544578
// resultForDefaultIter reads batches of rows from the iterator
545579
// and writes results into the callback function.
546-
func (h *DoltgresHandler) resultForDefaultIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, callback func(*sql.Context, *Result) error, resultFields []pgproto3.FieldDescription) (*Result, bool, error) {
580+
func (h *DoltgresHandler) resultForDefaultIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter, callback func(*sql.Context, *Result) error, resultFields []pgproto3.FieldDescription, isExecute bool) (*Result, bool, error) {
547581
defer trace.StartRegion(ctx, "DoltgresHandler.resultForDefaultIter").End()
548582

549583
var r *Result
@@ -636,7 +670,7 @@ func (h *DoltgresHandler) resultForDefaultIter(ctx *sql.Context, schema sql.Sche
636670
continue
637671
}
638672

639-
outputRow, err := rowToBytes(ctx, schema, row)
673+
outputRow, err := rowToBytes(ctx, schema, row, isExecute)
640674
if err != nil {
641675
return err
642676
}
@@ -678,7 +712,7 @@ func (h *DoltgresHandler) resultForDefaultIter(ctx *sql.Context, schema sql.Sche
678712
return r, processedAtLeastOneBatch, nil
679713
}
680714

681-
func rowToBytes(ctx *sql.Context, s sql.Schema, row sql.Row) ([][]byte, error) {
715+
func rowToBytes(ctx *sql.Context, s sql.Schema, row sql.Row, isExecute bool) ([][]byte, error) {
682716
if len(row) == 0 {
683717
return nil, nil
684718
}
@@ -691,6 +725,41 @@ func rowToBytes(ctx *sql.Context, s sql.Schema, row sql.Row) ([][]byte, error) {
691725
if v == nil {
692726
o[i] = nil
693727
} else {
728+
if isExecute {
729+
switch d := s[i].Type.(type) {
730+
case *pgtypes.DoltgresType:
731+
switch d.ID {
732+
// This is the list of types to use binary mode for when receiving them
733+
// through a prepared statement. If a type appears in this list, it
734+
// must also be implemented in binaryDecode in encode.go.
735+
case pgtypes.Bytea.ID:
736+
o[i] = v.([]byte)
737+
continue
738+
case pgtypes.Int64.ID:
739+
buf := make([]byte, 8)
740+
binary.BigEndian.PutUint64(buf, uint64(v.(int64)))
741+
o[i] = buf
742+
continue
743+
case pgtypes.Int32.ID:
744+
buf := make([]byte, 4)
745+
binary.BigEndian.PutUint32(buf, uint32(v.(int32)))
746+
o[i] = buf
747+
continue
748+
case pgtypes.Int16.ID:
749+
buf := make([]byte, 2)
750+
binary.BigEndian.PutUint16(buf, uint16(v.(int16)))
751+
o[i] = buf
752+
continue
753+
case pgtypes.Uuid.ID:
754+
buf, err := v.(uuid.UUID).MarshalBinary()
755+
if err != nil {
756+
return nil, err
757+
}
758+
o[i] = buf
759+
continue
760+
}
761+
}
762+
}
694763
val, err := s[i].Type.SQL(ctx, []byte{}, v)
695764
if err != nil {
696765
return nil, err

0 commit comments

Comments
 (0)