@@ -6,6 +6,7 @@ package net
66import (
77 "bytes"
88 "encoding/binary"
9+ "encoding/json"
910 "math"
1011 "net"
1112
@@ -583,14 +584,44 @@ func MakeQueryPacket(stmt string) []byte {
583584 return request
584585}
585586
586- func MakePrepareStmtPacket (stmt string ) []byte {
587+ func MakePrepareStmtRequest (stmt string ) []byte {
587588 request := make ([]byte , len (stmt )+ 1 )
588589 request [0 ] = ComStmtPrepare .Byte ()
589590 copy (request [1 :], hack .Slice (stmt ))
590591 return request
591592}
592593
593- func MakeExecuteStmtPacket (stmtID uint32 , args []any ) ([]byte , error ) {
594+ // MakePrepareStmtResp creates a prepared statement response.
595+ // The packet is incomplete and it's only used for testing.
596+ func MakePrepareStmtResp (stmtID uint32 , paramNum int ) []byte {
597+ // header
598+ response := make ([]byte , 1 + 4 + 2 + 2 )
599+ pos := 0
600+ response [pos ] = ComStmtPrepare .Byte ()
601+ pos += 1
602+ binary .LittleEndian .PutUint32 (response [pos :], stmtID )
603+ pos += 4
604+ // column count
605+ pos += 2
606+ // param count
607+ binary .LittleEndian .PutUint16 (response [pos :], uint16 (paramNum ))
608+ // ignore rest part
609+ return response
610+ }
611+
612+ func ParsePrepareStmtResp (resp []byte ) (stmtID uint32 , paramNum int ) {
613+ // header
614+ pos := 1
615+ stmtID = binary .LittleEndian .Uint32 (resp [pos :])
616+ pos += 4
617+ // column count
618+ pos += 2
619+ paramNum = int (binary .LittleEndian .Uint16 (resp [pos :]))
620+ // ignore rest part
621+ return
622+ }
623+
624+ func MakeExecuteStmtRequest (stmtID uint32 , args []any ) ([]byte , error ) {
594625 paramNum := len (args )
595626 paramTypes := make ([]byte , paramNum * 2 )
596627 paramValues := make ([][]byte , paramNum )
@@ -607,21 +638,64 @@ func MakeExecuteStmtPacket(stmtID uint32, args []any) ([]byte, error) {
607638
608639 newParamBoundFlag = 1
609640 switch v := args [i ].(type ) {
641+ case int8 :
642+ paramTypes [i << 1 ] = fieldTypeTiny
643+ paramValues [i ] = []byte {byte (v )}
644+ case int16 :
645+ paramTypes [i << 1 ] = fieldTypeShort
646+ paramValues [i ] = Uint16ToBytes (uint16 (v ))
647+ case int32 :
648+ paramTypes [i << 1 ] = fieldTypeLong
649+ paramValues [i ] = Uint32ToBytes (uint32 (v ))
650+ case int :
651+ paramTypes [i << 1 ] = fieldTypeLongLong
652+ paramValues [i ] = Uint64ToBytes (uint64 (v ))
610653 case int64 :
611654 paramTypes [i << 1 ] = fieldTypeLongLong
612655 paramValues [i ] = Uint64ToBytes (uint64 (v ))
656+ case uint8 :
657+ paramTypes [i << 1 ] = fieldTypeTiny
658+ paramTypes [(i << 1 )+ 1 ] = 0x80
659+ paramValues [i ] = []byte {v }
660+ case uint16 :
661+ paramTypes [i << 1 ] = fieldTypeShort
662+ paramTypes [(i << 1 )+ 1 ] = 0x80
663+ paramValues [i ] = Uint16ToBytes (v )
664+ case uint32 :
665+ paramTypes [i << 1 ] = fieldTypeLong
666+ paramTypes [(i << 1 )+ 1 ] = 0x80
667+ paramValues [i ] = Uint32ToBytes (v )
668+ case uint :
669+ paramTypes [i << 1 ] = fieldTypeLongLong
670+ paramTypes [(i << 1 )+ 1 ] = 0x80
671+ paramValues [i ] = Uint64ToBytes (uint64 (v ))
613672 case uint64 :
614673 paramTypes [i << 1 ] = fieldTypeLongLong
615674 paramTypes [(i << 1 )+ 1 ] = 0x80
616675 paramValues [i ] = Uint64ToBytes (v )
676+ case bool :
677+ paramTypes [i << 1 ] = fieldTypeTiny
678+ if v {
679+ paramValues [i ] = []byte {1 }
680+ } else {
681+ paramValues [i ] = []byte {0 }
682+ }
683+ case float32 :
684+ paramTypes [i << 1 ] = fieldTypeFloat
685+ paramValues [i ] = Uint32ToBytes (math .Float32bits (v ))
617686 case float64 :
618687 paramTypes [i << 1 ] = fieldTypeDouble
619688 paramValues [i ] = Uint64ToBytes (math .Float64bits (v ))
620689 case string :
621690 paramTypes [i << 1 ] = fieldTypeString
622691 paramValues [i ] = DumpLengthEncodedString (nil , hack .Slice (v ))
692+ case []byte :
693+ paramTypes [i << 1 ] = fieldTypeString
694+ paramValues [i ] = DumpLengthEncodedString (nil , v )
695+ case json.RawMessage :
696+ paramTypes [i << 1 ] = fieldTypeString
697+ paramValues [i ] = DumpLengthEncodedString (nil , v )
623698 default :
624- // we don't need other types currently
625699 return nil , errors .WithStack (errors .Errorf ("unsupported type %T" , v ))
626700 }
627701
@@ -656,3 +730,98 @@ func MakeExecuteStmtPacket(stmtID uint32, args []any) ([]byte, error) {
656730 }
657731 return request , nil
658732}
733+
734+ // ParseExecuteStmtRequest parses ComStmtExecute request.
735+ // NOTICE: the type of returned args may be wrong because it doesn't have the knowledge of real param types.
736+ // E.g. []byte is returned as string, and int is returned as int32.
737+ func ParseExecuteStmtRequest (data []byte , paramNum int ) (stmtID uint32 , args []any , err error ) {
738+ if len (data ) < 1 + 4 + 1 + 4 + 1 {
739+ return 0 , nil , errors .WithStack (gomysql .ErrMalformPacket )
740+ }
741+
742+ pos := 1
743+ stmtID = binary .LittleEndian .Uint32 (data [pos : pos + 4 ])
744+ // cursor flag and iteration count
745+ pos += 4 + 1 + 4
746+ if len (data ) < pos + ((paramNum + 7 )>> 3 )+ 1 {
747+ return 0 , nil , errors .WithStack (gomysql .ErrMalformPacket )
748+ }
749+ nullBitmap := data [pos : pos + ((paramNum + 7 )>> 3 )]
750+ pos += len (nullBitmap )
751+ newParamBoundFlag := data [pos ]
752+ pos += 1
753+ args = make ([]any , paramNum )
754+ if newParamBoundFlag == 0 {
755+ return stmtID , args , nil
756+ }
757+
758+ if len (data ) < pos + paramNum * 2 {
759+ return 0 , nil , errors .WithStack (gomysql .ErrMalformPacket )
760+ }
761+ paramTypes := data [pos : pos + paramNum * 2 ]
762+ pos += paramNum * 2
763+
764+ for i := 0 ; i < paramNum ; i ++ {
765+ if nullBitmap [i / 8 ]& (1 << (uint (i )% 8 )) > 0 {
766+ args [i ] = nil
767+ continue
768+ }
769+ switch paramTypes [i << 1 ] {
770+ case fieldTypeTiny :
771+ if paramTypes [(i << 1 )+ 1 ] == 0x80 {
772+ args [i ] = uint8 (data [pos ])
773+ } else {
774+ args [i ] = int8 (data [pos ])
775+ }
776+ pos += 1
777+ case fieldTypeShort :
778+ v := binary .LittleEndian .Uint16 (data [pos : pos + 2 ])
779+ if paramTypes [(i << 1 )+ 1 ] == 0x80 {
780+ args [i ] = v
781+ } else {
782+ args [i ] = int16 (v )
783+ }
784+ pos += 2
785+ case fieldTypeLong :
786+ v := binary .LittleEndian .Uint32 (data [pos : pos + 4 ])
787+ if paramTypes [(i << 1 )+ 1 ] == 0x80 {
788+ args [i ] = v
789+ } else {
790+ args [i ] = int32 (v )
791+ }
792+ pos += 4
793+ case fieldTypeLongLong :
794+ v := binary .LittleEndian .Uint64 (data [pos : pos + 8 ])
795+ if paramTypes [(i << 1 )+ 1 ] == 0x80 {
796+ args [i ] = v
797+ } else {
798+ args [i ] = int64 (v )
799+ }
800+ pos += 8
801+ case fieldTypeFloat :
802+ args [i ] = math .Float32frombits (binary .LittleEndian .Uint32 (data [pos : pos + 4 ]))
803+ pos += 4
804+ case fieldTypeDouble :
805+ args [i ] = math .Float64frombits (binary .LittleEndian .Uint64 (data [pos : pos + 8 ]))
806+ pos += 8
807+ case fieldTypeString :
808+ v , _ , off , err := ParseLengthEncodedBytes (data [pos :])
809+ if err != nil {
810+ return 0 , nil , errors .Wrapf (err , "parse param %d err" , i )
811+ }
812+ args [i ] = hack .String (v )
813+ pos += off
814+ default :
815+ return 0 , nil , errors .Errorf ("unsupported type %d" , paramTypes [i << 1 ])
816+ }
817+ }
818+
819+ return stmtID , args , nil
820+ }
821+
822+ func MakeCloseStmtRequest (stmtID uint32 ) []byte {
823+ request := make ([]byte , 1 + 4 )
824+ request [0 ] = ComStmtClose .Byte ()
825+ binary .LittleEndian .PutUint32 (request [1 :], stmtID )
826+ return request
827+ }
0 commit comments