Skip to content

Commit 828dbe8

Browse files
committed
Add SaslAuthV0 handlers
1 parent 06a7da1 commit 828dbe8

File tree

3 files changed

+228
-223
lines changed

3 files changed

+228
-223
lines changed

proxy/processor.go

Lines changed: 10 additions & 223 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,8 @@ package proxy
22

33
import (
44
"errors"
5-
"fmt"
65
"github.com/grepplabs/kafka-proxy/config"
76
"github.com/grepplabs/kafka-proxy/proxy/protocol"
8-
"io"
9-
"strconv"
107
"time"
118
)
129

@@ -19,17 +16,17 @@ const (
1916
defaultReadTimeout = 30 * time.Second
2017
minOpenRequests = 16
2118

22-
apiKeyUnset = int16(-1) // not in protocol
23-
apiKeySaslAuth = int16(-2) // not in protocol
2419
apiKeySaslHandshake = int16(17)
2520

2621
minRequestApiKey = int16(0) // 0 - Produce
2722
maxRequestApiKey = int16(100) // so far 42 is the last (reserve some for the feature)
2823
)
2924

3025
var (
31-
defaultRequestHandler = &DefaultRequestHandler{}
32-
defaultResponseHandler = &DefaultResponseHandler{}
26+
defaultRequestHandler = &DefaultRequestHandler{}
27+
defaultResponseHandler = &DefaultResponseHandler{}
28+
saslAuthV0RequestHandler = &SaslAuthV0RequestHandler{}
29+
saslAuthV0ResponseHandler = &SaslAuthV0ResponseHandler{}
3330
)
3431

