Skip to content

Commit 5edf5c4

Browse files
committed
Proxy initiated oauthbearer auth
1 parent 43640ba commit 5edf5c4

File tree

7 files changed

+233
-22
lines changed

7 files changed

+233
-22
lines changed

.travis.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ sudo: false
33
language: go
44

55
go:
6-
- "1.10.x"
6+
- "1.11.x"
77

88
env:
99
global:

Dockerfile.build

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
FROM golang:1.10 as builder
1+
FROM golang:1.11 as builder
22

33
ARG GOOS=linux
44
ARG GOARCH=amd64

proxy/client.go

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ type Client struct {
3434
stopRun chan struct{}
3535
stopOnce sync.Once
3636

37-
saslPlainAuth *SASLPlainAuth
38-
authClient *AuthClient
37+
saslAuthByProxy SASLAuthByProxy
38+
authClient *AuthClient
3939
}
4040

4141
func NewClient(conns *ConnSet, c *config.Config, netAddressMappingFunc config.NetAddressMappingFunc, localPasswordAuthenticator apis.PasswordAuthenticator, localTokenAuthenticator apis.TokenInfo, gatewayTokenProvider apis.TokenProvider, gatewayTokenInfo apis.TokenInfo) (*Client, error) {
@@ -72,7 +72,7 @@ func NewClient(conns *ConnSet, c *config.Config, netAddressMappingFunc config.Ne
7272
}
7373

7474
return &Client{conns: conns, config: c, dialer: dialer, tcpConnOptions: tcpConnOptions, stopRun: make(chan struct{}, 1),
75-
saslPlainAuth: &SASLPlainAuth{
75+
saslAuthByProxy: &SASLPlainAuth{
7676
clientID: c.Kafka.ClientID,
7777
writeTimeout: c.Kafka.WriteTimeout,
7878
readTimeout: c.Kafka.ReadTimeout,
@@ -193,7 +193,7 @@ func (c *Client) handleConn(conn Conn) {
193193
server, err := c.DialAndAuth(conn.BrokerAddress)
194194
if err != nil {
195195
logrus.Infof("couldn't connect to %s: %v", conn.BrokerAddress, err)
196-
conn.LocalConnection.Close()
196+
_ = conn.LocalConnection.Close()
197197
return
198198
}
199199
if tcpConn, ok := server.(*net.TCPConn); ok {
@@ -215,7 +215,7 @@ func (c *Client) DialAndAuth(brokerAddress string) (net.Conn, error) {
215215
return nil, err
216216
}
217217
if err := conn.SetDeadline(time.Time{}); err != nil {
218-
conn.Close()
218+
_ = conn.Close()
219219
return nil, err
220220
}
221221
err = c.auth(conn)
@@ -228,22 +228,22 @@ func (c *Client) DialAndAuth(brokerAddress string) (net.Conn, error) {
228228
func (c *Client) auth(conn net.Conn) error {
229229
if c.config.Auth.Gateway.Client.Enable {
230230
if err := c.authClient.sendAndReceiveGatewayAuth(conn); err != nil {
231-
conn.Close()
231+
_ = conn.Close()
232232
return err
233233
}
234234
if err := conn.SetDeadline(time.Time{}); err != nil {
235-
conn.Close()
235+
_ = conn.Close()
236236
return err
237237
}
238238
}
239239
if c.config.Kafka.SASL.Enable {
240-
err := c.saslPlainAuth.sendAndReceiveSASLPlainAuth(conn)
240+
err := c.saslAuthByProxy.sendAndReceiveSASLAuth(conn)
241241
if err != nil {
242-
conn.Close()
242+
_ = conn.Close()
243243
return err
244244
}
245245
if err := conn.SetDeadline(time.Time{}); err != nil {
246-
conn.Close()
246+
_ = conn.Close()
247247
return err
248248
}
249249
}

proxy/sasl.go renamed to proxy/sasl_by_proxy.go

Lines changed: 123 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@ package proxy
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/binary"
67
"fmt"
8+
"github.com/grepplabs/kafka-proxy/pkg/apis"
79
"github.com/grepplabs/kafka-proxy/proxy/protocol"
810
"github.com/pkg/errors"
911
"io"
@@ -15,6 +17,24 @@ const (
1517
SASLOAuthBearer = "OAUTHBEARER"
1618
)
1719

20+
type SASLHandshake struct {
21+
clientID string
22+
version int16
23+
mechanism string
24+
25+
writeTimeout time.Duration
26+
readTimeout time.Duration
27+
}
28+
29+
type SASLOAuthBearerAuth struct {
30+
clientID string
31+
32+
writeTimeout time.Duration
33+
readTimeout time.Duration
34+
35+
tokenProvider apis.TokenProvider
36+
}
37+
1838
type SASLPlainAuth struct {
1939
clientID string
2040

@@ -25,6 +45,10 @@ type SASLPlainAuth struct {
2545
password string
2646
}
2747

48+
type SASLAuthByProxy interface {
49+
sendAndReceiveSASLAuth(conn DeadlineReaderWriter) error
50+
}
51+
2852
// In SASL Plain, Kafka expects the auth header to be in the following format
2953
// Message format (from https://tools.ietf.org/html/rfc4616):
3054
//
@@ -40,9 +64,16 @@ type SASLPlainAuth struct {
4064
// When credentials are valid, Kafka returns a 4 byte array of null characters.
4165
// When credentials are invalid, Kafka closes the connection. This does not seem to be the ideal way
4266
// of responding to bad credentials but thats how its being done today.
43-
func (b *SASLPlainAuth) sendAndReceiveSASLPlainAuth(conn DeadlineReaderWriter) error {
44-
45-
handshakeErr := b.sendAndReceiveSASLPlainHandshake(conn)
67+
func (b *SASLPlainAuth) sendAndReceiveSASLAuth(conn DeadlineReaderWriter) error {
68+
69+
saslHandshake := &SASLHandshake{
70+
clientID: b.clientID,
71+
version: 0,
72+
mechanism: SASLPlain,
73+
writeTimeout: b.writeTimeout,
74+
readTimeout: b.readTimeout,
75+
}
76+
handshakeErr := saslHandshake.sendAndReceiveHandshake(conn)
4677
if handshakeErr != nil {
4778
return handshakeErr
4879
}
@@ -78,11 +109,11 @@ func (b *SASLPlainAuth) sendAndReceiveSASLPlainAuth(conn DeadlineReaderWriter) e
78109
return nil
79110
}
80111

81-
func (b *SASLPlainAuth) sendAndReceiveSASLPlainHandshake(conn DeadlineReaderWriter) error {
112+
func (b *SASLHandshake) sendAndReceiveHandshake(conn DeadlineReaderWriter) error {
82113

83114
req := &protocol.Request{
84115
ClientID: b.clientID,
85-
Body: &protocol.SaslHandshakeRequestV0orV1{Version: 0, Mechanism: SASLPlain},
116+
Body: &protocol.SaslHandshakeRequestV0orV1{Version: b.version, Mechanism: b.mechanism},
86117
}
87118
reqBuf, err := protocol.Encode(req)
88119
if err != nil {
@@ -128,3 +159,90 @@ func (b *SASLPlainAuth) sendAndReceiveSASLPlainHandshake(conn DeadlineReaderWrit
128159
}
129160
return nil
130161
}
162+
163+
func (b *SASLOAuthBearerAuth) getOAuthBearerToken() (string, error) {
164+
resp, err := b.tokenProvider.GetToken(context.Background(), apis.TokenRequest{})
165+
if err != nil {
166+
return "", err
167+
}
168+
if !resp.Success {
169+
return "", fmt.Errorf("get sasl token failed with status: %d", resp.Status)
170+
}
171+
if resp.Token == "" {
172+
return "", errors.New("get sasl token returned empty token")
173+
}
174+
return resp.Token, nil
175+
}
176+
177+
func (b *SASLOAuthBearerAuth) sendAndReceiveSASLAuth(conn DeadlineReaderWriter) error {
178+
179+
token, err := b.getOAuthBearerToken()
180+
if err != nil {
181+
return err
182+
}
183+
saslHandshake := &SASLHandshake{
184+
clientID: b.clientID,
185+
version: 1,
186+
mechanism: SASLOAuthBearer,
187+
writeTimeout: b.writeTimeout,
188+
readTimeout: b.readTimeout,
189+
}
190+
handshakeErr := saslHandshake.sendAndReceiveHandshake(conn)
191+
if handshakeErr != nil {
192+
return handshakeErr
193+
}
194+
return b.sendSaslAuthenticateRequest(token, conn)
195+
}
196+
197+
func (b *SASLOAuthBearerAuth) sendSaslAuthenticateRequest(token string, conn DeadlineReaderWriter) error {
198+
saslAuthReqV0 := protocol.SaslAuthenticateRequestV0{SaslAuthBytes: SaslOAuthBearer{}.ToBytes(token, "", make(map[string]string, 0))}
199+
200+
req := &protocol.Request{
201+
ClientID: b.clientID,
202+
Body: &saslAuthReqV0,
203+
}
204+
reqBuf, err := protocol.Encode(req)
205+
if err != nil {
206+
return err
207+
}
208+
sizeBuf := make([]byte, 4)
209+
binary.BigEndian.PutUint32(sizeBuf, uint32(len(reqBuf)))
210+
211+
err = conn.SetWriteDeadline(time.Now().Add(b.writeTimeout))
212+
if err != nil {
213+
return err
214+
}
215+
216+
_, err = conn.Write(bytes.Join([][]byte{sizeBuf, reqBuf}, nil))
217+
if err != nil {
218+
return errors.Wrap(err, "Failed to send SASL auth request")
219+
}
220+
221+
err = conn.SetReadDeadline(time.Now().Add(b.readTimeout))
222+
if err != nil {
223+
return err
224+
}
225+
226+
//wait for the response
227+
header := make([]byte, 8) // response header
228+
_, err = io.ReadFull(conn, header)
229+
if err != nil {
230+
return errors.Wrap(err, "Failed to read SASL auth header")
231+
}
232+
length := binary.BigEndian.Uint32(header[:4])
233+
payload := make([]byte, length-4)
234+
_, err = io.ReadFull(conn, payload)
235+
if err != nil {
236+
return errors.Wrap(err, "Failed to read SASL auth payload")
237+
}
238+
239+
res := &protocol.SaslAuthenticateResponseV0{}
240+
err = protocol.Decode(payload, res)
241+
if err != nil {
242+
return errors.Wrap(err, "Failed to parse SASL auth response")
243+
}
244+
if res.Err != protocol.ErrNoError {
245+
return errors.Wrapf(res.Err, "SASL authentication failed, error message is '%v'", res.ErrMsg)
246+
}
247+
return nil
248+
}

proxy/sasl_local_auth.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ func NewLocalSaslOauth(tokenAuthenticator apis.TokenInfo) *LocalSaslOauth {
6161

6262
// implements LocalSaslAuth
6363
func (p *LocalSaslOauth) doLocalAuth(saslAuthBytes []byte) (err error) {
64-
token, err := p.saslOAuthBearer.GetToken(saslAuthBytes)
64+
token, _, _, err := p.saslOAuthBearer.GetClientInitialResponse(saslAuthBytes)
6565
if err != nil {
6666
return err
6767
}

proxy/sasl_oauthbearer.go

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ import (
99

1010
// https://tools.ietf.org/html/rfc7628#section-3.1
1111
// https://tools.ietf.org/html/rfc5801#section-4
12+
// https://tools.ietf.org/html/rfc5801 (UTF8-1-safe)
1213
const (
1314
saslOauthSeparator = "\u0001"
14-
saslOauthSaslName = "(?:[\\x01-\\x7F&&[^=,]]|=2C|=3D)+"
15+
saslOauthSaslName = "(?:[\x01-\x2b]|[\x2d-\x3c]|[\x3e-\x7F]|=2C|=3D)+"
1516
saslOauthKey = "[A-Za-z]+"
1617
saslOauthValue = "[\\x21-\\x7E \t\r\n]+"
1718
saslOauthAuthKey = "auth"
@@ -25,18 +26,32 @@ var (
2526

2627
type SaslOAuthBearer struct{}
2728

28-
func (p SaslOAuthBearer) GetToken(saslAuthBytes []byte) (string, error) {
29+
func (p SaslOAuthBearer) GetClientInitialResponse(saslAuthBytes []byte) (token string, authzid string, extensions map[string]string, err error) {
2930
match := saslOauthClientInitialResponsePattern.FindSubmatch(saslAuthBytes)
31+
if len(match) == 0 {
32+
return "", "", nil, errors.New("invalid OAUTHBEARER initial client response: 'saslAuthBytes' parse error")
33+
}
3034

3135
result := make(map[string][]byte)
3236
for i, name := range saslOauthClientInitialResponsePattern.SubexpNames() {
3337
if i != 0 && name != "" {
38+
if i >= len(match) {
39+
return "", "", nil, errors.New("invalid OAUTHBEARER initial client response: 'SubexpNames' range error")
40+
}
3441
result[name] = match[i]
3542
}
3643
}
44+
45+
authzid = string(result["authzid"])
3746
kvpairs := result["kvpairs"]
3847
properties := p.parseMap(string(kvpairs), "=", saslOauthSeparator)
39-
return p.parseToken(properties[saslOauthAuthKey])
48+
49+
token, err = p.parseToken(properties[saslOauthAuthKey])
50+
if err != nil {
51+
return "", "", nil, err
52+
}
53+
delete(properties, saslOauthAuthKey)
54+
return token, authzid, properties, nil
4055
}
4156

4257
func (SaslOAuthBearer) parseToken(auth string) (string, error) {
@@ -73,3 +88,28 @@ func (SaslOAuthBearer) parseMap(mapStr string, keyValueSeparator string, element
7388
}
7489
return result
7590
}
91+
92+
func (SaslOAuthBearer) mkString(mapValues map[string]string, keyValueSeparator string, elementSeparator string) string {
93+
if len(mapValues) == 0 {
94+
return ""
95+
}
96+
elements := make([]string, 0, len(mapValues))
97+
for k, v := range mapValues {
98+
elements = append(elements, strings.Join([]string{k, v}, keyValueSeparator))
99+
}
100+
return strings.Join(elements, elementSeparator)
101+
}
102+
103+
func (p SaslOAuthBearer) ToBytes(tokenValue string, authorizationId string, saslExtensions map[string]string) []byte {
104+
authzid := authorizationId
105+
if authzid != "" {
106+
authzid = "a=" + authorizationId
107+
}
108+
extensions := p.mkString(saslExtensions, "=", saslOauthSeparator)
109+
if extensions != "" {
110+
extensions = saslOauthSeparator + extensions
111+
}
112+
message := fmt.Sprintf("n,%s,%sauth=Bearer %s%s%s%s", authzid,
113+
saslOauthSeparator, tokenValue, extensions, saslOauthSeparator, saslOauthSeparator)
114+
return []byte(message)
115+
}

proxy/sasl_oauthbearer_test.go

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,60 @@ func TestSaslOauthParseToken(t *testing.T) {
1313
saslAuthBytes, err := hex.DecodeString(saslBytes)
1414
a.Nil(err)
1515

16-
token, err := SaslOAuthBearer{}.GetToken(saslAuthBytes)
16+
token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse(saslAuthBytes)
1717
a.Nil(err)
18+
a.Empty(authzid)
19+
a.Empty(extensions)
1820
a.Equal("eyJhbGciOiJub25lIn0.eyJleHAiOjEuNTM5NTE2Njk0NDE4RTksImlhdCI6MS41Mzk1MTMwOTQ0MThFOSwic3ViIjoiYWxpY2UyIn0.", token)
21+
22+
a.Equal(saslAuthBytes, SaslOAuthBearer{}.ToBytes(token, authzid, extensions))
23+
24+
}
25+
func TestSaslOAuthBearerToBytes(t *testing.T) {
26+
a := assert.New(t)
27+
token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse([]byte("n,,\u0001auth=Bearer 123.345.567\u0001nineteen=42\u0001\u0001"))
28+
a.Nil(err)
29+
a.Equal("123.345.567", token)
30+
a.Empty(authzid)
31+
a.Equal(map[string]string{"nineteen": "42"}, extensions)
32+
}
33+
34+
func TestSaslOAuthBearerAuthorizationId(t *testing.T) {
35+
a := assert.New(t)
36+
token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse([]byte("n,a=myuser,\u0001auth=Bearer 345\u0001\u0001"))
37+
a.Nil(err)
38+
a.Equal("345", token)
39+
a.Equal("myuser", authzid)
40+
a.Empty(extensions)
41+
}
42+
43+
func TestSaslOAuthBearerExtensions(t *testing.T) {
44+
a := assert.New(t)
45+
token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse([]byte("n,,\u0001propA=valueA1, valueA2\u0001auth=Bearer 567\u0001propB=valueB\u0001\u0001"))
46+
a.Nil(err)
47+
a.Equal("567", token)
48+
a.Empty(authzid)
49+
a.Equal(map[string]string{"propA": "valueA1, valueA2", "propB": "valueB"}, extensions)
50+
}
51+
52+
func TestSaslOAuthBearerRfc7688Example(t *testing.T) {
53+
a := assert.New(t)
54+
message := "n,[email protected],\u0001host=server.example.com\u0001port=143\u0001" +
55+
"auth=Bearer vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg\u0001\u0001"
56+
token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse([]byte(message))
57+
a.Nil(err)
58+
a.Equal("vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg", token)
59+
a.Equal("[email protected]", authzid)
60+
a.Equal(map[string]string{"host": "server.example.com", "port": "143"}, extensions)
61+
}
62+
63+
func TestSaslOAuthBearerNoExtensionsFromByteArray(t *testing.T) {
64+
a := assert.New(t)
65+
message := "n,[email protected],\u0001" +
66+
"auth=Bearer vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg\u0001\u0001"
67+
token, authzid, extensions, err := SaslOAuthBearer{}.GetClientInitialResponse([]byte(message))
68+
a.Nil(err)
69+
a.Equal("vF9dft4qmTc2Nvb3RlckBhbHRhdmlzdGEuY29tCg", token)
70+
a.Equal("[email protected]", authzid)
71+
a.Empty(extensions)
1972
}

0 commit comments

Comments
 (0)