@@ -2,8 +2,10 @@ package proxy
22
33import (
44 "bytes"
5+ "context"
56 "encoding/binary"
67 "fmt"
8+ "github.com/grepplabs/kafka-proxy/pkg/apis"
79 "github.com/grepplabs/kafka-proxy/proxy/protocol"
810 "github.com/pkg/errors"
911 "io"
@@ -15,6 +17,24 @@ const (
1517 SASLOAuthBearer = "OAUTHBEARER"
1618)
1719
20+ type SASLHandshake struct {
21+ clientID string
22+ version int16
23+ mechanism string
24+
25+ writeTimeout time.Duration
26+ readTimeout time.Duration
27+ }
28+
29+ type SASLOAuthBearerAuth struct {
30+ clientID string
31+
32+ writeTimeout time.Duration
33+ readTimeout time.Duration
34+
35+ tokenProvider apis.TokenProvider
36+ }
37+
1838type SASLPlainAuth struct {
1939 clientID string
2040
@@ -25,6 +45,10 @@ type SASLPlainAuth struct {
2545 password string
2646}
2747
48+ type SASLAuthByProxy interface {
49+ sendAndReceiveSASLAuth (conn DeadlineReaderWriter ) error
50+ }
51+
2852// In SASL Plain, Kafka expects the auth header to be in the following format
2953// Message format (from https://tools.ietf.org/html/rfc4616):
3054//
@@ -40,9 +64,16 @@ type SASLPlainAuth struct {
4064// When credentials are valid, Kafka returns a 4 byte array of null characters.
4165// When credentials are invalid, Kafka closes the connection. This does not seem to be the ideal way
4266// of responding to bad credentials but thats how its being done today.
43- func (b * SASLPlainAuth ) sendAndReceiveSASLPlainAuth (conn DeadlineReaderWriter ) error {
44-
45- handshakeErr := b .sendAndReceiveSASLPlainHandshake (conn )
67+ func (b * SASLPlainAuth ) sendAndReceiveSASLAuth (conn DeadlineReaderWriter ) error {
68+
69+ saslHandshake := & SASLHandshake {
70+ clientID : b .clientID ,
71+ version : 0 ,
72+ mechanism : SASLPlain ,
73+ writeTimeout : b .writeTimeout ,
74+ readTimeout : b .readTimeout ,
75+ }
76+ handshakeErr := saslHandshake .sendAndReceiveHandshake (conn )
4677 if handshakeErr != nil {
4778 return handshakeErr
4879 }
@@ -78,11 +109,11 @@ func (b *SASLPlainAuth) sendAndReceiveSASLPlainAuth(conn DeadlineReaderWriter) e
78109 return nil
79110}
80111
81- func (b * SASLPlainAuth ) sendAndReceiveSASLPlainHandshake (conn DeadlineReaderWriter ) error {
112+ func (b * SASLHandshake ) sendAndReceiveHandshake (conn DeadlineReaderWriter ) error {
82113
83114 req := & protocol.Request {
84115 ClientID : b .clientID ,
85- Body : & protocol.SaslHandshakeRequestV0orV1 {Version : 0 , Mechanism : SASLPlain },
116+ Body : & protocol.SaslHandshakeRequestV0orV1 {Version : b . version , Mechanism : b . mechanism },
86117 }
87118 reqBuf , err := protocol .Encode (req )
88119 if err != nil {
@@ -128,3 +159,90 @@ func (b *SASLPlainAuth) sendAndReceiveSASLPlainHandshake(conn DeadlineReaderWrit
128159 }
129160 return nil
130161}
162+
163+ func (b * SASLOAuthBearerAuth ) getOAuthBearerToken () (string , error ) {
164+ resp , err := b .tokenProvider .GetToken (context .Background (), apis.TokenRequest {})
165+ if err != nil {
166+ return "" , err
167+ }
168+ if ! resp .Success {
169+ return "" , fmt .Errorf ("get sasl token failed with status: %d" , resp .Status )
170+ }
171+ if resp .Token == "" {
172+ return "" , errors .New ("get sasl token returned empty token" )
173+ }
174+ return resp .Token , nil
175+ }
176+
177+ func (b * SASLOAuthBearerAuth ) sendAndReceiveSASLAuth (conn DeadlineReaderWriter ) error {
178+
179+ token , err := b .getOAuthBearerToken ()
180+ if err != nil {
181+ return err
182+ }
183+ saslHandshake := & SASLHandshake {
184+ clientID : b .clientID ,
185+ version : 1 ,
186+ mechanism : SASLOAuthBearer ,
187+ writeTimeout : b .writeTimeout ,
188+ readTimeout : b .readTimeout ,
189+ }
190+ handshakeErr := saslHandshake .sendAndReceiveHandshake (conn )
191+ if handshakeErr != nil {
192+ return handshakeErr
193+ }
194+ return b .sendSaslAuthenticateRequest (token , conn )
195+ }
196+
197+ func (b * SASLOAuthBearerAuth ) sendSaslAuthenticateRequest (token string , conn DeadlineReaderWriter ) error {
198+ saslAuthReqV0 := protocol.SaslAuthenticateRequestV0 {SaslAuthBytes : SaslOAuthBearer {}.ToBytes (token , "" , make (map [string ]string , 0 ))}
199+
200+ req := & protocol.Request {
201+ ClientID : b .clientID ,
202+ Body : & saslAuthReqV0 ,
203+ }
204+ reqBuf , err := protocol .Encode (req )
205+ if err != nil {
206+ return err
207+ }
208+ sizeBuf := make ([]byte , 4 )
209+ binary .BigEndian .PutUint32 (sizeBuf , uint32 (len (reqBuf )))
210+
211+ err = conn .SetWriteDeadline (time .Now ().Add (b .writeTimeout ))
212+ if err != nil {
213+ return err
214+ }
215+
216+ _ , err = conn .Write (bytes .Join ([][]byte {sizeBuf , reqBuf }, nil ))
217+ if err != nil {
218+ return errors .Wrap (err , "Failed to send SASL auth request" )
219+ }
220+
221+ err = conn .SetReadDeadline (time .Now ().Add (b .readTimeout ))
222+ if err != nil {
223+ return err
224+ }
225+
226+ //wait for the response
227+ header := make ([]byte , 8 ) // response header
228+ _ , err = io .ReadFull (conn , header )
229+ if err != nil {
230+ return errors .Wrap (err , "Failed to read SASL auth header" )
231+ }
232+ length := binary .BigEndian .Uint32 (header [:4 ])
233+ payload := make ([]byte , length - 4 )
234+ _ , err = io .ReadFull (conn , payload )
235+ if err != nil {
236+ return errors .Wrap (err , "Failed to read SASL auth payload" )
237+ }
238+
239+ res := & protocol.SaslAuthenticateResponseV0 {}
240+ err = protocol .Decode (payload , res )
241+ if err != nil {
242+ return errors .Wrap (err , "Failed to parse SASL auth response" )
243+ }
244+ if res .Err != protocol .ErrNoError {
245+ return errors .Wrapf (res .Err , "SASL authentication failed, error message is '%v'" , res .ErrMsg )
246+ }
247+ return nil
248+ }
0 commit comments