Skip to content

Commit 5975ca9

Browse files
committed
more refactoring
Try to remove unnecessary indirections and initialisations with zero. Also update links to the MySQL doc
1 parent 33d6df2 commit 5975ca9

File tree

4 files changed

+78
-79
lines changed

4 files changed

+78
-79
lines changed

connection.go

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -136,14 +136,14 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
136136
columnCount, err := stmt.readPrepareResultPacket()
137137
if err == nil {
138138
if stmt.paramCount > 0 {
139-
stmt.params, err = stmt.mc.readColumns(stmt.paramCount)
139+
stmt.params, err = mc.readColumns(stmt.paramCount)
140140
if err != nil {
141141
return nil, err
142142
}
143143
}
144144

145145
if columnCount > 0 {
146-
err = stmt.mc.readUntilEOF()
146+
err = mc.readUntilEOF()
147147
}
148148
}
149149

@@ -171,26 +171,24 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err
171171
}
172172

173173
// Internal function to execute commands
174-
func (mc *mysqlConn) exec(query string) (err error) {
174+
func (mc *mysqlConn) exec(query string) error {
175175
// Send command
176-
err = mc.writeCommandPacketStr(comQuery, query)
176+
err := mc.writeCommandPacketStr(comQuery, query)
177177
if err != nil {
178-
return
178+
return err
179179
}
180180

181181
// Read Result
182-
var resLen int
183-
resLen, err = mc.readResultSetHeaderPacket()
182+
resLen, err := mc.readResultSetHeaderPacket()
184183
if err == nil && resLen > 0 {
185-
err = mc.readUntilEOF()
186-
if err != nil {
187-
return
184+
if err = mc.readUntilEOF(); err != nil {
185+
return err
188186
}
189187

190188
err = mc.readUntilEOF()
191189
}
192190

193-
return
191+
return err
194192
}
195193

196194
func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) {
@@ -211,7 +209,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
211209
return rows, err
212210
}
213211
}
214-
215212
return nil, err
216213
}
217214

@@ -221,29 +218,29 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro
221218

