|
| 1 | +package proxy |
| 2 | + |
| 3 | +import ( |
| 4 | + "bytes" |
| 5 | + "crypto/sha256" |
| 6 | + "crypto/sha512" |
| 7 | + "encoding/binary" |
| 8 | + "fmt" |
| 9 | + "github.com/grepplabs/kafka-proxy/proxy/protocol" |
| 10 | + "github.com/sirupsen/logrus" |
| 11 | + "github.com/xdg/scram" |
| 12 | + "hash" |
| 13 | + "io" |
| 14 | + "time" |
| 15 | +) |
| 16 | + |
| 17 | +// Most of this is a direct copy from Shopify's Sarama found here: |
| 18 | +// https://github.com/Shopify/sarama |
| 19 | +// The commented out lines of code match with Sarama's |
| 20 | + |
| 21 | +type SASLSCRAMAuth struct { |
| 22 | + clientID string |
| 23 | + |
| 24 | + writeTimeout time.Duration |
| 25 | + readTimeout time.Duration |
| 26 | + |
| 27 | + username string |
| 28 | + password string |
| 29 | + mechanism string |
| 30 | + correlationID int32 |
| 31 | + |
| 32 | + // authz id used for SASL/SCRAM authentication |
| 33 | + SCRAMAuthzID string |
| 34 | +} |
| 35 | + |
| 36 | +// Workaround for xdg-go not having accepted this pull request: |
| 37 | +// https://github.com/xdg-go/scram/pull/1/commits |
| 38 | +var SHA256 scram.HashGeneratorFcn = func() hash.Hash { return sha256.New() } |
| 39 | +var SHA512 scram.HashGeneratorFcn = func() hash.Hash { return sha512.New() } |
| 40 | + |
| 41 | +// Maps to Sarama sendAndReceiveSASLSCRAMv1 |
| 42 | +func (b *SASLSCRAMAuth) sendAndReceiveSASLAuth(conn DeadlineReaderWriter) error { |
| 43 | + |
| 44 | + err := b.sendAndReceiveSASLHandshake(conn) |
| 45 | + if err != nil { |
| 46 | + logrus.Debugf("SASL Handshake fails") |
| 47 | + return err |
| 48 | + } |
| 49 | + |
| 50 | + var scramClient *scram.Client |
| 51 | + if b.mechanism == "SCRAM-SHA-256" { |
| 52 | + scramClient, err = SHA256.NewClient(b.username, b.password, "") |
| 53 | + if err != nil { |
| 54 | + logrus.Debugf("Unable to make scram client for SCRAM-SHA-256: %v", err) |
| 55 | + return err |
| 56 | + } |
| 57 | + } else if b.mechanism == "SCRAM-SHA-512" { |
| 58 | + scramClient, err = SHA512.NewClient(b.username, b.password, "") |
| 59 | + if err != nil { |
| 60 | + logrus.Debugf("Unable to make scram client for SCRAM-SHA-512: %v", err) |
| 61 | + return err |
| 62 | + } |
| 63 | + } else { |
| 64 | + return fmt.Errorf("Invalid SCRAM specification provided: %s. Expected one of [\"SCRAM-SHA-256\",\"SCRAM-SHA-512\"]", b.mechanism) |
| 65 | + } |
| 66 | + |
| 67 | + //if err := scramClient.Begin(b.username, b.password, b.SCRAMAuthzID); err != nil { |
| 68 | + // return fmt.Errorf("failed to start SCRAM exchange with the server: %s", err.Error()) |
| 69 | + //} |
| 70 | + scramConversation := scramClient.NewConversation() |
| 71 | + |
| 72 | + //msg, err := scramClient.Step("") |
| 73 | + msg, err := scramConversation.Step("") |
| 74 | + if err != nil { |
| 75 | + return fmt.Errorf("failed to advance the SCRAM exchange: %s", err.Error()) |
| 76 | + } |
| 77 | + |
| 78 | + logrus.Debugf("Commencing scram loop") |
| 79 | + for !scramConversation.Done() { |
| 80 | + //requestTime := time.Now() |
| 81 | + correlationID := b.correlationID |
| 82 | + // bytesWritten, err := b.sendSaslAuthenticateRequest(correlationID, []byte(msg)) |
| 83 | + _, err := b.sendSaslAuthenticateRequest(conn, correlationID, []byte(msg)) |
| 84 | + if err != nil { |
| 85 | + //Logger.Printf("Failed to write SASL auth header to broker %s: %s\n", b.addr, err.Error()) |
| 86 | + logrus.Debugf("Failed to write SASL auth header to broker: %s\n", err.Error()) |
| 87 | + return err |
| 88 | + } |
| 89 | + |
| 90 | + b.correlationID++ |
| 91 | + //challenge, err := b.receiveSaslAuthenticateResponse(correlationID) |
| 92 | + challenge, err := b.receiveSaslAuthenticateResponse(conn, correlationID) |
| 93 | + if err != nil { |
| 94 | + //Logger.Printf("Failed to read response while authenticating with SASL to broker %s: %s\n", b.addr, err.Error()) |
| 95 | + logrus.Debugf("Failed to read response while authenticating with SASL to broker: %s\n", err.Error()) |
| 96 | + return err |
| 97 | + } |
| 98 | + |
| 99 | + msg, err = scramConversation.Step(string(challenge)) |
| 100 | + if err != nil { |
| 101 | + logrus.Debugf("SASL authentication failed", err) |
| 102 | + //Logger.Println("SASL authentication failed", err) |
| 103 | + return err |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + logrus.Debugf("SASL SCRAM authentication succeeded") |
| 108 | + return nil |
| 109 | +} |
| 110 | + |
| 111 | +func (b *SASLSCRAMAuth) sendAndReceiveSASLHandshake(conn DeadlineReaderWriter) error { |
| 112 | + logrus.Debugf("SASLSCRAM: Doing handshake. Mechanism: %s", b.mechanism) |
| 113 | + |
| 114 | + rb := &protocol.SaslHandshakeRequestV0orV1{ |
| 115 | + Version: 1, |
| 116 | + Mechanism: b.mechanism, |
| 117 | + } |
| 118 | + |
| 119 | + req := &protocol.Request{ClientID: b.clientID, Body: rb} |
| 120 | + //req := &protocol.Request{CorrelationID: b.correlationID, ClientID: b.clientID, Body: rb} |
| 121 | + buf, err := protocol.Encode(req) |
| 122 | + if err != nil { |
| 123 | + logrus.Debugf("Error encoding protocol.Request: %v", err) |
| 124 | + return err |
| 125 | + } |
| 126 | + sizeBuf := make([]byte, 4) |
| 127 | + binary.BigEndian.PutUint32(sizeBuf, uint32(len(buf))) |
| 128 | + |
| 129 | + if err := conn.SetWriteDeadline(time.Now().Add(b.writeTimeout)); err != nil { |
| 130 | + return err |
| 131 | + } |
| 132 | + |
| 133 | + bytes, err := conn.Write(bytes.Join([][]byte{sizeBuf, buf}, nil)) |
| 134 | + //bytes, err := conn.Write(buf) |
| 135 | + if err != nil { |
| 136 | + logrus.Debugf("Failed to send SASL handshake: %s bytes: %v\n", err.Error(), bytes) |
| 137 | + return err |
| 138 | + } |
| 139 | + |
| 140 | + b.correlationID++ |
| 141 | + //wait for the response |
| 142 | + header := make([]byte, 8) // response header |
| 143 | + bytes, err = io.ReadFull(conn, header) |
| 144 | + if err != nil { |
| 145 | + logrus.Debugf("Failed to read SASL handshake header [%v]: %v\n", bytes, err) |
| 146 | + return err |
| 147 | + } |
| 148 | + |
| 149 | + length := binary.BigEndian.Uint32(header[:4]) |
| 150 | + payload := make([]byte, length-4) |
| 151 | + n, err := io.ReadFull(conn, payload) |
| 152 | + if err != nil { |
| 153 | + logrus.Debugf("Failed to read SASL handshake payload : %s bytes: %v\n", err.Error(), n) |
| 154 | + return err |
| 155 | + } |
| 156 | + |
| 157 | + res := &protocol.SaslHandshakeResponseV0orV1{} |
| 158 | + |
| 159 | + err = protocol.Decode(payload, res) |
| 160 | + if err != nil { |
| 161 | + logrus.Debugf("Failed to parse SASL handshake : %s\n", err.Error()) |
| 162 | + return err |
| 163 | + } |
| 164 | + |
| 165 | + if res.Err != protocol.ErrNoError { |
| 166 | + logrus.Debugf("Invalid SASL Mechanism : %s\n", res.Err.Error()) |
| 167 | + return res.Err |
| 168 | + } |
| 169 | + |
| 170 | + logrus.Debugf("Successful SASL handshake. Available mechanisms: %v", res.EnabledMechanisms) |
| 171 | + |
| 172 | + return nil |
| 173 | +} |
| 174 | + |
| 175 | +func (b *SASLSCRAMAuth) sendSaslAuthenticateRequest(conn DeadlineReaderWriter, correlationID int32, msg []byte) (int, error) { |
| 176 | + // rb := &SaslAuthenticateRequest{msg} |
| 177 | + rb := &protocol.SaslAuthenticateRequestV0{msg} |
| 178 | + //req := &request{correlationID: correlationID, clientID: b.conf.ClientID, body: rb} |
| 179 | + req := &protocol.Request{CorrelationID: correlationID, ClientID: b.clientID, Body: rb} |
| 180 | + //buf, err := encode(req, b.conf.MetricRegistry) |
| 181 | + buf, err := protocol.Encode(req) |
| 182 | + if err != nil { |
| 183 | + logrus.Debugf("Failed to encode") |
| 184 | + return 0, err |
| 185 | + } |
| 186 | + |
| 187 | + if err := conn.SetWriteDeadline(time.Now().Add(b.writeTimeout)); err != nil { |
| 188 | + return 0, err |
| 189 | + } |
| 190 | + |
| 191 | + sizeBuf := make([]byte, 4) |
| 192 | + binary.BigEndian.PutUint32(sizeBuf, uint32(len(buf))) |
| 193 | + return conn.Write(bytes.Join([][]byte{sizeBuf, buf}, nil)) |
| 194 | +} |
| 195 | + |
| 196 | +func (b *SASLSCRAMAuth) receiveSaslAuthenticateResponse(conn DeadlineReaderWriter, correlationID int32) ([]byte, error) { |
| 197 | + const responseLengthSize = 4 |
| 198 | + const correlationIDSize = 4 |
| 199 | + |
| 200 | + buf := make([]byte, responseLengthSize+correlationIDSize) |
| 201 | + _, err := io.ReadFull(conn, buf) |
| 202 | + if err != nil { |
| 203 | + logrus.Debugf("Failed to read from broker: %v", err) |
| 204 | + return nil, err |
| 205 | + } |
| 206 | + |
| 207 | + //header := responseHeader{} |
| 208 | + header := protocol.ResponseHeader{} |
| 209 | + //err = decode(buf, &header) |
| 210 | + err = protocol.Decode(buf, &header) |
| 211 | + if err != nil { |
| 212 | + return nil, err |
| 213 | + } |
| 214 | + |
| 215 | + if header.CorrelationID != correlationID { |
| 216 | + return nil, fmt.Errorf("correlation ID didn't match, wanted %d, got %d", correlationID, header.CorrelationID) |
| 217 | + } |
| 218 | + |
| 219 | + buf = make([]byte, header.Length-correlationIDSize) |
| 220 | + _, err = io.ReadFull(conn, buf) |
| 221 | + if err != nil { |
| 222 | + return nil, err |
| 223 | + } |
| 224 | + |
| 225 | + //res := &SaslAuthenticateResponse{} |
| 226 | + res := &protocol.SaslAuthenticateResponseV0{} |
| 227 | + //if err := versionedDecode(buf, res, 0); err != nil { |
| 228 | + err = protocol.Decode(buf, res) |
| 229 | + if err != nil { |
| 230 | + return nil, err |
| 231 | + } |
| 232 | + if res.Err != protocol.ErrNoError { |
| 233 | + return nil, res.Err |
| 234 | + } |
| 235 | + return res.SaslAuthBytes, nil |
| 236 | +} |
0 commit comments