@@ -17,12 +17,15 @@ package server
1717import (
1818 "context"
1919 "encoding/base64"
20+ "encoding/binary"
21+ "encoding/hex"
2022 goerrors "errors"
2123 "fmt"
2224 "io"
2325 "os"
2426 "regexp"
2527 "runtime/trace"
28+ "strconv"
2629 "sync"
2730 "time"
2831
@@ -40,6 +43,7 @@ import (
4043 "github.com/sirupsen/logrus"
4144
4245 "github.com/dolthub/doltgresql/core/id"
46+ "github.com/dolthub/doltgresql/postgres/parser/uuid"
4347 pgexprs "github.com/dolthub/doltgresql/server/expression"
4448 pgtypes "github.com/dolthub/doltgresql/server/types"
4549)
@@ -119,7 +123,7 @@ func (h *DoltgresHandler) ComBind(ctx context.Context, c *mysql.Conn, query stri
119123 return nil , nil , err
120124 }
121125
122- return queryPlan , schemaToFieldDescriptions (sqlCtx , queryPlan .Schema ()), nil
126+ return queryPlan , schemaToFieldDescriptions (sqlCtx , queryPlan .Schema (), true ), nil
123127}
124128
125129// ComExecuteBound implements the Handler interface.
@@ -135,7 +139,7 @@ func (h *DoltgresHandler) ComExecuteBound(ctx context.Context, conn *mysql.Conn,
135139 h .sel .QueryStarted ()
136140 }
137141
138- err := h .doQuery (ctx , conn , query , nil , analyzedPlan , h .executeBoundPlan , callback )
142+ err := h .doQuery (ctx , conn , query , nil , analyzedPlan , h .executeBoundPlan , callback , true )
139143 if err != nil {
140144 err = sql .CastSQLError (err )
141145 }
@@ -175,7 +179,7 @@ func (h *DoltgresHandler) ComPrepareParsed(ctx context.Context, c *mysql.Conn, q
175179 },
176180 }
177181 } else {
178- fields = schemaToFieldDescriptions (sqlCtx , analyzed .Schema ())
182+ fields = schemaToFieldDescriptions (sqlCtx , analyzed .Schema (), true )
179183 }
180184 return analyzed , fields , nil
181185}
@@ -188,7 +192,7 @@ func (h *DoltgresHandler) ComQuery(ctx context.Context, c *mysql.Conn, query str
188192 h .sel .QueryStarted ()
189193 }
190194
191- err := h .doQuery (ctx , c , query , parsed , nil , h .executeQuery , callback )
195+ err := h .doQuery (ctx , c , query , parsed , nil , h .executeQuery , callback , false )
192196 if err != nil {
193197 err = sql .CastSQLError (err )
194198 }
@@ -259,7 +263,11 @@ func (h *DoltgresHandler) NewContext(ctx context.Context, c *mysql.Conn, query s
259263func (h * DoltgresHandler ) convertBindParameters (ctx * sql.Context , types []uint32 , formatCodes []int16 , values [][]byte ) (map [string ]sqlparser.Expr , error ) {
260264 bindings := make (map [string ]sqlparser.Expr , len (values ))
261265 for i := range values {
262- bindVarString , err := h .convertBindParameterToString (types [i ], values [i ], formatCodes [i ])
266+ formatCode := int16 (0 )
267+ if formatCodes != nil {
268+ formatCode = formatCodes [i ]
269+ }
270+ bindVarString , err := h .convertBindParameterToString (types [i ], values [i ], formatCode )
263271 if err != nil {
264272 return nil , err
265273 }
@@ -312,6 +320,20 @@ func (h *DoltgresHandler) convertBindParameterToString(typ uint32, value []byte,
312320 } else {
313321 bindVarString = "false"
314322 }
323+ case typ == pgtype .ByteaOID && isBinaryFormat :
324+ bindVarString = `\x` + hex .EncodeToString (value )
325+ case typ == pgtype .Int2OID && isBinaryFormat :
326+ bindVarString = strconv .FormatInt (int64 (binary .BigEndian .Uint16 (value )), 10 )
327+ case typ == pgtype .Int4OID && isBinaryFormat :
328+ bindVarString = strconv .FormatInt (int64 (binary .BigEndian .Uint32 (value )), 10 )
329+ case typ == pgtype .Int8OID && isBinaryFormat :
330+ bindVarString = strconv .FormatInt (int64 (binary .BigEndian .Uint64 (value )), 10 )
331+ case typ == pgtype .UUIDOID && isBinaryFormat :
332+ u , err := uuid .FromBytes (value )
333+ if err != nil {
334+ return "" , err
335+ }
336+ bindVarString = u .String ()
315337 default :
316338 // For text format or types that can handle binary-to-string conversion
317339 if err := h .pgTypeMap .Scan (typ , formatCode , value , & bindVarString ); err != nil {
@@ -324,7 +346,7 @@ func (h *DoltgresHandler) convertBindParameterToString(typ uint32, value []byte,
324346
325347var queryLoggingRegex = regexp .MustCompile (`[\r\n\t ]+` )
326348
327- func (h * DoltgresHandler ) doQuery (ctx context.Context , c * mysql.Conn , query string , parsed sqlparser.Statement , analyzedPlan sql.Node , queryExec QueryExecutor , callback func (* sql.Context , * Result ) error ) error {
349+ func (h * DoltgresHandler ) doQuery (ctx context.Context , c * mysql.Conn , query string , parsed sqlparser.Statement , analyzedPlan sql.Node , queryExec QueryExecutor , callback func (* sql.Context , * Result ) error , isExecute bool ) error {
328350 sqlCtx , err := h .sm .NewContextWithQuery (ctx , c , query )
329351 if err != nil {
330352 return err
@@ -374,11 +396,11 @@ func (h *DoltgresHandler) doQuery(ctx context.Context, c *mysql.Conn, query stri
374396 } else if schema == nil {
375397 r , err = resultForEmptyIter (sqlCtx , rowIter )
376398 } else if analyzer .FlagIsSet (qFlags , sql .QFlagMax1Row ) {
377- resultFields := schemaToFieldDescriptions (sqlCtx , schema )
378- r , err = resultForMax1RowIter (sqlCtx , schema , rowIter , resultFields )
399+ resultFields := schemaToFieldDescriptions (sqlCtx , schema , isExecute )
400+ r , err = resultForMax1RowIter (sqlCtx , schema , rowIter , resultFields , isExecute )
379401 } else {
380- resultFields := schemaToFieldDescriptions (sqlCtx , schema )
381- r , processedAtLeastOneBatch , err = h .resultForDefaultIter (sqlCtx , schema , rowIter , callback , resultFields )
402+ resultFields := schemaToFieldDescriptions (sqlCtx , schema , isExecute )
403+ r , processedAtLeastOneBatch , err = h .resultForDefaultIter (sqlCtx , schema , rowIter , callback , resultFields , isExecute )
382404 }
383405 if err != nil {
384406 return err
@@ -432,18 +454,29 @@ func (h *DoltgresHandler) maybeReleaseAllLocks(c *mysql.Conn) {
432454// These nodes will eventually return an OK result, but their intermediate forms here return a different schema
433455// than they will at execution time.
434456func nodeReturnsOkResultSchema (node sql.Node ) bool {
435- switch node .(type ) {
436- case * plan.InsertInto , * plan.Update , * plan.UpdateJoin , * plan.DeleteFrom :
457+ switch n := node .(type ) {
458+ case * plan.InsertInto :
459+ return len (n .Returning ) == 0
460+ case * plan.Update :
461+ return len (n .Returning ) == 0
462+ case * plan.DeleteFrom , * plan.UpdateJoin :
437463 return true
438464 }
439465 return types .IsOkResultSchema (node .Schema ())
440466}
441467
442- func schemaToFieldDescriptions (ctx * sql.Context , s sql.Schema ) []pgproto3.FieldDescription {
468+ func schemaToFieldDescriptions (ctx * sql.Context , s sql.Schema , isPrepared bool ) []pgproto3.FieldDescription {
443469 fields := make ([]pgproto3.FieldDescription , len (s ))
444470 for i , c := range s {
445471 var oid uint32
446472 var typmod = int32 (- 1 )
473+
474+ // "Format" field: The format code being used for the field.
475+ // Currently, will be zero (text) or one (binary).
476+ // In a RowDescription returned from the statement variant of Describe,
477+ // the format code is not yet known and will always be zero.
478+ var formatCode = int16 (0 )
479+
447480 var err error
448481 if doltgresType , ok := c .Type .(* pgtypes.DoltgresType ); ok {
449482 if doltgresType .TypType == pgtypes .TypeType_Domain {
@@ -452,26 +485,27 @@ func schemaToFieldDescriptions(ctx *sql.Context, s sql.Schema) []pgproto3.FieldD
452485 oid = id .Cache ().ToOID (doltgresType .ID .AsId ())
453486 }
454487 typmod = doltgresType .GetAttTypMod () // pg_attribute.atttypmod
488+ if isPrepared {
489+ switch doltgresType .ID {
490+ case pgtypes .Bytea .ID , pgtypes .Int16 .ID , pgtypes .Int32 .ID , pgtypes .Int64 .ID , pgtypes .Uuid .ID :
491+ formatCode = 1
492+ }
493+ }
455494 } else {
456495 oid , err = VitessTypeToObjectID (c .Type .Type ())
457496 if err != nil {
458497 panic (err )
459498 }
460499 }
461500
462- // "Format" field: The format code being used for the field.
463- // Currently, will be zero (text) or one (binary).
464- // In a RowDescription returned from the statement variant of Describe,
465- // the format code is not yet known and will always be zero.
466-
467501 fields [i ] = pgproto3.FieldDescription {
468502 Name : []byte (c .Name ),
469503 TableOID : uint32 (0 ),
470504 TableAttributeNumber : uint16 (0 ),
471505 DataTypeOID : oid ,
472506 DataTypeSize : int16 (c .Type .MaxTextResponseByteLength (ctx )),
473507 TypeModifier : typmod ,
474- Format : int16 ( 0 ) ,
508+ Format : formatCode ,
475509 }
476510 }
477511
@@ -515,7 +549,7 @@ func resultForEmptyIter(ctx *sql.Context, iter sql.RowIter) (*Result, error) {
515549}
516550
517551// resultForMax1RowIter ensures that an empty iterator returns at most one row
518- func resultForMax1RowIter (ctx * sql.Context , schema sql.Schema , iter sql.RowIter , resultFields []pgproto3.FieldDescription ) (* Result , error ) {
552+ func resultForMax1RowIter (ctx * sql.Context , schema sql.Schema , iter sql.RowIter , resultFields []pgproto3.FieldDescription , isExecute bool ) (* Result , error ) {
519553 defer trace .StartRegion (ctx , "DoltgresHandler.resultForMax1RowIter" ).End ()
520554 row , err := iter .Next (ctx )
521555 if err == io .EOF {
@@ -531,7 +565,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,
531565 return nil , err
532566 }
533567
534- outputRow , err := rowToBytes (ctx , schema , row )
568+ outputRow , err := rowToBytes (ctx , schema , row , isExecute )
535569 if err != nil {
536570 return nil , err
537571 }
@@ -543,7 +577,7 @@ func resultForMax1RowIter(ctx *sql.Context, schema sql.Schema, iter sql.RowIter,
543577
544578// resultForDefaultIter reads batches of rows from the iterator
545579// and writes results into the callback function.
546- func (h * DoltgresHandler ) resultForDefaultIter (ctx * sql.Context , schema sql.Schema , iter sql.RowIter , callback func (* sql.Context , * Result ) error , resultFields []pgproto3.FieldDescription ) (* Result , bool , error ) {
580+ func (h * DoltgresHandler ) resultForDefaultIter (ctx * sql.Context , schema sql.Schema , iter sql.RowIter , callback func (* sql.Context , * Result ) error , resultFields []pgproto3.FieldDescription , isExecute bool ) (* Result , bool , error ) {
547581 defer trace .StartRegion (ctx , "DoltgresHandler.resultForDefaultIter" ).End ()
548582
549583 var r * Result
@@ -636,7 +670,7 @@ func (h *DoltgresHandler) resultForDefaultIter(ctx *sql.Context, schema sql.Sche
636670 continue
637671 }
638672
639- outputRow , err := rowToBytes (ctx , schema , row )
673+ outputRow , err := rowToBytes (ctx , schema , row , isExecute )
640674 if err != nil {
641675 return err
642676 }
@@ -678,7 +712,7 @@ func (h *DoltgresHandler) resultForDefaultIter(ctx *sql.Context, schema sql.Sche
678712 return r , processedAtLeastOneBatch , nil
679713}
680714
681- func rowToBytes (ctx * sql.Context , s sql.Schema , row sql.Row ) ([][]byte , error ) {
715+ func rowToBytes (ctx * sql.Context , s sql.Schema , row sql.Row , isExecute bool ) ([][]byte , error ) {
682716 if len (row ) == 0 {
683717 return nil , nil
684718 }
@@ -691,6 +725,41 @@ func rowToBytes(ctx *sql.Context, s sql.Schema, row sql.Row) ([][]byte, error) {
691725 if v == nil {
692726 o [i ] = nil
693727 } else {
728+ if isExecute {
729+ switch d := s [i ].Type .(type ) {
730+ case * pgtypes.DoltgresType :
731+ switch d .ID {
732+ // This is the list of types to use binary mode for when receiving them
733+ // through a prepared statement. If a type appears in this list, it
734+ // must also be implemented in binaryDecode in encode.go.
735+ case pgtypes .Bytea .ID :
736+ o [i ] = v .([]byte )
737+ continue
738+ case pgtypes .Int64 .ID :
739+ buf := make ([]byte , 8 )
740+ binary .BigEndian .PutUint64 (buf , uint64 (v .(int64 )))
741+ o [i ] = buf
742+ continue
743+ case pgtypes .Int32 .ID :
744+ buf := make ([]byte , 4 )
745+ binary .BigEndian .PutUint32 (buf , uint32 (v .(int32 )))
746+ o [i ] = buf
747+ continue
748+ case pgtypes .Int16 .ID :
749+ buf := make ([]byte , 2 )
750+ binary .BigEndian .PutUint16 (buf , uint16 (v .(int16 )))
751+ o [i ] = buf
752+ continue
753+ case pgtypes .Uuid .ID :
754+ buf , err := v .(uuid.UUID ).MarshalBinary ()
755+ if err != nil {
756+ return nil , err
757+ }
758+ o [i ] = buf
759+ continue
760+ }
761+ }
762+ }
694763 val , err := s [i ].Type .SQL (ctx , []byte {}, v )
695764 if err != nil {
696765 return nil , err
0 commit comments