222219
// Gets the value of the given MySQL System Variable
223220
// The returned byte slice is only valid until the next read
224-
func (mc *mysqlConn) getSystemVar(name string) (val []byte, err error) {
221+
func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) {
225222
// Send command
226-
err = mc.writeCommandPacketStr(comQuery, "SELECT @@"+name)
223+
if err := mc.writeCommandPacketStr(comQuery, "SELECT @@"+name); err != nil {
224+
return nil, err
225+
}
226+
227+
// Read Result
228+
resLen, err := mc.readResultSetHeaderPacket()
227229
if err == nil {
228-
// Read Result
229-
var resLen int
230-
resLen, err = mc.readResultSetHeaderPacket()
231-
if err == nil {
232-
rows := &mysqlRows{mc, false, nil, false}
230+
rows := &mysqlRows{mc, false, nil, false}
233231

234-
if resLen > 0 {
235-
// Columns
236-
rows.columns, err = mc.readColumns(resLen)
232+
if resLen > 0 {
233+
// Columns
234+
rows.columns, err = mc.readColumns(resLen)
235+
if err != nil {
236+
return nil, err
237237
}
238+
}
238239

239-
dest := make([]driver.Value, resLen)
240-
err = rows.readRow(dest)
241-
if err == nil {
242-
val = dest[0].([]byte)
243-
err = mc.readUntilEOF()
244-
}
240+
dest := make([]driver.Value, resLen)
241+
if err = rows.readRow(dest); err == nil {
242+
return dest[0].([]byte), mc.readUntilEOF()
245243
}
246244
}
247-
248-
return
245+
return nil, err
249246
}

packets.go

Lines changed: 30 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ func (mc *mysqlConn) splitPacket(data []byte) error {
140140
******************************************************************************/
141141

142142
// Handshake Initialization Packet
143-
// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake
143+
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
144144
func (mc *mysqlConn) readInitPacket() ([]byte, error) {
145145
data, err := mc.readPacket()
146146
if err != nil {
@@ -197,7 +197,6 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
197197
// TODO: Verify string termination
198198
// EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2)
199199
// \NUL otherwise
200-
// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::Handshake
201200
//
202201
//if data[len(data)-1] == 0 {
203202
// return
@@ -209,7 +208,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) {
209208
}
210209

211210
// Client Authentication Packet
212-
// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::HandshakeResponse
211+
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
213212
func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
214213
// Adjust client flags based on server support
215214
clientFlags := clientProtocol41 |
@@ -263,7 +262,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
263262
data[12] = collation_utf8_general_ci
264263

265264
// SSL Connection Request Packet
266-
// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::SSLRequest
265+
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
267266
if mc.cfg.tls != nil {
268267
// Packet header [24bit length + 1 byte sequence]
269268
data[0] = byte((4 + 4 + 1 + 23))
@@ -316,7 +315,7 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error {
316315
}
317316

318317
// Client old authentication packet
319-
// http://dev.mysql.com/doc/internals/en/connection-phase.html#packet-Protocol::AuthSwitchResponse
318+
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse
320319
func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error {
321320
// User password
322321
scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.passwd))
@@ -454,7 +453,7 @@ func (mc *mysqlConn) readResultOK() error {
454453
}
455454

456455
// Result Set Header Packet
457-
// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::Resultset
456+
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::Resultset
458457
func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
459458
data, err := mc.readPacket()
460459
if err == nil {
@@ -482,7 +481,7 @@ func (mc *mysqlConn) readResultSetHeaderPacket() (int, error) {
482481
}
483482

484483
// Error Packet
485-
// http://dev.mysql.com/doc/internals/en/overview.html#packet-ERR_Packet
484+
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-ERR_Packet
486485
func (mc *mysqlConn) handleErrorPacket(data []byte) error {
487486
if data[0] != iERR {
488487
return errMalformPkt
@@ -509,7 +508,7 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error {
509508
}
510509

511510
// Ok Packet
512-
// http://dev.mysql.com/doc/internals/en/overview.html#packet-OK_Packet
511+
// http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-OK_Packet
513512
func (mc *mysqlConn) handleOkPacket(data []byte) error {
514513
var n, m int
515514

@@ -536,7 +535,7 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error {
536535
}
537536

538537
// Read Packets as Field Packets until EOF-Packet or an Error appears
539-
// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-Protocol::ColumnDefinition41
538+
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnDefinition41
540539
func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
541540
columns := make([]mysqlField, count)
542541

@@ -619,9 +618,11 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) {
619618
}
620619

621620
// Read Packets as Field Packets until EOF-Packet or an Error appears
622-
// http://dev.mysql.com/doc/internals/en/text-protocol.html#packet-ProtocolText::ResultsetRow
621+
// http://dev.mysql.com/doc/internals/en/com-query-response.html#packet-ProtocolText::ResultsetRow
623622
func (rows *mysqlRows) readRow(dest []driver.Value) error {
624-
data, err := rows.mc.readPacket()
623+
mc := rows.mc
624+
625+
data, err := mc.readPacket()
625626
if err != nil {
626627
return err
627628
}
@@ -642,15 +643,15 @@ func (rows *mysqlRows) readRow(dest []driver.Value) error {
642643
pos += n
643644
if err == nil {
644645
if !isNull {
645-
if !rows.mc.parseTime {
646+
if !mc.parseTime {
646647
continue
647648
} else {
648649
switch rows.columns[i].fieldType {
649650
case fieldTypeTimestamp, fieldTypeDateTime,
650651
fieldTypeDate, fieldTypeNewDate:
651652
dest[i], err = parseDateTime(
652653
string(dest[i].([]byte)),
653-
rows.mc.cfg.loc,
654+
mc.cfg.loc,
654655
)
655656
if err == nil {
656657
continue
@@ -689,7 +690,7 @@ func (mc *mysqlConn) readUntilEOF() error {
689690
******************************************************************************/
690691

691692
// Prepare Result Packets
692-
// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-prepare-response
693+
// http://dev.mysql.com/doc/internals/en/com-stmt-prepare-response.html
693694
func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
694695
data, err := stmt.mc.readPacket()
695696
if err == nil {
@@ -723,7 +724,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) {
723724
return 0, err
724725
}
725726

726-
// http://dev.mysql.com/doc/internals/en/prepared-statements.html#com-stmt-send-long-data
727+
// http://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
727728
func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error {
728729
maxLen := stmt.mc.maxPacketAllowed - 1
729730
pktLen := maxLen
@@ -785,14 +786,16 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
785786
)
786787
}
787788

789+
mc := stmt.mc
790+
788791
// Reset packet-sequence
789-
stmt.mc.sequence = 0
792+
mc.sequence = 0
790793

791794
var data []byte
792795

793796
if len(args) == 0 {
794797
const pktLen = 1 + 4 + 1 + 4
795-
data = stmt.mc.buf.writeBuffer(4 + pktLen)
798+
data = mc.buf.writeBuffer(4 + pktLen)
796799
if data == nil {
797800
// can not take the buffer. Something must be wrong with the connection
798801
errLog.Print("Busy buffer")
@@ -805,7 +808,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
805808
data[2] = byte(pktLen >> 16)
806809
data[3] = 0x00 // sequence is always 0
807810
} else {
808-
data = stmt.mc.buf.takeCompleteBuffer()
811+
data = mc.buf.takeCompleteBuffer()
809812
if data == nil {
810813
// can not take the buffer. Something must be wrong with the connection
811814
errLog.Print("Busy buffer")
@@ -902,7 +905,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
902905
paramTypes[i+i] = fieldTypeString
903906
paramTypes[i+i+1] = 0x00
904907

905-
if len(v) < stmt.mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
908+
if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
906909
paramValues = append(paramValues,
907910
lengthEncodedIntegerToBytes(uint64(len(v)))...,
908911
)
@@ -917,7 +920,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
917920
paramTypes[i+i] = fieldTypeString
918921
paramTypes[i+i+1] = 0x00
919922

920-
if len(v) < stmt.mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
923+
if len(v) < mc.maxPacketAllowed-pos-len(paramValues)-(len(args)-(i+1))*64 {
921924
paramValues = append(paramValues,
922925
lengthEncodedIntegerToBytes(uint64(len(v)))...,
923926
)
@@ -936,7 +939,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
936939
if v.IsZero() {
937940
val = []byte("0000-00-00")
938941
} else {
939-
val = []byte(v.In(stmt.mc.cfg.loc).Format(timeFormat))
942+
val = []byte(v.In(mc.cfg.loc).Format(timeFormat))
940943
}
941944

942945
paramValues = append(paramValues,
@@ -953,7 +956,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
953956
// In that case we must build the data packet with the new values buffer
954957
if valuesCap != cap(paramValues) {
955958
data = append(data[:pos], paramValues...)
956-
stmt.mc.buf.buf = data
959+
mc.buf.buf = data
957960
}
958961

959962
pos += len(paramValues)
@@ -965,18 +968,18 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
965968
data[0] = byte(pktLen)
966969
data[1] = byte(pktLen >> 8)
967970
data[2] = byte(pktLen >> 16)
968-
data[3] = stmt.mc.sequence
971+
data[3] = mc.sequence
969972

970973
// Convert nullMask to bytes
971-
for i, max := 14, 14+((stmt.paramCount+7)>>3); i < max; i++ {
972-
data[i] = byte(nullMask >> uint((i-14)<<3))
974+
for i, max := 0, (stmt.paramCount+7)>>3; i < max; i++ {
975+
data[i+14] = byte(nullMask >> uint(i<<3))
973976
}
974977
}
975978

976-
return stmt.mc.writePacket(data)
979+
return mc.writePacket(data)
977980
}
978981

979-
// http://dev.mysql.com/doc/internals/en/prepared-statements.html#packet-ProtocolBinary::ResultsetRow
982+
// http://dev.mysql.com/doc/internals/en/binary-protocol-resultset-row.html
980983
func (rows *mysqlRows) readBinaryRow(dest []driver.Value) error {
981984
data, err := rows.mc.readPacket()
982985
if err != nil {

rows.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,12 @@ type mysqlRows struct {
2626
eof bool
2727
}
2828

29-
func (rows *mysqlRows) Columns() (columns []string) {
30-
columns = make([]string, len(rows.columns))
29+
func (rows *mysqlRows) Columns() []string {
30+
columns := make([]string, len(rows.columns))
3131
for i := range columns {
3232
columns[i] = rows.columns[i].name
3333
}
34-
return
34+
return columns
3535
}
3636

3737
func (rows *mysqlRows) Close() (err error) {

0 commit comments

Comments
 (0)