@@ -621,13 +621,20 @@ func ParsePrepareStmtResp(resp []byte) (stmtID uint32, paramNum int) {
621621 return
622622}
623623
624- func MakeExecuteStmtRequest (stmtID uint32 , args []any ) ([]byte , error ) {
624+ func MakeExecuteStmtRequest (stmtID uint32 , args []any , newParamBound bool ) ([]byte , error ) {
625625 paramNum := len (args )
626626 paramTypes := make ([]byte , paramNum * 2 )
627627 paramValues := make ([][]byte , paramNum )
628628 nullBitmap := make ([]byte , (paramNum + 7 )>> 3 )
629- dataLen := 1 + 4 + 1 + 4 + len (nullBitmap ) + 1 + len (paramTypes )
629+ dataLen := 1 + 4 + 1 + 4
630+ if paramNum > 0 {
631+ dataLen += len (nullBitmap ) + 1
632+ }
630633 var newParamBoundFlag byte = 0
634+ if newParamBound {
635+ newParamBoundFlag = 1
636+ dataLen += len (paramTypes )
637+ }
631638
632639 for i := range args {
633640 if args [i ] == nil {
@@ -636,7 +643,6 @@ func MakeExecuteStmtRequest(stmtID uint32, args []any) ([]byte, error) {
636643 continue
637644 }
638645
639- newParamBoundFlag = 1
640646 switch v := args [i ].(type ) {
641647 case int8 :
642648 paramTypes [i << 1 ] = fieldTypeTiny
@@ -719,7 +725,7 @@ func MakeExecuteStmtRequest(stmtID uint32, args []any) ([]byte, error) {
719725 pos += len (nullBitmap )
720726 request [pos ] = newParamBoundFlag
721727 pos ++
722- if newParamBoundFlag == 1 {
728+ if newParamBound {
723729 copy (request [pos :], paramTypes )
724730 pos += len (paramTypes )
725731 }
@@ -734,55 +740,62 @@ func MakeExecuteStmtRequest(stmtID uint32, args []any) ([]byte, error) {
734740// ParseExecuteStmtRequest parses ComStmtExecute request.
735741// NOTICE: the type of returned args may be wrong because it doesn't have the knowledge of real param types.
736742// E.g. []byte is returned as string, and int is returned as int32.
737- func ParseExecuteStmtRequest (data []byte , paramNum int ) (stmtID uint32 , args []any , err error ) {
738- if len (data ) < 1 + 4 + 1 + 4 + 1 {
739- return 0 , nil , errors .WithStack (gomysql .ErrMalformPacket )
743+ func ParseExecuteStmtRequest (data []byte , paramNum int , paramTypes [] byte ) (stmtID uint32 , args []any , newParamTypes [] byte , err error ) {
744+ if len (data ) < 1 + 4 + 1 + 4 {
745+ return 0 , nil , nil , errors .WithStack (gomysql .ErrMalformPacket )
740746 }
741747
742748 pos := 1
743749 stmtID = binary .LittleEndian .Uint32 (data [pos : pos + 4 ])
750+ // paramNum is contained in the ComStmtPrepare but paramTypes is contained in the first ComStmtExecute (with newParamBoundFlag==1).
751+ // If the prepared statement is parsed from the session states, the paramTypes may be empty but the paramNum is not in the session states.
752+ // Just return empty args in this case, which is fine currently.
753+ if paramNum == 0 {
754+ return stmtID , nil , nil , nil
755+ }
744756 // cursor flag and iteration count
745757 pos += 4 + 1 + 4
746758 if len (data ) < pos + ((paramNum + 7 )>> 3 )+ 1 {
747- return 0 , nil , errors .WithStack (gomysql .ErrMalformPacket )
759+ return 0 , nil , nil , errors .WithStack (gomysql .ErrMalformPacket )
748760 }
749761 nullBitmap := data [pos : pos + ((paramNum + 7 )>> 3 )]
750762 pos += len (nullBitmap )
751763 newParamBoundFlag := data [pos ]
752764 pos += 1
753765 args = make ([]any , paramNum )
754- if newParamBoundFlag == 0 {
755- return stmtID , args , nil
756- }
757766
758- if len (data ) < pos + paramNum * 2 {
759- return 0 , nil , errors .WithStack (gomysql .ErrMalformPacket )
767+ if newParamBoundFlag > 0 {
768+ if len (data ) < pos + paramNum << 1 {
769+ return 0 , nil , nil , errors .WithStack (gomysql .ErrMalformPacket )
770+ }
771+ paramTypes = data [pos : pos + paramNum << 1 ]
772+ pos += paramNum << 1
760773 }
761- paramTypes := data [pos : pos + paramNum * 2 ]
762- pos += paramNum * 2
763774
764775 for i := 0 ; i < paramNum ; i ++ {
765776 if nullBitmap [i / 8 ]& (1 << (uint (i )% 8 )) > 0 {
766777 args [i ] = nil
767778 continue
768779 }
769780 switch paramTypes [i << 1 ] {
781+ case fieldTypeNULL :
782+ args [i ] = nil
770783 case fieldTypeTiny :
771784 if paramTypes [(i << 1 )+ 1 ] == 0x80 {
772785 args [i ] = uint8 (data [pos ])
773786 } else {
774787 args [i ] = int8 (data [pos ])
775788 }
776789 pos += 1
777- case fieldTypeShort :
790+ case fieldTypeShort , fieldTypeYear :
778791 v := binary .LittleEndian .Uint16 (data [pos : pos + 2 ])
779792 if paramTypes [(i << 1 )+ 1 ] == 0x80 {
780793 args [i ] = v
781794 } else {
782795 args [i ] = int16 (v )
783796 }
784797 pos += 2
785- case fieldTypeLong :
798+ case fieldTypeLong , fieldTypeInt24 :
786799 v := binary .LittleEndian .Uint32 (data [pos : pos + 4 ])
787800 if paramTypes [(i << 1 )+ 1 ] == 0x80 {
788801 args [i ] = v
@@ -804,19 +817,58 @@ func ParseExecuteStmtRequest(data []byte, paramNum int) (stmtID uint32, args []a
804817 case fieldTypeDouble :
805818 args [i ] = math .Float64frombits (binary .LittleEndian .Uint64 (data [pos : pos + 8 ]))
806819 pos += 8
807- case fieldTypeString :
808- v , _ , off , err := ParseLengthEncodedBytes (data [pos :])
820+ case fieldTypeDate , fieldTypeTimestamp , fieldTypeDateTime :
821+ length := data [pos ]
822+ pos ++
823+ switch length {
824+ case 0 :
825+ args [i ] = "0000-00-00 00:00:00"
826+ case 4 :
827+ pos , args [i ] = BinaryDate (pos , data )
828+ case 7 :
829+ pos , args [i ] = BinaryDateTime (pos , data )
830+ case 11 :
831+ pos , args [i ] = BinaryTimestamp (pos , data )
832+ case 13 :
833+ pos , args [i ] = BinaryTimestampWithTZ (pos , data )
834+ default :
835+ return 0 , nil , nil , errors .WithStack (gomysql .ErrMalformPacket )
836+ }
837+ case fieldTypeTime :
838+ length := data [pos ]
839+ pos ++
840+ switch length {
841+ case 0 :
842+ args [i ] = "0"
843+ case 8 :
844+ isNegative := data [pos ]
845+ pos ++
846+ pos , args [i ] = BinaryDuration (pos , data , isNegative )
847+ case 12 :
848+ isNegative := data [pos ]
849+ pos ++
850+ pos , args [i ] = BinaryDurationWithMS (pos , data , isNegative )
851+ default :
852+ return 0 , nil , nil , errors .WithStack (gomysql .ErrMalformPacket )
853+ }
854+ case fieldTypeDecimal , fieldTypeNewDecimal , fieldTypeVarChar , fieldTypeString , fieldTypeVarString , fieldTypeBLOB , fieldTypeTinyBLOB ,
855+ fieldTypeMediumBLOB , fieldTypeLongBLOB , fieldTypeEnum , fieldTypeSet , fieldTypeGeometry , fieldTypeBit , fieldTypeJSON , fieldTypeVector :
856+ v , isNull , n , err := ParseLengthEncodedBytes (data [pos :])
809857 if err != nil {
810- return 0 , nil , errors .Wrapf (err , "parse param %d err " , i )
858+ return 0 , nil , nil , errors .Wrapf (err , "parse param err, type: %d, idx: %d, pos: %d " , paramTypes [ i << 1 ], i , pos )
811859 }
812- args [i ] = hack .String (v )
813- pos += off
860+ if isNull {
861+ args [i ] = nil
862+ } else {
863+ args [i ] = hack .String (v )
864+ }
865+ pos += n
814866 default :
815- return 0 , nil , errors .Errorf ("unsupported type %d" , paramTypes [i << 1 ])
867+ return 0 , nil , nil , errors .Errorf ("unsupported type %d" , paramTypes [i << 1 ])
816868 }
817869 }
818870
819- return stmtID , args , nil
871+ return stmtID , args , paramTypes , nil
820872}
821873
822874func MakeCloseStmtRequest (stmtID uint32 ) []byte {
0 commit comments