1818package internal
1919
2020import (
21+ "context"
2122 "crypto/tls"
2223 "crypto/x509"
2324 "errors"
@@ -78,7 +79,7 @@ type ConnectionListener interface {
7879type Connection interface {
7980 SendRequest (requestID uint64 , req * pb.BaseCommand , callback func (* pb.BaseCommand , error ))
8081 SendRequestNoWait (req * pb.BaseCommand ) error
81- WriteData (data Buffer )
82+ WriteData (ctx context. Context , data Buffer )
8283 RegisterListener (id uint64 , listener ConnectionListener ) error
8384 UnregisterListener (id uint64 )
8485 AddConsumeHandler (id uint64 , handler ConsumerHandler ) error
@@ -129,6 +130,11 @@ type request struct {
129130 callback func (command * pb.BaseCommand , err error )
130131}
131132
133+ type dataRequest struct {
134+ ctx context.Context
135+ data Buffer
136+ }
137+
132138type connection struct {
133139 started int32
134140 connectionTimeout time.Duration
@@ -157,7 +163,7 @@ type connection struct {
157163 incomingRequestsCh chan * request
158164 closeCh chan struct {}
159165 readyCh chan struct {}
160- writeRequestsCh chan Buffer
166+ writeRequestsCh chan * dataRequest
161167
162168 pendingLock sync.Mutex
163169 pendingReqs map [uint64 ]* request
@@ -209,7 +215,7 @@ func newConnection(opts connectionOptions) *connection {
209215 // partition produces writing on a single connection. In general it's
210216 // good to keep this above the number of partition producers assigned
211217 // to a single connection.
212- writeRequestsCh : make (chan Buffer , 256 ),
218+ writeRequestsCh : make (chan * dataRequest , 256 ),
213219 listeners : make (map [uint64 ]ConnectionListener ),
214220 consumerHandlers : make (map [uint64 ]ConsumerHandler ),
215221 metrics : opts .metrics ,
@@ -421,11 +427,11 @@ func (c *connection) run() {
421427 return // TODO: this never gonna be happen
422428 }
423429 c .internalSendRequest (req )
424- case data := <- c .writeRequestsCh :
425- if data == nil {
430+ case req := <- c .writeRequestsCh :
431+ if req == nil {
426432 return
427433 }
428- c .internalWriteData (data )
434+ c .internalWriteData (req . ctx , req . data )
429435
430436 case <- pingSendTicker .C :
431437 c .sendPing ()
@@ -450,22 +456,26 @@ func (c *connection) runPingCheck(pingCheckTicker *time.Ticker) {
450456 }
451457}
452458
453- func (c * connection ) WriteData (data Buffer ) {
459+ func (c * connection ) WriteData (ctx context. Context , data Buffer ) {
454460 select {
455- case c .writeRequestsCh <- data :
461+ case c .writeRequestsCh <- & dataRequest { ctx : ctx , data : data } :
456462 // Channel is not full
457463 return
458-
464+ case <- ctx .Done ():
465+ c .log .Debug ("Write data context cancelled" )
466+ return
459467 default :
460468 // Channel full, fallback to probe if connection is closed
461469 }
462470
463471 for {
464472 select {
465- case c .writeRequestsCh <- data :
473+ case c .writeRequestsCh <- & dataRequest { ctx : ctx , data : data } :
466474 // Successfully wrote on the channel
467475 return
468-
476+ case <- ctx .Done ():
477+ c .log .Debug ("Write data context cancelled" )
478+ return
469479 case <- time .After (100 * time .Millisecond ):
470480 // The channel is either:
471481 // 1. blocked, in which case we need to wait until we have space
@@ -481,11 +491,17 @@ func (c *connection) WriteData(data Buffer) {
481491
482492}
483493
484- func (c * connection ) internalWriteData (data Buffer ) {
494+ func (c * connection ) internalWriteData (ctx context. Context , data Buffer ) {
485495 c .log .Debug ("Write data: " , data .ReadableBytes ())
486- if _ , err := c .cnx .Write (data .ReadableSlice ()); err != nil {
487- c .log .WithError (err ).Warn ("Failed to write on connection" )
488- c .Close ()
496+
497+ select {
498+ case <- ctx .Done ():
499+ return
500+ default :
501+ if _ , err := c .cnx .Write (data .ReadableSlice ()); err != nil {
502+ c .log .WithError (err ).Warn ("Failed to write on connection" )
503+ c .Close ()
504+ }
489505 }
490506}
491507
@@ -510,7 +526,7 @@ func (c *connection) writeCommand(cmd *pb.BaseCommand) {
510526 }
511527
512528 c .writeBuffer .WrittenBytes (cmdSize )
513- c .internalWriteData (c .writeBuffer )
529+ c .internalWriteData (context . Background (), c .writeBuffer )
514530}
515531
516532func (c * connection ) receivedCommand (cmd * pb.BaseCommand , headersAndPayload Buffer ) {
0 commit comments