Skip to content

Commit 2e0561e

Browse files
committed
Local SaslAuthV0
1 parent 828dbe8 commit 2e0561e

File tree

4 files changed

+96
-44
lines changed

4 files changed

+96
-44
lines changed

proxy/processor.go

Lines changed: 59 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ const (
1616
defaultReadTimeout = 30 * time.Second
1717
minOpenRequests = 16
1818

19-
apiKeySaslHandshake = int16(17)
19+
apiKeySaslHandshake = int16(17)
20+
apiKeyApiApiVersions = int16(18)
2021

2122
minRequestApiKey = int16(0) // 0 - Produce
2223
maxRequestApiKey = int16(100) // so far 42 is the last (reserve some for the feature)
@@ -81,8 +82,7 @@ func newProcessor(cfg ProcessorConfig, brokerAddress string) *processor {
8182
if readTimeout <= 0 {
8283
readTimeout = defaultReadTimeout
8384
}
84-
// in most use cases there will be only one entry in the nextRequestHandlerChannel channel
85-
nextRequestHandlerChannel := make(chan RequestHandler, minOpenRequests)
85+
nextRequestHandlerChannel := make(chan RequestHandler, 1)
8686
nextResponseHandlerChannel := make(chan ResponseHandler, maxOpenRequests+1)
8787

8888
// initial handlers -> standard kafka message arrives always as first
@@ -112,14 +112,6 @@ func (p *processor) RequestsLoop(dst DeadlineWriter, src DeadlineReaderWriter) (
112112
return true, err
113113
}
114114
}
115-
if p.localSasl.enabled {
116-
//TODO: when localSasl is enabled SASL is mandadory - we need authorized e.g. in RequestsLoopContext (mutex ?)
117-
//TODO: before SASL only ApiVersions is required
118-
//TODO: SASL can be done only once
119-
if err = p.localSasl.receiveAndSendSASLPlainAuth(src); err != nil {
120-
return true, err
121-
}
122-
}
123115
src.SetDeadline(time.Time{})
124116

125117
ctx := &RequestsLoopContext{
@@ -130,6 +122,8 @@ func (p *processor) RequestsLoop(dst DeadlineWriter, src DeadlineReaderWriter) (
130122
brokerAddress: p.brokerAddress,
131123
forbiddenApiKeys: p.forbiddenApiKeys,
132124
buf: make([]byte, p.requestBufferSize),
125+
localSasl: p.localSasl,
126+
localSaslDone: false, // sequential processing - mutex is required
133127
}
134128

135129
return ctx.requestsLoop(dst, src)
@@ -144,9 +138,23 @@ type RequestsLoopContext struct {
144138
brokerAddress string
145139
forbiddenApiKeys map[int16]struct{}
146140
buf []byte // bufSize
141+
142+
localSasl *LocalSasl
143+
localSaslDone bool
144+
}
145+
146+
// used by local authentication
147+
func (ctx *RequestsLoopContext) putNextRequestHandler(nextRequestHandler RequestHandler) error {
148+
149+
select {
150+
case ctx.nextRequestHandlerChannel <- nextRequestHandler:
151+
default:
152+
return errors.New("next request handler channel is full")
153+
}
154+
return nil
147155
}
148156

149-
func (ctx *RequestsLoopContext) nextHandlers(nextRequestHandler RequestHandler, nextResponseHandler ResponseHandler) error {
157+
func (ctx *RequestsLoopContext) putNextHandlers(nextRequestHandler RequestHandler, nextResponseHandler ResponseHandler) error {
150158

151159
select {
152160
case ctx.nextRequestHandlerChannel <- nextRequestHandler:
@@ -157,8 +165,7 @@ func (ctx *RequestsLoopContext) nextHandlers(nextRequestHandler RequestHandler,
157165
select {
158166
case ctx.nextResponseHandlerChannel <- nextResponseHandler:
159167
default:
160-
// timer.Stop() will be invoked only after nextHandlers is finished
161-
timer := time.NewTimer(openRequestReceiveTimeout) // reuse openRequestReceiveTimeout
168+
timer := time.NewTimer(openRequestSendTimeout)
162169
defer timer.Stop()
163170

164171
select {
@@ -170,19 +177,27 @@ func (ctx *RequestsLoopContext) nextHandlers(nextRequestHandler RequestHandler,
170177
return nil
171178
}
172179

180+
func (r *RequestsLoopContext) getNextRequestHandler() (RequestHandler, error) {
181+
select {
182+
case nextRequestHandler := <-r.nextRequestHandlerChannel:
183+
return nextRequestHandler, nil
184+
default:
185+
return nil, errors.New("next request handler is missing")
186+
}
187+
}
188+
173189
type RequestHandler interface {
174-
handleRequest(dst DeadlineWriter, src DeadlineReader, ctx *RequestsLoopContext) (readErr bool, err error)
190+
handleRequest(dst DeadlineWriter, src DeadlineReaderWriter, ctx *RequestsLoopContext) (readErr bool, err error)
175191
}
176192

177-
func (r *RequestsLoopContext) requestsLoop(dst DeadlineWriter, src DeadlineReader) (readErr bool, err error) {
193+
func (r *RequestsLoopContext) requestsLoop(dst DeadlineWriter, src DeadlineReaderWriter) (readErr bool, err error) {
194+
var nextRequestHandler RequestHandler
178195
for {
179-
select {
180-
case nextRequestHandler := <-r.nextRequestHandlerChannel:
181-
if readErr, err = nextRequestHandler.handleRequest(dst, src, r); err != nil {
182-
return readErr, err
183-
}
184-
default:
185-
return false, errors.New("internal error: next request handler expected")
196+
if nextRequestHandler, err = r.getNextRequestHandler(); err != nil {
197+
return false, nil
198+
}
199+
if readErr, err = nextRequestHandler.handleRequest(dst, src, r); err != nil {
200+
return readErr, err
186201
}
187202
}
188203
}
@@ -213,11 +228,30 @@ type ResponseHandler interface {
213228
}
214229

215230
func (r *ResponsesLoopContext) responsesLoop(dst DeadlineWriter, src DeadlineReader) (readErr bool, err error) {
231+
var nextResponseHandler ResponseHandler
216232
for {
217-
//TODO: timeout noting was received
218-
nextResponseHandler := <-r.nextResponseHandlerChannel
233+
if nextResponseHandler, err = r.getNextResponseHandler(); err != nil {
234+
return false, err
235+
}
219236
if readErr, err = nextResponseHandler.handleResponse(dst, src, r); err != nil {
220237
return readErr, err
221238
}
222239
}
223240
}
241+
242+
func (r *ResponsesLoopContext) getNextResponseHandler() (ResponseHandler, error) {
243+
select {
244+
case handler := <-r.nextResponseHandlerChannel:
245+
return handler, nil
246+
default:
247+
timer := time.NewTimer(openRequestReceiveTimeout)
248+
defer timer.Stop()
249+
250+
select {
251+
case handler := <-r.nextResponseHandlerChannel:
252+
return handler, nil
253+
case <-timer.C:
254+
return nil, errors.New("next response handler is missing")
255+
}
256+
}
257+
}

proxy/processor_default.go

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ type DefaultRequestHandler struct {
1515
type DefaultResponseHandler struct {
1616
}
1717

18-
func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src DeadlineReader, ctx *RequestsLoopContext) (readErr bool, err error) {
18+
func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src DeadlineReaderWriter, ctx *RequestsLoopContext) (readErr bool, err error) {
1919
// logrus.Println("Await Kafka request")
2020

2121
// waiting for first bytes or EOF - reset deadlines
@@ -45,6 +45,31 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead
4545
return true, fmt.Errorf("api key %d is forbidden", requestKeyVersion.ApiKey)
4646
}
4747

48+
if ctx.localSasl.enabled {
49+
if ctx.localSaslDone {
50+
if requestKeyVersion.ApiKey == apiKeySaslHandshake {
51+
return false, errors.New("SASL Auth was already done")
52+
}
53+
} else {
54+
switch requestKeyVersion.ApiKey {
55+
case apiKeySaslHandshake:
56+
//TODO: this is only V0 version
57+
if err = ctx.localSasl.receiveAndSendSASLPlainAuth(src, keyVersionBuf); err != nil {
58+
return true, err
59+
}
60+
ctx.localSaslDone = true
61+
src.SetDeadline(time.Time{})
62+
63+
// defaultRequestHandler was consumed but due to local handling enqueued defaultResponseHandler will not be.
64+
return false, ctx.putNextRequestHandler(defaultRequestHandler)
65+
case apiKeyApiApiVersions:
66+
// continue processing
67+
default:
68+
return false, errors.New("SASL Auth is required. Only SaslHandshake or ApiVersions requests are allowed")
69+
}
70+
}
71+
}
72+
4873
// send inFlightRequest to channel before myCopyN to prevent race condition in proxyResponses
4974
if err = sendRequestKeyVersion(ctx.openRequestsChannel, openRequestSendTimeout, requestKeyVersion); err != nil {
5075
return true, err
@@ -69,19 +94,12 @@ func (handler *DefaultRequestHandler) handleRequest(dst DeadlineWriter, src Dead
6994
return readErr, err
7095
}
7196
if requestKeyVersion.ApiKey == apiKeySaslHandshake {
72-
if requestKeyVersion.ApiVersion == 0 {
73-
if err = ctx.nextHandlers(saslAuthV0RequestHandler, saslAuthV0ResponseHandler); err != nil {
74-
return false, err
75-
}
76-
return false, nil
77-
} else {
97+
if requestKeyVersion.ApiVersion != 0 {
7898
return false, errors.New("only SASL V0 Handshake is supported")
7999
}
100+
return false, ctx.putNextHandlers(saslAuthV0RequestHandler, saslAuthV0ResponseHandler)
80101
}
81-
if err = ctx.nextHandlers(defaultRequestHandler, defaultResponseHandler); err != nil {
82-
return false, err
83-
}
84-
return false, nil
102+
return false, ctx.putNextHandlers(defaultRequestHandler, defaultResponseHandler)
85103
}
86104

87105
func (handler *DefaultResponseHandler) handleResponse(dst DeadlineWriter, src DeadlineReader, ctx *ResponsesLoopContext) (readErr bool, err error) {

proxy/processor_sasl_auth_v0.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,11 @@ type SaslAuthV0RequestHandler struct {
1414
type SaslAuthV0ResponseHandler struct {
1515
}
1616

17-
func (handler *SaslAuthV0RequestHandler) handleRequest(dst DeadlineWriter, src DeadlineReader, ctx *RequestsLoopContext) (readErr bool, err error) {
17+
func (handler *SaslAuthV0RequestHandler) handleRequest(dst DeadlineWriter, src DeadlineReaderWriter, ctx *RequestsLoopContext) (readErr bool, err error) {
1818
if readErr, err = copySaslAuthRequest(dst, src, ctx.timeout, ctx.buf); err != nil {
1919
return readErr, err
2020
}
21-
if err = ctx.nextHandlers(defaultRequestHandler, defaultResponseHandler); err != nil {
21+
if err = ctx.putNextHandlers(defaultRequestHandler, defaultResponseHandler); err != nil {
2222
return false, err
2323
}
2424
return false, nil

proxy/sasl_local.go

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ type LocalSasl struct {
1919
localAuthenticator apis.PasswordAuthenticator
2020
}
2121

22-
func (p *LocalSasl) receiveAndSendSASLPlainAuth(conn DeadlineReaderWriter) (err error) {
23-
if err = p.receiveAndSendSasl(conn); err != nil {
22+
func (p *LocalSasl) receiveAndSendSASLPlainAuth(conn DeadlineReaderWriter, readKeyVersionBuf []byte) (err error) {
23+
if err = p.receiveAndSendSasl(conn, readKeyVersionBuf); err != nil {
2424
return err
2525
}
2626
if err = p.receiveAndSendAuth(conn); err != nil {
@@ -29,17 +29,17 @@ func (p *LocalSasl) receiveAndSendSASLPlainAuth(conn DeadlineReaderWriter) (err
2929
return nil
3030
}
3131

32-
func (p *LocalSasl) receiveAndSendSasl(conn DeadlineReaderWriter) (err error) {
32+
func (p *LocalSasl) receiveAndSendSasl(conn DeadlineReaderWriter, keyVersionBuf []byte) (err error) {
3333
requestDeadline := time.Now().Add(p.timeout)
3434
err = conn.SetDeadline(requestDeadline)
3535
if err != nil {
3636
return err
3737
}
3838

39-
keyVersionBuf := make([]byte, 8) // Size => int32 + ApiKey => int16 + ApiVersion => int16
40-
if _, err = io.ReadFull(conn, keyVersionBuf); err != nil {
41-
return err
39+
if len(keyVersionBuf) != 8 {
40+
return errors.New("length of keyVersionBuf should be 8")
4241
}
42+
// keyVersionBuf has already been read from connection
4343
requestKeyVersion := &protocol.RequestKeyVersion{}
4444
if err = protocol.Decode(keyVersionBuf, requestKeyVersion); err != nil {
4545
return err

0 commit comments

Comments
 (0)