Skip to content

Commit f1e31a5

Browse files
worryg0djoao-r-reis
authored andcommitted
Protocol version negotiation doesn't work if server replies with stream id different than 0
Previously, protocol negotiation didn't work properly when C* was responding with stream id different from 0. This patch changes the way protocol negotiation works. Instead of parsing a supported protocol version from C* error response, the driver tries to connect with each supported protocol starting from the latest. Patch by Bohdan Siryk; Reviewed by João Reis for CASSGO-98
1 parent 0089073 commit f1e31a5

File tree

9 files changed

+442
-115
lines changed

9 files changed

+442
-115
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1717
- Prevent panic with queries during session init (CASSGO-92)
1818
- Return correct values from RowData (CASSGO-95)
1919
- Prevent setting a compression flag in a frame header when native proto v5 is being used (CASSGO-98)
20+
- Use protocol downgrading approach during protocol negotiation (CASSGO-97)
2021

2122
## [2.0.0]
2223

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ test-integration-auth: .prepare-cassandra-cluster
100100
test-unit:
101101
@echo "Run unit tests"
102102
@go clean -testcache
103-
go test -tags unit -timeout=5m -race ./...
103+
go test -v -tags unit -timeout=5m -race ./...
104104

105105
check: .prepare-golangci
106106
@echo "Build"