3532
type ProcessorConfig struct {
@@ -116,6 +113,9 @@ func (p *processor) RequestsLoop(dst DeadlineWriter, src DeadlineReaderWriter) (
116113
}
117114
}
118115
if p.localSasl.enabled {
116+
//TODO: when localSasl is enabled SASL is mandadory - we need authorized e.g. in RequestsLoopContext (mutex ?)
117+
//TODO: before SASL only ApiVersions is required
118+
//TODO: SASL can be done only once
119119
if err = p.localSasl.receiveAndSendSASLPlainAuth(src); err != nil {
120120
return true, err
121121
}
@@ -129,9 +129,7 @@ func (p *processor) RequestsLoop(dst DeadlineWriter, src DeadlineReaderWriter) (
129129
timeout: p.writeTimeout,
130130
brokerAddress: p.brokerAddress,
131131
forbiddenApiKeys: p.forbiddenApiKeys,
132-
keyVersionBuf: make([]byte, 8), // Size => int32 + ApiKey => int16 + ApiVersion => int16
133132
buf: make([]byte, p.requestBufferSize),
134-
lastApiKey: apiKeyUnset,
135133
}
136134

137135
return ctx.requestsLoop(dst, src)
@@ -145,11 +143,7 @@ type RequestsLoopContext struct {
145143
timeout time.Duration
146144
brokerAddress string
147145
forbiddenApiKeys map[int16]struct{}
148-
149-
// bufSize int
150-
keyVersionBuf []byte // 8 Size => int32 + ApiKey => int16 + ApiVersion => int16
151-
buf []byte // bufSize
152-
lastApiKey int16
146+
buf []byte // bufSize
153147
}
154148

155149
func (ctx *RequestsLoopContext) nextHandlers(nextRequestHandler RequestHandler, nextResponseHandler ResponseHandler) error {
@@ -193,93 +187,14 @@ func (r *RequestsLoopContext) requestsLoop(dst DeadlineWriter, src DeadlineReade
193187
}
194188
}
195189

196-
type DefaultRequestHandler struct {
197-
}
198-
199-
func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src DeadlineReader, ctx *RequestsLoopContext) (readErr bool, err error) {
200-
if ctx.lastApiKey == apiKeySaslHandshake {
201-
ctx.lastApiKey = apiKeySaslAuth
202-
if readErr, err = copySaslAuthRequest(dst, src, ctx.timeout, ctx.buf); err != nil {
203-
return readErr, err
204-
}
205-
if err = ctx.nextHandlers(defaultRequestHandler, defaultResponseHandler); err != nil {
206-
return false, err
207-
}
208-
return false, nil
209-
}
210-
if len(ctx.keyVersionBuf) != 8 {
211-
return false, errors.New("key version buf should have size 8")
212-
}
213-
// logrus.Println("Await Kafka request")
214-
215-
// waiting for first bytes or EOF - reset deadlines
216-
src.SetReadDeadline(time.Time{})
217-
dst.SetWriteDeadline(time.Time{})
218-
219-
if _, err = io.ReadFull(src, ctx.keyVersionBuf); err != nil {
220-
return true, err
221-
}
222-
223-
requestKeyVersion := &protocol.RequestKeyVersion{}
224-
if err = protocol.Decode(ctx.keyVersionBuf, requestKeyVersion); err != nil {
225-
return true, err
226-
}
227-
//logrus.Printf("Kafka request length %v, key %v, version %v", requestKeyVersion.Length, requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion)
228-
229-
if requestKeyVersion.ApiKey < minRequestApiKey || requestKeyVersion.ApiKey > maxRequestApiKey {
230-
return true, fmt.Errorf("api key %d is invalid", requestKeyVersion.ApiKey)
231-
}
232-
233-
proxyRequestsTotal.WithLabelValues(ctx.brokerAddress, strconv.Itoa(int(requestKeyVersion.ApiKey)), strconv.Itoa(int(requestKeyVersion.ApiVersion))).Inc()
234-
proxyRequestsBytes.WithLabelValues(ctx.brokerAddress).Add(float64(requestKeyVersion.Length + 4))
235-
236-
if _, ok := ctx.forbiddenApiKeys[requestKeyVersion.ApiKey]; ok {
237-
return true, fmt.Errorf("api key %d is forbidden", requestKeyVersion.ApiKey)
238-
}
239-
240-
// send inFlightRequest to channel before myCopyN to prevent race condition in proxyResponses
241-
if err = sendRequestKeyVersion(ctx.openRequestsChannel, openRequestSendTimeout, requestKeyVersion); err != nil {
242-
return true, err
243-
}
244-
245-
requestDeadline := time.Now().Add(ctx.timeout)
246-
err = dst.SetWriteDeadline(requestDeadline)
247-
if err != nil {
248-
return false, err
249-
}
250-
err = src.SetReadDeadline(requestDeadline)
251-
if err != nil {
252-
return true, err
253-
}
254-
255-
// write - send to broker
256-
if _, err = dst.Write(ctx.keyVersionBuf); err != nil {
257-
return false, err
258-
}
259-
// 4 bytes were written as keyVersionBuf (ApiKey, ApiVersion)
260-
if readErr, err = myCopyN(dst, src, int64(requestKeyVersion.Length-4), ctx.buf); err != nil {
261-
return readErr, err
262-
}
263-
264-
ctx.lastApiKey = requestKeyVersion.ApiKey
265-
266-
if err = ctx.nextHandlers(defaultRequestHandler, defaultResponseHandler); err != nil {
267-
return false, err
268-
}
269-
return false, nil
270-
}
271-
272190
func (p *processor) ResponsesLoop(dst DeadlineWriter, src DeadlineReader) (readErr bool, err error) {
273191
ctx := &ResponsesLoopContext{
274192
openRequestsChannel: p.openRequestsChannel,
275193
nextResponseHandlerChannel: p.nextResponseHandlerChannel,
276194
netAddressMappingFunc: p.netAddressMappingFunc,
277195
timeout: p.readTimeout,
278196
brokerAddress: p.brokerAddress,
279-
280-
responseHeaderBuf: make([]byte, 8), // Size => int32, CorrelationId => int32
281-
buf: make([]byte, p.responseBufferSize),
282-
lastApiKey: apiKeyUnset,
197+
buf: make([]byte, p.responseBufferSize),
283198
}
284199
return ctx.responsesLoop(dst, src)
285200
}
@@ -290,10 +205,7 @@ type ResponsesLoopContext struct {
290205
netAddressMappingFunc config.NetAddressMappingFunc
291206
timeout time.Duration
292207
brokerAddress string
293-
294-
responseHeaderBuf []byte // 8 - Size => int32, CorrelationId => int32
295-
buf []byte // bufSize
296-
lastApiKey int16
208+
buf []byte // bufSize
297209
}
298210

