11package proxy
22
33import (
4+ "bytes"
45 "errors"
56 "fmt"
67 "github.com/grepplabs/kafka-proxy/proxy/protocol"
@@ -20,8 +21,12 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead
2021 // logrus.Println("Await Kafka request")
2122
2223 // waiting for first bytes or EOF - reset deadlines
23- src .SetReadDeadline (time.Time {})
24- dst .SetWriteDeadline (time.Time {})
24+ if err = src .SetReadDeadline (time.Time {}); err != nil {
25+ return true , err
26+ }
27+ if err = dst .SetWriteDeadline (time.Time {}); err != nil {
28+ return true , err
29+ }
2530
2631 keyVersionBuf := make ([]byte , 8 ) // Size => int32 + ApiKey => int16 + ApiVersion => int16
2732
@@ -67,8 +72,9 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead
6772 return true , fmt .Errorf ("only saslHandshake version 0 and 1 are supported, got version %d" , requestKeyVersion .ApiVersion )
6873 }
6974 ctx .localSaslDone = true
70- src .SetDeadline (time.Time {})
71-
75+ if err = src .SetDeadline (time.Time {}); err != nil {
76+ return false , err
77+ }
7278 // defaultRequestHandler was consumed but due to local handling enqueued defaultResponseHandler will not be.
7379 return false , ctx .putNextRequestHandler (defaultRequestHandler )
7480 case apiKeyApiApiVersions :
@@ -79,11 +85,18 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead
7985 }
8086 }
8187
82- // send inFlightRequest to channel before myCopyN to prevent race condition in proxyResponses
83- if err = sendRequestKeyVersion ( ctx . openRequestsChannel , openRequestSendTimeout , requestKeyVersion ); err != nil {
88+ mustReply , readBytes , err := handler . mustReply ( requestKeyVersion , src , ctx )
89+ if err != nil {
8490 return true , err
8591 }
8692
93+ // send inFlightRequest to channel before myCopyN to prevent race condition in proxyResponses
94+ if mustReply {
95+ if err = sendRequestKeyVersion (ctx .openRequestsChannel , openRequestSendTimeout , requestKeyVersion ); err != nil {
96+ return true , err
97+ }
98+ }
99+
87100 requestDeadline := time .Now ().Add (ctx .timeout )
88101 err = dst .SetWriteDeadline (requestDeadline )
89102 if err != nil {
@@ -98,24 +111,82 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead
98111 if _ , err = dst .Write (keyVersionBuf ); err != nil {
99112 return false , err
100113 }
114+ // write - send to broker
115+ if len (readBytes ) > 0 {
116+ if _ , err = dst .Write (readBytes ); err != nil {
117+ return false , err
118+ }
119+ }
101120 // 4 bytes were written as keyVersionBuf (ApiKey, ApiVersion)
102- if readErr , err = myCopyN (dst , src , int64 (requestKeyVersion .Length - 4 ), ctx .buf ); err != nil {
121+ if readErr , err = myCopyN (dst , src , int64 (requestKeyVersion .Length - int32 ( 4 + len ( readBytes )) ), ctx .buf ); err != nil {
103122 return readErr , err
104123 }
105124 if requestKeyVersion .ApiKey == apiKeySaslHandshake {
106125 if requestKeyVersion .ApiVersion == 0 {
107126 return false , ctx .putNextHandlers (saslAuthV0RequestHandler , saslAuthV0ResponseHandler )
108127 }
109128 }
110- return false , ctx .putNextHandlers (defaultRequestHandler , defaultResponseHandler )
129+ if mustReply {
130+ return false , ctx .putNextHandlers (defaultRequestHandler , defaultResponseHandler )
131+ } else {
132+ return false , ctx .putNextRequestHandler (defaultRequestHandler )
133+ }
134+ }
135+
136+ func (handler * DefaultRequestHandler ) mustReply (requestKeyVersion * protocol.RequestKeyVersion , src io.Reader , ctx * RequestsLoopContext ) (bool , []byte , error ) {
137+ if requestKeyVersion .ApiKey == apiKeyProduce {
138+ if ctx .producerAcks0Disabled {
139+ return true , nil , nil
140+ }
141+ // header version for produce [0..8] is 1 (request_api_key,request_api_version,correlation_id (INT32),client_id, NULLABLE_STRING )
142+ acksReader := protocol.RequestAcksReader {}
143+
144+ var (
145+ acks int16
146+ err error
147+ )
148+ var bufferRead bytes.Buffer
149+ reader := io .TeeReader (src , & bufferRead )
150+ switch requestKeyVersion .ApiVersion {
151+ case 0 , 1 , 2 :
152+ // CorrelationID + ClientID
153+ if err = acksReader .ReadAndDiscardHeaderV1Part (reader ); err != nil {
154+ return false , nil , err
155+ }
156+ // acks (INT16)
157+ acks , err = acksReader .ReadAndDiscardProduceAcks (reader )
158+ if err != nil {
159+ return false , nil , err
160+ }
161+
162+ case 3 , 4 , 5 , 6 , 7 , 8 :
163+ // CorrelationID + ClientID
164+ if err = acksReader .ReadAndDiscardHeaderV1Part (reader ); err != nil {
165+ return false , nil , err
166+ }
167+ // transactional_id (NULLABLE_STRING),acks (INT16)
168+ acks , err = acksReader .ReadAndDiscardProduceTxnAcks (reader )
169+ if err != nil {
170+ return false , nil , err
171+ }
172+ default :
173+ return false , nil , fmt .Errorf ("produce version %d is not supported" , requestKeyVersion .ApiVersion )
174+ }
175+ return acks != 0 , bufferRead .Bytes (), nil
176+ }
177+ return true , nil , nil
111178}
112179
113180func (handler * DefaultResponseHandler ) handleResponse (dst DeadlineWriter , src DeadlineReader , ctx * ResponsesLoopContext ) (readErr bool , err error ) {
114181 //logrus.Println("Await Kafka response")
115182
116183 // waiting for first bytes or EOF - reset deadlines
117- src .SetReadDeadline (time.Time {})
118- dst .SetWriteDeadline (time.Time {})
184+ if err = src .SetReadDeadline (time.Time {}); err != nil {
185+ return true , err
186+ }
187+ if err = dst .SetWriteDeadline (time.Time {}); err != nil {
188+ return true , err
189+ }
119190
120191 responseHeaderBuf := make ([]byte , 8 ) // Size => int32, CorrelationId => int32
121192 if _ , err = io .ReadFull (src , responseHeaderBuf ); err != nil {
0 commit comments