conn.go

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,13 @@ func (s *startupCoordinator) setupConn(ctx context.Context) error {
378378
select {
379379
case err := <-startupErr:
380380
if err != nil {
381+
if s.checkProtocolRelatedError(err) {
382+
return &unsupportedProtocolVersionError{
383+
err: err,
384+
hostInfo: s.conn.host,
385+
version: protoVersion(s.conn.version),
386+
}
387+
}
381388
return err
382389
}
383390
case <-ctx.Done():
@@ -387,6 +394,38 @@ func (s *startupCoordinator) setupConn(ctx context.Context) error {
387394
return nil
388395
}
389396

397+
// Checks if the error is protocol related and should be retried during startup.
398+
// It returns the frame that caused the error and whether the error should be retried.
399+
func (s *startupCoordinator) checkProtocolRelatedError(err error) bool {
400+
var unwrappedFrame frame
401+
402+
var protocolErr *protocolError
403+
if !errors.As(err, &protocolErr) {
404+
var errFrame errorFrame
405+
if !errors.As(err, &errFrame) {
406+
return false
407+
} else {
408+
unwrappedFrame = errFrame
409+
}
410+
} else {
411+
unwrappedFrame = protocolErr.frame
412+
}
413+
414+
switch frame := unwrappedFrame.(type) {
415+
case *supportedFrame:
416+
// We can receive a supportedFrame wrapped in protocolError from Conn.recv if the host responds to a 0 stream id.
417+
// If we receive a supportedFrame then we know that the host is not compatible with the protocol version, but it is reachable, so we can retry
418+
return true
419+
case errorFrame:
420+
// If we receive an errorFrame with codes ErrCodeProtocol or ErrCodeServer,
421+
// then we should try to downgrade a protocol version, so do not skip the host
422+
return frame.code == ErrCodeProtocol || frame.code == ErrCodeServer
423+
default:
424+
// In any other case we should not retry as it means the host is not reachable or some other error happened
425+
return false
426+
}
427+
}
428+
390429
func (s *startupCoordinator) write(ctx context.Context, frame frameBuilder, startupCompleted *atomic.Bool) (frame, error) {
391430
select {
392431
case s.frameTicker <- struct{}{}:
@@ -408,12 +447,14 @@ func (s *startupCoordinator) options(ctx context.Context, startupCompleted *atom
408447
return err
409448
}
410449

411-
supported, ok := frame.(*supportedFrame)
412-
if !ok {
413-
return NewErrProtocol("Unknown type of response to startup frame: %T", frame)
450+
switch frame := frame.(type) {
451+
case *supportedFrame:
452+
return s.startup(ctx, frame.supported, startupCompleted)
453+
case error:
454+
return frame
455+
default:
456+
return NewErrProtocol("Unknown type of response to startup frame: %T (frame=%s)", frame, frame.String())
414457
}
415-
416-
return s.startup(ctx, supported.supported, startupCompleted)
417458
}
418459

419460
func (s *startupCoordinator) startup(ctx context.Context, supported map[string][]string, startupCompleted *atomic.Bool) error {

conn_test.go

Lines changed: 67 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,6 +1054,9 @@ type newTestServerOpts struct {
10541054
addr string
10551055
protocol uint8
10561056
recvHook func(*framer)
1057+
1058+
customRequestHandler func(srv *TestServer, reqFrame, respFrame *framer) error
1059+
dontFailOnProtocolMismatch bool
10571060
}
10581061

10591062
func (nts newTestServerOpts) newServer(t testing.TB, ctx context.Context) *TestServer {
@@ -1078,6 +1081,9 @@ func (nts newTestServerOpts) newServer(t testing.TB, ctx context.Context) *TestS
10781081
cancel: cancel,
10791082

10801083
onRecv: nts.recvHook,
1084+
1085+
customRequestHandler: nts.customRequestHandler,
1086+
dontFailOnProtocolMismatch: nts.dontFailOnProtocolMismatch,
10811087
}
10821088

10831089
go srv.closeWatch()
@@ -1142,6 +1148,10 @@ type TestServer struct {
11421148

11431149
// onRecv is a hook point for tests, called in receive loop.
11441150
onRecv func(*framer)
1151+
1152+
// customRequestHandler allows overriding the default request handling for testing purposes.
1153+
customRequestHandler func(srv *TestServer, reqFrame, respFrame *framer) error
1154+
dontFailOnProtocolMismatch bool
11451155
}
11461156

11471157
func (srv *TestServer) closeWatch() {
@@ -1162,9 +1172,26 @@ func (srv *TestServer) serve() {
11621172
}
11631173

11641174
go func(conn net.Conn) {
1175+
var startupCompleted bool
1176+
var useProtoV5 bool
1177+
11651178
defer conn.Close()
11661179
for !srv.isClosed() {
1167-
framer, err := srv.readFrame(conn)
1180+
var reader io.Reader = conn
1181+
1182+
if useProtoV5 && startupCompleted {
1183+
frame, _, err := readUncompressedSegment(conn)
1184+
if err != nil {
1185+
if errors.Is(err, io.EOF) {
1186+
return
1187+
}
1188+
srv.errorLocked(err)
1189+
return
1190+
}
1191+
reader = bytes.NewReader(frame)
1192+
}
1193+
1194+
framer, err := srv.readFrame(reader)
11681195
if err != nil {
11691196
if err == io.EOF {
11701197
return
@@ -1177,7 +1204,7 @@ func (srv *TestServer) serve() {
11771204
srv.onRecv(framer)
11781205
}
11791206

1180-
go srv.process(conn, framer)
1207+
srv.process(conn, framer, &useProtoV5, &startupCompleted)
11811208
}
11821209
}(conn)
11831210
}
@@ -1215,13 +1242,22 @@ func (srv *TestServer) errorLocked(err interface{}) {
12151242
srv.t.Error(err)
12161243
}
12171244

1218-
func (srv *TestServer) process(conn net.Conn, reqFrame *framer) {
1245+
func (srv *TestServer) process(conn net.Conn, reqFrame *framer, useProtoV5, startupCompleted *bool) {
12191246
head := reqFrame.header
12201247
if head == nil {
12211248
srv.errorLocked("process frame with a nil header")
12221249
return
12231250
}
1224-
respFrame := newFramer(nil, reqFrame.proto, GlobalTypes)
1251+
respFrame := newFramer(nil, byte(head.version), GlobalTypes)
1252+
1253+
if srv.customRequestHandler != nil {
1254+
if err := srv.customRequestHandler(srv, reqFrame, respFrame); err != nil {
1255+
srv.errorLocked(err)
1256+
return
1257+
}
1258+
// Dont like this but...
1259+
goto finish
1260+
}
12251261

12261262
switch head.op {
12271263
case opStartup:
@@ -1412,34 +1448,54 @@ func (srv *TestServer) process(conn net.Conn, reqFrame *framer) {
14121448
respFrame.writeString("not supported")
14131449
}
14141450

1415-
respFrame.buf[0] = srv.protocol | 0x80
1451+
finish:
1452+
1453+
respFrame.buf[0] |= 0x80
14161454

14171455
if err := respFrame.finish(); err != nil {
14181456
srv.errorLocked(err)
14191457
}
14201458

1421-
if err := respFrame.writeTo(conn); err != nil {
1422-
srv.errorLocked(err)
1459+
if *useProtoV5 && *startupCompleted {
1460+
segment, err := newUncompressedSegment(respFrame.buf, true)
1461+
if err == nil {
1462+
_, err = conn.Write(segment)
1463+
}
1464+
if err != nil {
1465+
srv.errorLocked(err)
1466+
return
1467+
}
1468+
} else {
1469+
if err := respFrame.writeTo(conn); err != nil {
1470+
srv.errorLocked(err)
1471+
}
1472+
1473+
if reqFrame.header.op == opStartup {
1474+
*startupCompleted = true
1475+
if head.version == protoVersion5 {
1476+
*useProtoV5 = true
1477+
}
1478+
}
14231479
}
14241480
}
14251481

1426-
func (srv *TestServer) readFrame(conn net.Conn) (*framer, error) {
1482+
func (srv *TestServer) readFrame(reader io.Reader) (*framer, error) {
14271483
buf := make([]byte, srv.headerSize)
1428-
head, err := readHeader(conn, buf)
1484+
head, err := readHeader(reader, buf)
14291485
if err != nil {
14301486
return nil, err
14311487
}
14321488
framer := newFramer(nil, srv.protocol, GlobalTypes)
14331489

1434-
err = framer.readFrame(conn, &head)
1490+
err = framer.readFrame(reader, &head)
14351491
if err != nil {
14361492
return nil, err
14371493
}
14381494

14391495
// should be a request frame
14401496
if head.version.response() {
14411497
return nil, fmt.Errorf("expected to read a request frame got version: %v", head.version)
1442-
} else if head.version.version() != srv.protocol {
1498+
} else if !srv.dontFailOnProtocolMismatch && head.version.version() != srv.protocol {
14431499
return nil, fmt.Errorf("expected to read protocol version 0x%x got 0x%x", srv.protocol, head.version.version())
14441500
}
14451501

control.go

Lines changed: 38 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ import (
3232
"math/rand"
3333
"net"
3434
"os"
35-
"regexp"
3635
"strconv"
3736
"sync"
3837
"sync/atomic"
@@ -202,56 +201,9 @@ func shuffleHosts(hosts []*HostInfo) []*HostInfo {
202201
return shuffled
203202
}
204203

205-
// this is going to be version dependant and a nightmare to maintain :(
206-
var protocolSupportRe = regexp.MustCompile(`the lowest supported version is \d+ and the greatest is (\d+)$`)
207-
var betaProtocolRe = regexp.MustCompile(`Beta version of the protocol used \(.*\), but USE_BETA flag is unset`)
208-
209-
func parseProtocolFromError(err error) int {
210-
errStr := err.Error()
211-
212-
var errProtocol ErrProtocol
213-
if errors.As(err, &errProtocol) {
214-
err = errProtocol.error
215-
}
216-
217-
// I really wish this had the actual info in the error frame...
218-
matches := betaProtocolRe.FindAllStringSubmatch(errStr, -1)
219-
if len(matches) == 1 {
220-
var protoErr *protocolError
221-
if errors.As(err, &protoErr) {
222-
version := protoErr.frame.Header().version.version()
223-
if version > 0 {
224-
return int(version - 1)
225-
}
226-
}
227-
return 0
228-
}
229-
230-
matches = protocolSupportRe.FindAllStringSubmatch(errStr, -1)
231-
if len(matches) != 1 || len(matches[0]) != 2 {
232-
var protoErr *protocolError
233-
if errors.As(err, &protoErr) {
234-
return int(protoErr.frame.Header().version.version())
235-
}
236-
return 0
237-
}
238-
239-
max, err := strconv.Atoi(matches[0][1])
240-
if err != nil {
241-
return 0
242-
}
243-
244-
return max
245-
}
246-
247-
const highestProtocolVersionSupported = 5
248-
249204
func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
250205
hosts = shuffleHosts(hosts)
251206

252-
connCfg := *c.session.connCfg
253-
connCfg.ProtoVersion = highestProtocolVersionSupported
254-
255207
handler := connErrorHandlerFn(func(c *Conn, err error, closed bool) {
256208
// we should never get here, but if we do it means we connected to a
257209
// host successfully which means our attempted protocol version worked
@@ -261,30 +213,56 @@ func (c *controlConn) discoverProtocol(hosts []*HostInfo) (int, error) {
261213
})
262214

263215
var err error
216+
var proto int
264217
for _, host := range hosts {
265-
var conn *Conn
266-
conn, err = c.session.dial(c.session.ctx, host, &connCfg, handler)
218+
proto, err = c.tryProtocolVersionsForHost(host, handler)
219+
if err == nil {
220+
return proto, nil
221+
}
222+
223+
c.session.logger.Debug("Failed to discover protocol version for host.",
224+
NewLogFieldIP("host_addr", host.ConnectAddress()),
225+
NewLogFieldError("err", err))
226+
}
227+
228+
return 0, err
229+
}
230+
231+
func (c *controlConn) tryProtocolVersionsForHost(host *HostInfo, handler ConnErrorHandler) (int, error) {
232+
connCfg := *c.session.connCfg
233+
234+
var triedVersions []int
235+
236+
for proto := highestProtocolVersionSupported; proto >= lowestProtocolVersionSupported; proto-- {
237+
connCfg.ProtoVersion = proto
238+
239+
conn, err := c.session.dial(c.session.ctx, host, &connCfg, handler)
267240
if conn != nil {
268241
conn.Close()
269242
}
270243

271244
if err == nil {
272-
c.session.logger.Debug("Discovered protocol version using host.",
273-
NewLogFieldInt("protocol_version", connCfg.ProtoVersion), NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID()))
274-
return connCfg.ProtoVersion, nil
245+
return proto, nil
275246
}
276247

277-
if proto := parseProtocolFromError(err); proto > 0 {
278-
c.session.logger.Debug("Discovered protocol version using host after parsing protocol error.",
279-
NewLogFieldInt("protocol_version", proto), NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID()))
280-
return proto, nil
248+
var unsupportedErr *unsupportedProtocolVersionError
249+
if errors.As(err, &unsupportedErr) {
250+
// the host does not support this protocol version, try a lower version
251+
c.session.logger.Debug("Failed to connect to host during protocol negotiation.",
252+
NewLogFieldIP("host_addr", host.ConnectAddress()),
253+
NewLogFieldInt("proto_version", proto),
254+
NewLogFieldError("err", err))
255+
triedVersions = append(triedVersions, connCfg.ProtoVersion)
256+
continue
281257
}
282258

283-
c.session.logger.Debug("Failed to discover protocol version using host.",
284-
NewLogFieldIP("host_addr", host.ConnectAddress()), NewLogFieldString("host_id", host.HostID()), NewLogFieldError("err", err))
259+
c.session.logger.Debug("Error connecting to host during protocol negotiation.",
260+
NewLogFieldIP("host_addr", host.ConnectAddress()),
261+
NewLogFieldError("err", err))
262+
return 0, err
285263
}
286264

287-
return 0, err
265+
return 0, fmt.Errorf("gocql: failed to discover protocol version for host %s, tried versions: %v", host.ConnectAddress(), triedVersions)
288266
}
289267

290268
func (c *controlConn) connect(hosts []*HostInfo, sessionInit bool) error {

0 commit comments

Comments
 (0)