Skip to content

Commit 057ab10

Browse files
authored
Merge pull request #26 from worms/master
Add SASL SCRAM support
2 parents 7d600af + 67884b3 commit 057ab10

File tree

5 files changed

+263
-9
lines changed

5 files changed

+263
-9
lines changed

cmd/kafka-proxy/server.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ func initFlags() {
148148
Server.Flags().StringVar(&c.Kafka.SASL.Username, "sasl-username", "", "SASL user name")
149149
Server.Flags().StringVar(&c.Kafka.SASL.Password, "sasl-password", "", "SASL user password")
150150
Server.Flags().StringVar(&c.Kafka.SASL.JaasConfigFile, "sasl-jaas-config-file", "", "Location of JAAS config file with SASL username and password")
151+
Server.Flags().StringVar(&c.Kafka.SASL.Method, "sasl-method", "PLAIN", "SASL method to use (PLAIN, SCRAM-SHA-256, SCRAM-SHA-512")
151152

152153
// SASL by Proxy plugin
153154
Server.Flags().BoolVar(&c.Kafka.SASL.Plugin.Enable, "sasl-plugin-enable", false, "Use plugin for SASL authentication")

config/config.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,7 @@ type Config struct {
120120
Username string
121121
Password string
122122
JaasConfigFile string
123+
Method string
123124
Plugin struct {
124125
Enable bool
125126
Command string

proxy/client.go

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,13 +83,26 @@ func NewClient(conns *ConnSet, c *config.Config, netAddressMappingFunc config.Ne
8383
return nil, errors.Errorf("SASLAuthByProxy plugin unsupported or plugin misconfiguration for mechanism '%s' ", c.Kafka.SASL.Plugin.Mechanism)
8484
}
8585

86-
} else {
87-
saslAuthByProxy = &SASLPlainAuth{
88-
clientID: c.Kafka.ClientID,
89-
writeTimeout: c.Kafka.WriteTimeout,
90-
readTimeout: c.Kafka.ReadTimeout,
91-
username: c.Kafka.SASL.Username,
92-
password: c.Kafka.SASL.Password,
86+
} else if c.Kafka.SASL.Enable {
87+
if c.Kafka.SASL.Method == SASLPlain {
88+
saslAuthByProxy = &SASLPlainAuth{
89+
clientID: c.Kafka.ClientID,
90+
writeTimeout: c.Kafka.WriteTimeout,
91+
readTimeout: c.Kafka.ReadTimeout,
92+
username: c.Kafka.SASL.Username,
93+
password: c.Kafka.SASL.Password,
94+
}
95+
} else if c.Kafka.SASL.Method == SASLSCRAM256 || c.Kafka.SASL.Method == SASLSCRAM512 {
96+
saslAuthByProxy = &SASLSCRAMAuth{
97+
clientID: c.Kafka.ClientID,
98+
writeTimeout: c.Kafka.WriteTimeout,
99+
readTimeout: c.Kafka.ReadTimeout,
100+
username: c.Kafka.SASL.Username,
101+
password: c.Kafka.SASL.Password,
102+
mechanism: c.Kafka.SASL.Method,
103+
}
104+
} else {
105+
return nil, errors.Errorf("SASL Mechanism not valid '%s'", c.Kafka.SASL.Method)
93106
}
94107
}
95108

proxy/sasl_by_proxy.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ import (
1616
const (
1717
SASLPlain = "PLAIN"
1818
SASLOAuthBearer = "OAUTHBEARER"
19+
SASLSCRAM256 = "SCRAM-SHA-256"
20+
SASLSCRAM512 = "SCRAM-SHA-512"
1921
)
2022

2123
type SASLHandshake struct {
@@ -117,8 +119,7 @@ func (b *SASLPlainAuth) sendSaslAuthenticateRequest(conn DeadlineReaderWriter) e
117119
}
118120

119121
func (b *SASLHandshake) sendAndReceiveHandshake(conn DeadlineReaderWriter) error {
120-
logrus.Debugf("Sending SaslHandshakeRequest")
121-
122+
logrus.Debugf("Sending SaslHandshakeRequest mechanism: %v version: %v", b.mechanism, b.version)
122123
req := &protocol.Request{
123124
ClientID: b.clientID,
124125
Body: &protocol.SaslHandshakeRequestV0orV1{Version: b.version, Mechanism: b.mechanism},
@@ -165,6 +166,8 @@ func (b *SASLHandshake) sendAndReceiveHandshake(conn DeadlineReaderWriter) error
165166
if res.Err != protocol.ErrNoError {
166167
return errors.Wrap(res.Err, "Invalid SASL Mechanism")
167168
}
169+
170+
logrus.Debugf("Successful SASL handshake. Available mechanisms: %v", res.EnabledMechanisms)
168171
return nil
169172
}
170173

proxy/sasl_scram.go

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
package proxy
2+
3+
import (
4+
"bytes"
5+
"crypto/sha256"
6+
"crypto/sha512"
7+
"encoding/binary"
8+
"fmt"
9+
"github.com/grepplabs/kafka-proxy/proxy/protocol"
10+
"github.com/sirupsen/logrus"
11+
"github.com/xdg/scram"
12+
"hash"
13+
"io"
14+
"time"
15+
)
16+
17+
// Most of this is a direct copy from Shopify's Sarama found here:
18+
// https://github.com/Shopify/sarama
19+
// The commented out lines of code match with Sarama's
20+
21+
type SASLSCRAMAuth struct {
22+
clientID string
23+
24+
writeTimeout time.Duration
25+
readTimeout time.Duration
26+
27+
username string
28+
password string
29+
mechanism string
30+
correlationID int32
31+
32+
// authz id used for SASL/SCRAM authentication
33+
SCRAMAuthzID string
34+
}
35+
36+
// Workaround for xdg-go not having accepted this pull request:
37+
// https://github.com/xdg-go/scram/pull/1/commits
38+
var SHA256 scram.HashGeneratorFcn = func() hash.Hash { return sha256.New() }
39+
var SHA512 scram.HashGeneratorFcn = func() hash.Hash { return sha512.New() }
40+
41+
// Maps to Sarama sendAndReceiveSASLSCRAMv1
42+
func (b *SASLSCRAMAuth) sendAndReceiveSASLAuth(conn DeadlineReaderWriter) error {
43+
44+
err := b.sendAndReceiveSASLHandshake(conn)
45+
if err != nil {
46+
logrus.Debugf("SASL Handshake fails")
47+
return err
48+
}
49+
50+
var scramClient *scram.Client
51+
if b.mechanism == "SCRAM-SHA-256" {
52+
scramClient, err = SHA256.NewClient(b.username, b.password, "")
53+
if err != nil {
54+
logrus.Debugf("Unable to make scram client for SCRAM-SHA-256: %v", err)
55+
return err
56+
}
57+
} else if b.mechanism == "SCRAM-SHA-512" {
58+
scramClient, err = SHA512.NewClient(b.username, b.password, "")
59+
if err != nil {
60+
logrus.Debugf("Unable to make scram client for SCRAM-SHA-512: %v", err)
61+
return err
62+
}
63+
} else {
64+
return fmt.Errorf("Invalid SCRAM specification provided: %s. Expected one of [\"SCRAM-SHA-256\",\"SCRAM-SHA-512\"]", b.mechanism)
65+
}
66+
67+
//if err := scramClient.Begin(b.username, b.password, b.SCRAMAuthzID); err != nil {
68+
// return fmt.Errorf("failed to start SCRAM exchange with the server: %s", err.Error())
69+
//}
70+
scramConversation := scramClient.NewConversation()
71+
72+
//msg, err := scramClient.Step("")
73+
msg, err := scramConversation.Step("")
74+
if err != nil {
75+
return fmt.Errorf("failed to advance the SCRAM exchange: %s", err.Error())
76+
}
77+
78+
logrus.Debugf("Commencing scram loop")
79+
for !scramConversation.Done() {
80+
//requestTime := time.Now()
81+
correlationID := b.correlationID
82+
// bytesWritten, err := b.sendSaslAuthenticateRequest(correlationID, []byte(msg))
83+
_, err := b.sendSaslAuthenticateRequest(conn, correlationID, []byte(msg))
84+
if err != nil {
85+
//Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error())
86+
logrus.Debugf("Failed to write SASL auth header to broker: %s\n", err.Error())
87+
return err
88+
}
89+
90+
b.correlationID++
91+
//challenge, err := b.receiveSaslAuthenticateResponse(correlationID)
92+
challenge, err := b.receiveSaslAuthenticateResponse(conn, correlationID)
93+
if err != nil {
94+
//Logger.Printf("Failed to read response while authenticating with SASL to broker %s: %s\n", b.addr, err.Error())
95+
logrus.Debugf("Failed to read response while authenticating with SASL to broker: %s\n", err.Error())
96+
return err
97+
}
98+
99+
msg, err = scramConversation.Step(string(challenge))
100+
if err != nil {
101+
logrus.Debugf("SASL authentication failed", err)
102+
//Logger.Println("SASL authentication failed", err)
103+
return err
104+
}
105+
}
106+
107+
logrus.Debugf("SASL SCRAM authentication succeeded")
108+
return nil
109+
}
110+
111+
func (b *SASLSCRAMAuth) sendAndReceiveSASLHandshake(conn DeadlineReaderWriter) error {
112+
logrus.Debugf("SASLSCRAM: Doing handshake. Mechanism: %s", b.mechanism)
113+
114+
rb := &protocol.SaslHandshakeRequestV0orV1{
115+
Version: 1,
116+
Mechanism: b.mechanism,
117+
}
118+
119+
req := &protocol.Request{ClientID: b.clientID, Body: rb}
120+
//req := &protocol.Request{CorrelationID: b.correlationID, ClientID: b.clientID, Body: rb}
121+
buf, err := protocol.Encode(req)
122+
if err != nil {
123+
logrus.Debugf("Error encoding protocol.Request: %v", err)
124+
return err
125+
}
126+
sizeBuf := make([]byte, 4)
127+
binary.BigEndian.PutUint32(sizeBuf, uint32(len(buf)))
128+
129+
if err := conn.SetWriteDeadline(time.Now().Add(b.writeTimeout)); err != nil {
130+
return err
131+
}
132+
133+
bytes, err := conn.Write(bytes.Join([][]byte{sizeBuf, buf}, nil))
134+
//bytes, err := conn.Write(buf)
135+
if err != nil {
136+
logrus.Debugf("Failed to send SASL handshake: %s bytes: %v\n", err.Error(), bytes)
137+
return err
138+
}
139+
140+
b.correlationID++
141+
//wait for the response
142+
header := make([]byte, 8) // response header
143+
bytes, err = io.ReadFull(conn, header)
144+
if err != nil {
145+
logrus.Debugf("Failed to read SASL handshake header [%v]: %v\n", bytes, err)
146+
return err
147+
}
148+
149+
length := binary.BigEndian.Uint32(header[:4])
150+
payload := make([]byte, length-4)
151+
n, err := io.ReadFull(conn, payload)
152+
if err != nil {
153+
logrus.Debugf("Failed to read SASL handshake payload : %s bytes: %v\n", err.Error(), n)
154+
return err
155+
}
156+
157+
res := &protocol.SaslHandshakeResponseV0orV1{}
158+
159+
err = protocol.Decode(payload, res)
160+
if err != nil {
161+
logrus.Debugf("Failed to parse SASL handshake : %s\n", err.Error())
162+
return err
163+
}
164+
165+
if res.Err != protocol.ErrNoError {
166+
logrus.Debugf("Invalid SASL Mechanism : %s\n", res.Err.Error())
167+
return res.Err
168+
}
169+
170+
logrus.Debugf("Successful SASL handshake. Available mechanisms: %v", res.EnabledMechanisms)
171+
172+
return nil
173+
}
174+
175+
func (b *SASLSCRAMAuth) sendSaslAuthenticateRequest(conn DeadlineReaderWriter, correlationID int32, msg []byte) (int, error) {
176+
// rb := &SaslAuthenticateRequest{msg}
177+
rb := &protocol.SaslAuthenticateRequestV0{msg}
178+
//req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb}
179+
req := &protocol.Request{CorrelationID: correlationID, ClientID: b.clientID, Body: rb}
180+
//buf, err := encode(req, b.conf.MetricRegistry)
181+
buf, err := protocol.Encode(req)
182+
if err != nil {
183+
logrus.Debugf("Failed to encode")
184+
return 0, err
185+
}
186+
187+
if err := conn.SetWriteDeadline(time.Now().Add(b.writeTimeout)); err != nil {
188+
return 0, err
189+
}
190+
191+
sizeBuf := make([]byte, 4)
192+
binary.BigEndian.PutUint32(sizeBuf, uint32(len(buf)))
193+
return conn.Write(bytes.Join([][]byte{sizeBuf, buf}, nil))
194+
}
195+
196+
func (b *SASLSCRAMAuth) receiveSaslAuthenticateResponse(conn DeadlineReaderWriter, correlationID int32) ([]byte, error) {
197+
const responseLengthSize = 4
198+
const correlationIDSize = 4
199+
200+
buf := make([]byte, responseLengthSize+correlationIDSize)
201+
_, err := io.ReadFull(conn, buf)
202+
if err != nil {
203+
logrus.Debugf("Failed to read from broker: %v", err)
204+
return nil, err
205+
}
206+
207+
//header := responseHeader{}
208+
header := protocol.ResponseHeader{}
209+
//err = decode(buf, &header)
210+
err = protocol.Decode(buf, &header)
211+
if err != nil {
212+
return nil, err
213+
}
214+
215+
if header.CorrelationID != correlationID {
216+
return nil, fmt.Errorf("correlation ID didn't match, wanted %d, got %d", correlationID, header.CorrelationID)
217+
}
218+
219+
buf = make([]byte, header.Length-correlationIDSize)
220+
_, err = io.ReadFull(conn, buf)
221+
if err != nil {
222+
return nil, err
223+
}
224+
225+
//res := &SaslAuthenticateResponse{}
226+
res := &protocol.SaslAuthenticateResponseV0{}
227+
//if err := versionedDecode(buf, res, 0); err != nil {
228+
err = protocol.Decode(buf, res)
229+
if err != nil {
230+
return nil, err
231+
}
232+
if res.Err != protocol.ErrNoError {
233+
return nil, res.Err
234+
}
235+
return res.SaslAuthBytes, nil
236+
}

0 commit comments

Comments
 (0)