299211
type ResponseHandler interface {
@@ -309,128 +221,3 @@ func (r *ResponsesLoopContext) responsesLoop(dst DeadlineWriter, src DeadlineRea
309221
}
310222
}
311223
}
312-
313-
type DefaultResponseHandler struct {
314-
}
315-
316-
func (handler *DefaultResponseHandler) handleResponse(dst DeadlineWriter, src DeadlineReader, ctx *ResponsesLoopContext) (readErr bool, err error) {
317-
if ctx.lastApiKey == apiKeySaslHandshake {
318-
ctx.lastApiKey = apiKeySaslAuth
319-
if readErr, err = copySaslAuthResponse(dst, src, ctx.timeout); err != nil {
320-
return readErr, err
321-
}
322-
return false, nil // nextResponse
323-
}
324-
//logrus.Println("Await Kafka response")
325-
326-
// waiting for first bytes or EOF - reset deadlines
327-
src.SetReadDeadline(time.Time{})
328-
dst.SetWriteDeadline(time.Time{})
329-
330-
if len(ctx.responseHeaderBuf) != 8 {
331-
return false, errors.New("response header buf should have size 8")
332-
}
333-
334-
if _, err = io.ReadFull(src, ctx.responseHeaderBuf); err != nil {
335-
return true, err
336-
}
337-
338-
var responseHeader protocol.ResponseHeader
339-
if err = protocol.Decode(ctx.responseHeaderBuf, &responseHeader); err != nil {
340-
return true, err
341-
}
342-
343-
// Read the inFlightRequests channel after header is read. Otherwise the channel would block and socket EOF from remote would not be received.
344-
requestKeyVersion, err := receiveRequestKeyVersion(ctx.openRequestsChannel, openRequestReceiveTimeout)
345-
if err != nil {
346-
return true, err
347-
}
348-
proxyResponsesBytes.WithLabelValues(ctx.brokerAddress).Add(float64(responseHeader.Length + 4))
349-
//logrus.Printf("Kafka response lenght %v for key %v, version %v", responseHeader.Length, requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion)
350-
351-
responseDeadline := time.Now().Add(ctx.timeout)
352-
err = dst.SetWriteDeadline(responseDeadline)
353-
if err != nil {
354-
return false, err
355-
}
356-
err = src.SetReadDeadline(responseDeadline)
357-
if err != nil {
358-
return true, err
359-
}
360-
361-
responseModifier, err := protocol.GetResponseModifier(requestKeyVersion.ApiKey, requestKeyVersion.ApiVersion, ctx.netAddressMappingFunc)
362-
if err != nil {
363-
return true, err
364-
}
365-
if responseModifier != nil {
366-
if int32(responseHeader.Length) > protocol.MaxResponseSize {
367-
return true, protocol.PacketDecodingError{Info: fmt.Sprintf("message of length %d too large", responseHeader.Length)}
368-
}
369-
resp := make([]byte, int(responseHeader.Length-4))
370-
if _, err = io.ReadFull(src, resp); err != nil {
371-
return true, err
372-
}
373-
newResponseBuf, err := responseModifier.Apply(resp)
374-
if err != nil {
375-
return true, err
376-
}
377-
// add 4 bytes (CorrelationId) to the length
378-
newHeaderBuf, err := protocol.Encode(&protocol.ResponseHeader{Length: int32(len(newResponseBuf) + 4), CorrelationID: responseHeader.CorrelationID})
379-
if err != nil {
380-
return true, err
381-
}
382-
if _, err := dst.Write(newHeaderBuf); err != nil {
383-
return false, err
384-
}
385-
if _, err := dst.Write(newResponseBuf); err != nil {
386-
return false, err
387-
}
388-
} else {
389-
// write - send to local
390-
if _, err := dst.Write(ctx.responseHeaderBuf); err != nil {
391-
return false, err
392-
}
393-
// 4 bytes were written as responseHeaderBuf (CorrelationId)
394-
if readErr, err = myCopyN(dst, src, int64(responseHeader.Length-4), ctx.buf); err != nil {
395-
return readErr, err
396-
}
397-
}
398-
ctx.lastApiKey = requestKeyVersion.ApiKey
399-
400-
return false, nil // continue nextResponse
401-
}
402-
403-
func sendRequestKeyVersion(openRequestsChannel chan<- protocol.RequestKeyVersion, timeout time.Duration, request *protocol.RequestKeyVersion) error {
404-
select {
405-
case openRequestsChannel <- *request:
406-
default:
407-
// timer.Stop() will be invoked only after sendRequestKeyVersion is finished (not after select default) !
408-
timer := time.NewTimer(timeout)
409-
defer timer.Stop()
410-
411-
select {
412-
case openRequestsChannel <- *request:
413-
case <-timer.C:
414-
return errors.New("open requests buffer is full")
415-
}
416-
}
417-
return nil
418-
}
419-
420-
func receiveRequestKeyVersion(openRequestsChannel <-chan protocol.RequestKeyVersion, timeout time.Duration) (*protocol.RequestKeyVersion, error) {
421-
var request protocol.RequestKeyVersion
422-
select {
423-
case request = <-openRequestsChannel:
424-
default:
425-
// timer.Stop() will be invoked only after receiveRequestKeyVersion is finished (not after select default) !
426-
timer := time.NewTimer(timeout)
427-
defer timer.Stop()
428-
429-
select {
430-
case request = <-openRequestsChannel:
431-
case <-timer.C:
432-
return nil, errors.New("open request is missing")
433-
}
434-
}
435-
return &request, nil
436-
}

0 commit comments

Comments
 (0)