@@ -2,11 +2,8 @@ package proxy
22
33import (
44 "errors"
5- "fmt"
65 "github.com/grepplabs/kafka-proxy/config"
76 "github.com/grepplabs/kafka-proxy/proxy/protocol"
8- "io"
9- "strconv"
107 "time"
118)
129
@@ -19,17 +16,17 @@ const (
1916 defaultReadTimeout = 30 * time .Second
2017 minOpenRequests = 16
2118
22- apiKeyUnset = int16 (- 1 ) // not in protocol
23- apiKeySaslAuth = int16 (- 2 ) // not in protocol
2419 apiKeySaslHandshake = int16 (17 )
2520
2621 minRequestApiKey = int16 (0 ) // 0 - Produce
2722 maxRequestApiKey = int16 (100 ) // so far 42 is the last (reserve some for the feature)
2823)
2924
3025var (
31- defaultRequestHandler = & DefaultRequestHandler {}
32- defaultResponseHandler = & DefaultResponseHandler {}
26+ defaultRequestHandler = & DefaultRequestHandler {}
27+ defaultResponseHandler = & DefaultResponseHandler {}
28+ saslAuthV0RequestHandler = & SaslAuthV0RequestHandler {}
29+ saslAuthV0ResponseHandler = & SaslAuthV0ResponseHandler {}
3330)
3431
3532type ProcessorConfig struct {
@@ -116,6 +113,9 @@ func (p *processor) RequestsLoop(dst DeadlineWriter, src DeadlineReaderWriter) (
116113 }
117114 }
118115 if p .localSasl .enabled {
116+ //TODO: when localSasl is enabled SASL is mandadory - we need authorized e.g. in RequestsLoopContext (mutex ?)
117+ //TODO: before SASL only ApiVersions is required
118+ //TODO: SASL can be done only once
119119 if err = p .localSasl .receiveAndSendSASLPlainAuth (src ); err != nil {
120120 return true , err
121121 }
@@ -129,9 +129,7 @@ func (p *processor) RequestsLoop(dst DeadlineWriter, src DeadlineReaderWriter) (
129129 timeout : p .writeTimeout ,
130130 brokerAddress : p .brokerAddress ,
131131 forbiddenApiKeys : p .forbiddenApiKeys ,
132- keyVersionBuf : make ([]byte , 8 ), // Size => int32 + ApiKey => int16 + ApiVersion => int16
133132 buf : make ([]byte , p .requestBufferSize ),
134- lastApiKey : apiKeyUnset ,
135133 }
136134
137135 return ctx .requestsLoop (dst , src )
@@ -145,11 +143,7 @@ type RequestsLoopContext struct {
145143 timeout time.Duration
146144 brokerAddress string
147145 forbiddenApiKeys map [int16 ]struct {}
148-
149- // bufSize int
150- keyVersionBuf []byte // 8 Size => int32 + ApiKey => int16 + ApiVersion => int16
151- buf []byte // bufSize
152- lastApiKey int16
146+ buf []byte // bufSize
153147}
154148
155149func (ctx * RequestsLoopContext ) nextHandlers (nextRequestHandler RequestHandler , nextResponseHandler ResponseHandler ) error {
@@ -193,93 +187,14 @@ func (r *RequestsLoopContext) requestsLoop(dst DeadlineWriter, src DeadlineReade
193187 }
194188}
195189
196- type DefaultRequestHandler struct {
197- }
198-
199- func (handler * DefaultRequestHandler ) handleRequest (dst DeadlineWriter , src DeadlineReader , ctx * RequestsLoopContext ) (readErr bool , err error ) {
200- if ctx .lastApiKey == apiKeySaslHandshake {
201- ctx .lastApiKey = apiKeySaslAuth
202- if readErr , err = copySaslAuthRequest (dst , src , ctx .timeout , ctx .buf ); err != nil {
203- return readErr , err
204- }
205- if err = ctx .nextHandlers (defaultRequestHandler , defaultResponseHandler ); err != nil {
206- return false , err
207- }
208- return false , nil
209- }
210- if len (ctx .keyVersionBuf ) != 8 {
211- return false , errors .New ("key version buf should have size 8" )
212- }
213- // logrus.Println("Await Kafka request")
214-
215- // waiting for first bytes or EOF - reset deadlines
216- src .SetReadDeadline (time.Time {})
217- dst .SetWriteDeadline (time.Time {})
218-
219- if _ , err = io .ReadFull (src , ctx .keyVersionBuf ); err != nil {
220- return true , err
221- }
222-
223- requestKeyVersion := & protocol.RequestKeyVersion {}
224- if err = protocol .Decode (ctx .keyVersionBuf , requestKeyVersion ); err != nil {
225- return true , err
226- }
227- //logrus.Printf("Kafka request length %v, key %v, version %v", requestKeyVersion.Length, requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion)
228-
229- if requestKeyVersion .ApiKey < minRequestApiKey || requestKeyVersion .ApiKey > maxRequestApiKey {
230- return true , fmt .Errorf ("api key %d is invalid" , requestKeyVersion .ApiKey )
231- }
232-
233- proxyRequestsTotal .WithLabelValues (ctx .brokerAddress , strconv .Itoa (int (requestKeyVersion .ApiKey )), strconv .Itoa (int (requestKeyVersion .ApiVersion ))).Inc ()
234- proxyRequestsBytes .WithLabelValues (ctx .brokerAddress ).Add (float64 (requestKeyVersion .Length + 4 ))
235-
236- if _ , ok := ctx .forbiddenApiKeys [requestKeyVersion .ApiKey ]; ok {
237- return true , fmt .Errorf ("api key %d is forbidden" , requestKeyVersion .ApiKey )
238- }
239-
240- // send inFlightRequest to channel before myCopyN to prevent race condition in proxyResponses
241- if err = sendRequestKeyVersion (ctx .openRequestsChannel , openRequestSendTimeout , requestKeyVersion ); err != nil {
242- return true , err
243- }
244-
245- requestDeadline := time .Now ().Add (ctx .timeout )
246- err = dst .SetWriteDeadline (requestDeadline )
247- if err != nil {
248- return false , err
249- }
250- err = src .SetReadDeadline (requestDeadline )
251- if err != nil {
252- return true , err
253- }
254-
255- // write - send to broker
256- if _ , err = dst .Write (ctx .keyVersionBuf ); err != nil {
257- return false , err
258- }
259- // 4 bytes were written as keyVersionBuf (ApiKey, ApiVersion)
260- if readErr , err = myCopyN (dst , src , int64 (requestKeyVersion .Length - 4 ), ctx .buf ); err != nil {
261- return readErr , err
262- }
263-
264- ctx .lastApiKey = requestKeyVersion .ApiKey
265-
266- if err = ctx .nextHandlers (defaultRequestHandler , defaultResponseHandler ); err != nil {
267- return false , err
268- }
269- return false , nil
270- }
271-
272190func (p * processor ) ResponsesLoop (dst DeadlineWriter , src DeadlineReader ) (readErr bool , err error ) {
273191 ctx := & ResponsesLoopContext {
274192 openRequestsChannel : p .openRequestsChannel ,
275193 nextResponseHandlerChannel : p .nextResponseHandlerChannel ,
276194 netAddressMappingFunc : p .netAddressMappingFunc ,
277195 timeout : p .readTimeout ,
278196 brokerAddress : p .brokerAddress ,
279-
280- responseHeaderBuf : make ([]byte , 8 ), // Size => int32, CorrelationId => int32
281- buf : make ([]byte , p .responseBufferSize ),
282- lastApiKey : apiKeyUnset ,
197+ buf : make ([]byte , p .responseBufferSize ),
283198 }
284199 return ctx .responsesLoop (dst , src )
285200}
@@ -290,10 +205,7 @@ type ResponsesLoopContext struct {
290205 netAddressMappingFunc config.NetAddressMappingFunc
291206 timeout time.Duration
292207 brokerAddress string
293-
294- responseHeaderBuf []byte // 8 - Size => int32, CorrelationId => int32
295- buf []byte // bufSize
296- lastApiKey int16
208+ buf []byte // bufSize
297209}
298210
299211type ResponseHandler interface {
@@ -309,128 +221,3 @@ func (r *ResponsesLoopContext) responsesLoop(dst DeadlineWriter, src DeadlineRea
309221 }
310222 }
311223}
312-
313- type DefaultResponseHandler struct {
314- }
315-
316- func (handler * DefaultResponseHandler ) handleResponse (dst DeadlineWriter , src DeadlineReader , ctx * ResponsesLoopContext ) (readErr bool , err error ) {
317- if ctx .lastApiKey == apiKeySaslHandshake {
318- ctx .lastApiKey = apiKeySaslAuth
319- if readErr , err = copySaslAuthResponse (dst , src , ctx .timeout ); err != nil {
320- return readErr , err
321- }
322- return false , nil // nextResponse
323- }
324- //logrus.Println("Await Kafka response")
325-
326- // waiting for first bytes or EOF - reset deadlines
327- src .SetReadDeadline (time.Time {})
328- dst .SetWriteDeadline (time.Time {})
329-
330- if len (ctx .responseHeaderBuf ) != 8 {
331- return false , errors .New ("response header buf should have size 8" )
332- }
333-
334- if _ , err = io .ReadFull (src , ctx .responseHeaderBuf ); err != nil {
335- return true , err
336- }
337-
338- var responseHeader protocol.ResponseHeader
339- if err = protocol .Decode (ctx .responseHeaderBuf , & responseHeader ); err != nil {
340- return true , err
341- }
342-
343- // Read the inFlightRequests channel after header is read. Otherwise the channel would block and socket EOF from remote would not be received.
344- requestKeyVersion , err := receiveRequestKeyVersion (ctx .openRequestsChannel , openRequestReceiveTimeout )
345- if err != nil {
346- return true , err
347- }
348- proxyResponsesBytes .WithLabelValues (ctx .brokerAddress ).Add (float64 (responseHeader .Length + 4 ))
349- //logrus.Printf("Kafka response lenght %v for key %v, version %v", responseHeader.Length, requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion)
350-
351- responseDeadline := time .Now ().Add (ctx .timeout )
352- err = dst .SetWriteDeadline (responseDeadline )
353- if err != nil {
354- return false , err
355- }
356- err = src .SetReadDeadline (responseDeadline )
357- if err != nil {
358- return true , err
359- }
360-
361- responseModifier , err := protocol .GetResponseModifier (requestKeyVersion .ApiKey , requestKeyVersion .ApiVersion , ctx .netAddressMappingFunc )
362- if err != nil {
363- return true , err
364- }
365- if responseModifier != nil {
366- if int32 (responseHeader .Length ) > protocol .MaxResponseSize {
367- return true , protocol.PacketDecodingError {Info : fmt .Sprintf ("message of length %d too large" , responseHeader .Length )}
368- }
369- resp := make ([]byte , int (responseHeader .Length - 4 ))
370- if _ , err = io .ReadFull (src , resp ); err != nil {
371- return true , err
372- }
373- newResponseBuf , err := responseModifier .Apply (resp )
374- if err != nil {
375- return true , err
376- }
377- // add 4 bytes (CorrelationId) to the length
378- newHeaderBuf , err := protocol .Encode (& protocol.ResponseHeader {Length : int32 (len (newResponseBuf ) + 4 ), CorrelationID : responseHeader .CorrelationID })
379- if err != nil {
380- return true , err
381- }
382- if _ , err := dst .Write (newHeaderBuf ); err != nil {
383- return false , err
384- }
385- if _ , err := dst .Write (newResponseBuf ); err != nil {
386- return false , err
387- }
388- } else {
389- // write - send to local
390- if _ , err := dst .Write (ctx .responseHeaderBuf ); err != nil {
391- return false , err
392- }
393- // 4 bytes were written as responseHeaderBuf (CorrelationId)
394- if readErr , err = myCopyN (dst , src , int64 (responseHeader .Length - 4 ), ctx .buf ); err != nil {
395- return readErr , err
396- }
397- }
398- ctx .lastApiKey = requestKeyVersion .ApiKey
399-
400- return false , nil // continue nextResponse
401- }
402-
403- func sendRequestKeyVersion (openRequestsChannel chan <- protocol.RequestKeyVersion , timeout time.Duration , request * protocol.RequestKeyVersion ) error {
404- select {
405- case openRequestsChannel <- * request :
406- default :
407- // timer.Stop() will be invoked only after sendRequestKeyVersion is finished (not after select default) !
408- timer := time .NewTimer (timeout )
409- defer timer .Stop ()
410-
411- select {
412- case openRequestsChannel <- * request :
413- case <- timer .C :
414- return errors .New ("open requests buffer is full" )
415- }
416- }
417- return nil
418- }
419-
420- func receiveRequestKeyVersion (openRequestsChannel <- chan protocol.RequestKeyVersion , timeout time.Duration ) (* protocol.RequestKeyVersion , error ) {
421- var request protocol.RequestKeyVersion
422- select {
423- case request = <- openRequestsChannel :
424- default :
425- // timer.Stop() will be invoked only after receiveRequestKeyVersion is finished (not after select default) !
426- timer := time .NewTimer (timeout )
427- defer timer .Stop ()
428-
429- select {
430- case request = <- openRequestsChannel :
431- case <- timer .C :
432- return nil , errors .New ("open request is missing" )
433- }
434- }
435- return & request , nil
436- }
0 commit comments