Skip to content

Commit a3fcad9

Browse files
committed
made writeExecutePacket independent of number of arguments
1 parent 9247ef8 commit a3fcad9

File tree

2 files changed

+29
-17
lines changed

2 files changed

+29
-17
lines changed

driver_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1211,7 +1211,7 @@ func TestStmtMultiRows(t *testing.T) {
12111211
}
12121212

12131213
func TestPreparedManyCols(t *testing.T) {
1214-
const repetitions = 1024
1214+
const repetitions = 32 // defaultBufSize
12151215
runTests(t, dsn, func(dbt *DBTest) {
12161216
query := "SELECT ?" + strings.Repeat(",?", repetitions-1)
12171217
values := make([]sql.NullString, repetitions)

packets.go

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -750,6 +750,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
750750
)
751751
}
752752

753+
const minPktLen = 4 + 1 + 4 + 1 + 4
753754
mc := stmt.mc
754755

755756
// Reset packet-sequence
@@ -758,7 +759,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
758759
var data []byte
759760

760761
if len(args) == 0 {
761-
data = mc.buf.takeBuffer(4 + 1 + 4 + 1 + 4)
762+
data = mc.buf.takeBuffer(minPktLen)
762763
} else {
763764
data = mc.buf.takeCompleteBuffer()
764765
}
@@ -787,34 +788,50 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
787788
data[13] = 0x00
788789

789790
if len(args) > 0 {
791+
pos := minPktLen
790792
// NULL-bitmap [(len(args)+7)/8 bytes]
791-
nullMask := uint64(0)
792-
793-
pos := 4 + 1 + 4 + 1 + 4 + ((len(args) + 7) >> 3)
793+
var nullMask []byte
794+
if maskLen, typesLen := (len(args)+7)/8, 1+2*len(args); pos+maskLen+typesLen >= len(data) {
795+
// buffer has to be extended but we don't know by how much
796+
// so we depend on append after nullMask fits.
797+
// The default size didn't suffice and we have to deal with a lot of columns,
798+
// so allocation size is hard to guess.
799+
tmp := make([]byte, pos+maskLen+typesLen)
800+
copy(tmp[:pos], data[:pos])
801+
data = tmp
802+
nullMask = data[pos : pos+maskLen]
803+
pos += maskLen
804+
} else {
805+
nullMask = data[pos : pos+maskLen]
806+
for i := 0; i < maskLen; i++ {
807+
nullMask[i] = 0
808+
}
809+
pos += maskLen
810+
}
794811

795812
// newParameterBoundFlag 1 [1 byte]
796813
data[pos] = 0x01
797814
pos++
798815

799816
// type of each parameter [len(args)*2 bytes]
800817
paramTypes := data[pos:]
801-
pos += (len(args) << 1)
818+
pos += len(args) * 2
802819

803820
// value of each parameter [n bytes]
804821
paramValues := data[pos:pos]
805822
valuesCap := cap(paramValues)
806823

807-
for i := range args {
824+
for i, arg := range args {
808825
// build NULL-bitmap
809-
if args[i] == nil {
810-
nullMask |= 1 << uint(i)
826+
if arg == nil {
827+
nullMask[i/8] |= 1 << (uint(i) & 7) // |= 1 << uint(i)
811828
paramTypes[i+i] = fieldTypeNULL
812829
paramTypes[i+i+1] = 0x00
813830
continue
814831
}
815832

816833
// cache types and values
817-
switch v := args[i].(type) {
834+
switch v := arg.(type) {
818835
case int64:
819836
paramTypes[i+i] = fieldTypeLongLong
820837
paramTypes[i+i+1] = 0x00
@@ -877,7 +894,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
877894
}
878895

879896
// Handle []byte(nil) as a NULL value
880-
nullMask |= 1 << uint(i)
897+
nullMask[i/8] |= 1 << (uint(i) & 7) // |= 1 << uint(i)
881898
paramTypes[i+i] = fieldTypeNULL
882899
paramTypes[i+i+1] = 0x00
883900

@@ -913,7 +930,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
913930
paramValues = append(paramValues, val...)
914931

915932
default:
916-
return fmt.Errorf("Can't convert type: %T", args[i])
933+
return fmt.Errorf("Can't convert type: %T", arg)
917934
}
918935
}
919936

@@ -926,11 +943,6 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error {
926943

927944
pos += len(paramValues)
928945
data = data[:pos]
929-
930-
// Convert nullMask to bytes
931-
for i, max := 0, (stmt.paramCount+7)>>3; i < max; i++ {
932-
data[i+14] = byte(nullMask >> uint(i<<3))
933-
}
934946
}
935947

936948
return mc.writePacket(data)

0 commit comments

Comments
 (0)