@@ -56,18 +56,22 @@ func (s *Stmt) Close() error {
5656 return nil
5757}
5858
59+ // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
5960func (s * Stmt ) write (args ... interface {}) error {
6061 paramsNum := s .params
6162
6263 if len (args ) != paramsNum {
6364 return fmt .Errorf ("argument mismatch, need %d but got %d" , s .params , len (args ))
6465 }
6566
66- paramTypes := make ([]byte , paramsNum << 1 )
67- paramValues := make ([][]byte , paramsNum )
67+ qaLen := len (s .conn .queryAttributes )
68+ paramTypes := make ([][]byte , paramsNum + qaLen )
69+ paramFlags := make ([][]byte , paramsNum + qaLen )
70+ paramValues := make ([][]byte , paramsNum + qaLen )
71+ paramNames := make ([][]byte , paramsNum + qaLen )
6872
6973 //NULL-bitmap, length: (num-params+7)
70- nullBitmap := make ([]byte , (paramsNum + 7 )>> 3 )
74+ nullBitmap := make ([]byte , (paramsNum + qaLen + 7 )>> 3 )
7175
7276 length := 1 + 4 + 1 + 4 + ((paramsNum + 7 ) >> 3 ) + 1 + (paramsNum << 1 )
7377
@@ -76,76 +80,87 @@ func (s *Stmt) write(args ...interface{}) error {
7680 for i := range args {
7781 if args [i ] == nil {
7882 nullBitmap [i / 8 ] |= 1 << (uint (i ) % 8 )
79- paramTypes [i << 1 ] = MYSQL_TYPE_NULL
83+ paramTypes [i ] = [] byte { MYSQL_TYPE_NULL }
8084 continue
8185 }
8286
8387 newParamBoundFlag = 1
8488
8589 switch v := args [i ].(type ) {
8690 case int8 :
87- paramTypes [i << 1 ] = MYSQL_TYPE_TINY
91+ paramTypes [i ] = [] byte { MYSQL_TYPE_TINY }
8892 paramValues [i ] = []byte {byte (v )}
8993 case int16 :
90- paramTypes [i << 1 ] = MYSQL_TYPE_SHORT
94+ paramTypes [i ] = [] byte { MYSQL_TYPE_SHORT }
9195 paramValues [i ] = Uint16ToBytes (uint16 (v ))
9296 case int32 :
93- paramTypes [i << 1 ] = MYSQL_TYPE_LONG
97+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONG }
9498 paramValues [i ] = Uint32ToBytes (uint32 (v ))
9599 case int :
96- paramTypes [i << 1 ] = MYSQL_TYPE_LONGLONG
100+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONGLONG }
97101 paramValues [i ] = Uint64ToBytes (uint64 (v ))
98102 case int64 :
99- paramTypes [i << 1 ] = MYSQL_TYPE_LONGLONG
103+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONGLONG }
100104 paramValues [i ] = Uint64ToBytes (uint64 (v ))
101105 case uint8 :
102- paramTypes [i << 1 ] = MYSQL_TYPE_TINY
103- paramTypes [( i << 1 ) + 1 ] = 0x80
106+ paramTypes [i ] = [] byte { MYSQL_TYPE_TINY }
107+ paramFlags [ i ] = [] byte { UNSIGNED_FLAG }
104108 paramValues [i ] = []byte {v }
105109 case uint16 :
106- paramTypes [i << 1 ] = MYSQL_TYPE_SHORT
107- paramTypes [( i << 1 ) + 1 ] = 0x80
110+ paramTypes [i ] = [] byte { MYSQL_TYPE_SHORT }
111+ paramFlags [ i ] = [] byte { UNSIGNED_FLAG }
108112 paramValues [i ] = Uint16ToBytes (v )
109113 case uint32 :
110- paramTypes [i << 1 ] = MYSQL_TYPE_LONG
111- paramTypes [( i << 1 ) + 1 ] = 0x80
114+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONG }
115+ paramFlags [ i ] = [] byte { UNSIGNED_FLAG }
112116 paramValues [i ] = Uint32ToBytes (v )
113117 case uint :
114- paramTypes [i << 1 ] = MYSQL_TYPE_LONGLONG
115- paramTypes [( i << 1 ) + 1 ] = 0x80
118+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONGLONG }
119+ paramFlags [ i ] = [] byte { UNSIGNED_FLAG }
116120 paramValues [i ] = Uint64ToBytes (uint64 (v ))
117121 case uint64 :
118- paramTypes [i << 1 ] = MYSQL_TYPE_LONGLONG
119- paramTypes [( i << 1 ) + 1 ] = 0x80
122+ paramTypes [i ] = [] byte { MYSQL_TYPE_LONGLONG }
123+ paramFlags [ i ] = [] byte { UNSIGNED_FLAG }
120124 paramValues [i ] = Uint64ToBytes (v )
121125 case bool :
122- paramTypes [i << 1 ] = MYSQL_TYPE_TINY
126+ paramTypes [i ] = [] byte { MYSQL_TYPE_TINY }
123127 if v {
124128 paramValues [i ] = []byte {1 }
125129 } else {
126130 paramValues [i ] = []byte {0 }
127131 }
128132 case float32 :
129- paramTypes [i << 1 ] = MYSQL_TYPE_FLOAT
133+ paramTypes [i ] = [] byte { MYSQL_TYPE_FLOAT }
130134 paramValues [i ] = Uint32ToBytes (math .Float32bits (v ))
131135 case float64 :
132- paramTypes [i << 1 ] = MYSQL_TYPE_DOUBLE
136+ paramTypes [i ] = [] byte { MYSQL_TYPE_DOUBLE }
133137 paramValues [i ] = Uint64ToBytes (math .Float64bits (v ))
134138 case string :
135- paramTypes [i << 1 ] = MYSQL_TYPE_STRING
139+ paramTypes [i ] = [] byte { MYSQL_TYPE_STRING }
136140 paramValues [i ] = append (PutLengthEncodedInt (uint64 (len (v ))), v ... )
137141 case []byte :
138- paramTypes [i << 1 ] = MYSQL_TYPE_STRING
142+ paramTypes [i ] = [] byte { MYSQL_TYPE_STRING }
139143 paramValues [i ] = append (PutLengthEncodedInt (uint64 (len (v ))), v ... )
140144 case json.RawMessage :
141- paramTypes [i << 1 ] = MYSQL_TYPE_STRING
145+ paramTypes [i ] = [] byte { MYSQL_TYPE_STRING }
142146 paramValues [i ] = append (PutLengthEncodedInt (uint64 (len (v ))), v ... )
143147 default :
144148 return fmt .Errorf ("invalid argument type %T" , args [i ])
145149 }
150+ paramNames [i ] = []byte {0 } // lenght encoded, no name
151+ if paramFlags [i ] == nil {
152+ paramFlags [i ] = []byte {0 }
153+ }
146154
147155 length += len (paramValues [i ])
148156 }
157+ for i , qa := range s .conn .queryAttributes {
158+ tf := qa .TypeAndFlag ()
159+ paramTypes [(i + paramsNum )] = []byte {tf [0 ]}
160+ paramFlags [i + paramsNum ] = []byte {tf [1 ]}
161+ paramValues [i + paramsNum ] = qa .ValueBytes ()
162+ paramNames [i + paramsNum ] = PutLengthEncodedString ([]byte (qa .Name ))
163+ }
149164
150165 data := utils .BytesBufferGet ()
151166 defer func () {
@@ -159,30 +174,46 @@ func (s *Stmt) write(args ...interface{}) error {
159174 data .WriteByte (COM_STMT_EXECUTE )
160175 data .Write ([]byte {byte (s .id ), byte (s .id >> 8 ), byte (s .id >> 16 ), byte (s .id >> 24 )})
161176
162- //flag: CURSOR_TYPE_NO_CURSOR
163- data .WriteByte (0x00 )
177+ flags := CURSOR_TYPE_NO_CURSOR
178+ if s .conn .capability & CLIENT_QUERY_ATTRIBUTES > 0 && len (s .conn .queryAttributes ) > 0 {
179+ flags |= PARAMETER_COUNT_AVAILABLE
180+ }
181+ data .WriteByte (flags )
164182
165183 //iteration-count, always 1
166184 data .Write ([]byte {1 , 0 , 0 , 0 })
167185
168- if s .params > 0 {
169- data .Write (nullBitmap )
170-
171- //new-params-bound-flag
172- data .WriteByte (newParamBoundFlag )
173-
174- if newParamBoundFlag == 1 {
175- //type of each parameter, length: num-params * 2
176- data .Write (paramTypes )
177-
178- //value of each parameter
179- for _ , v := range paramValues {
180- data .Write (v )
186+ if paramsNum > 0 || (s .conn .capability & CLIENT_QUERY_ATTRIBUTES > 0 && (flags & PARAMETER_COUNT_AVAILABLE > 0 )) {
187+ if s .conn .capability & CLIENT_QUERY_ATTRIBUTES > 0 {
188+ paramsNum += len (s .conn .queryAttributes )
189+ data .Write (PutLengthEncodedInt (uint64 (paramsNum )))
190+ }
191+ if paramsNum > 0 {
192+ data .Write (nullBitmap )
193+
194+ //new-params-bound-flag
195+ data .WriteByte (newParamBoundFlag )
196+
197+ if newParamBoundFlag == 1 {
198+ for i := 0 ; i < paramsNum ; i ++ {
199+ data .Write (paramTypes [i ])
200+ data .Write (paramFlags [i ])
201+
202+ if s .conn .capability & CLIENT_QUERY_ATTRIBUTES > 0 {
203+ data .Write (paramNames [i ])
204+ }
205+ }
206+
207+ //value of each parameter
208+ for _ , v := range paramValues {
209+ data .Write (v )
210+ }
181211 }
182212 }
183213 }
184214
185215 s .conn .ResetSequence ()
216+ s .conn .queryAttributes = nil
186217
187218 return s .conn .WritePacket (data .Bytes ())
188219}
0 commit comments