Skip to content

Commit a11af9f

Browse files
e-maxhashmap
authored andcommitted
Add tls (#85)
* add TLS encryption * accept TLS certificates directly, not though files * send ReplicaID = -1 * fix after review * add test for TLS * fix testserver * fix data race
1 parent cec0458 commit a11af9f

File tree

12 files changed

+540
-11
lines changed

12 files changed

+540
-11
lines changed

broker.go

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,16 @@ type BrokerConf struct {
168168
// logging frameworks. Used to notify and as replacement for stdlib `log`
169169
// package.
170170
Logger Logger
171+
172+
//Settings for TLS encryption.
173+
//You need to set all these parameters to enable TLS
174+
175+
//TLS CA pem
176+
TLSCa []byte
177+
//TLS certificate
178+
TLSCert []byte
179+
//TLS key
180+
TLSKey []byte
171181
}
172182

173183
// NewBrokerConf returns the default broker configuration.
@@ -311,7 +321,7 @@ func (b *Broker) fetchMetadata(topics ...string) (*proto.MetadataResp, error) {
311321
if _, ok := checkednodes[nodeID]; ok {
312322
continue
313323
}
314-
conn, err := newTCPConnection(addr, b.conf.DialTimeout, b.conf.ReadTimeout)
324+
conn, err := b.getConnection(addr)
315325
if err != nil {
316326
b.conf.Logger.Debug("cannot connect",
317327
"address", addr,
@@ -336,7 +346,7 @@ func (b *Broker) fetchMetadata(topics ...string) (*proto.MetadataResp, error) {
336346
}
337347

338348
for _, addr := range b.getInitialAddresses() {
339-
conn, err := newTCPConnection(addr, b.conf.DialTimeout, b.conf.ReadTimeout)
349+
conn, err := b.getConnection(addr)
340350
if err != nil {
341351
b.conf.Logger.Debug("cannot connect to seed node",
342352
"address", addr,
@@ -475,7 +485,7 @@ func (b *Broker) muLeaderConnection(topic string, partition int32) (conn *connec
475485
delete(b.metadata.endpoints, tp)
476486
continue
477487
}
478-
conn, err = newTCPConnection(addr, b.conf.DialTimeout, b.conf.ReadTimeout)
488+
conn, err = b.getConnection(addr)
479489
if err != nil {
480490
b.conf.Logger.Info("cannot get leader connection: cannot connect to node",
481491
"address", addr,
@@ -490,6 +500,13 @@ func (b *Broker) muLeaderConnection(topic string, partition int32) (conn *connec
490500
return nil, err
491501
}
492502

503+
func (b *Broker) getConnection(addr string) (*connection, error) {
504+
if b.conf.TLSCa != nil && b.conf.TLSKey != nil && b.conf.TLSCert != nil {
505+
return newTLSConnection(addr, b.conf.TLSCa, b.conf.TLSCert, b.conf.TLSKey, b.conf.DialTimeout, b.conf.ReadTimeout)
506+
}
507+
return newTCPConnection(addr, b.conf.DialTimeout, b.conf.ReadTimeout)
508+
}
509+
493510
// coordinatorConnection returns connection to offset coordinator for given group.
494511
//
495512
// Failed connection retry is controlled by broker configuration.
@@ -526,7 +543,7 @@ func (b *Broker) muCoordinatorConnection(consumerGroup string) (conn *connection
526543
}
527544

528545
addr := fmt.Sprintf("%s:%d", resp.CoordinatorHost, resp.CoordinatorPort)
529-
conn, err := newTCPConnection(addr, b.conf.DialTimeout, b.conf.ReadTimeout)
546+
conn, err := b.getConnection(addr)
530547
if err != nil {
531548
b.conf.Logger.Debug("cannot connect to node",
532549
"coordinatorID", resp.CoordinatorID,
@@ -552,7 +569,7 @@ func (b *Broker) muCoordinatorConnection(consumerGroup string) (conn *connection
552569
// connection to node is cached so it was already checked
553570
continue
554571
}
555-
conn, err := newTCPConnection(addr, b.conf.DialTimeout, b.conf.ReadTimeout)
572+
conn, err := b.getConnection(addr)
556573
if err != nil {
557574
b.conf.Logger.Debug("cannot connect to node",
558575
"nodeID", nodeID,
@@ -583,7 +600,7 @@ func (b *Broker) muCoordinatorConnection(consumerGroup string) (conn *connection
583600
}
584601

585602
addr := fmt.Sprintf("%s:%d", resp.CoordinatorHost, resp.CoordinatorPort)
586-
conn, err = newTCPConnection(addr, b.conf.DialTimeout, b.conf.ReadTimeout)
603+
conn, err = b.getConnection(addr)
587604
if err != nil {
588605
b.conf.Logger.Debug("cannot connect to node",
589606
"coordinatorID", resp.CoordinatorID,

connection.go

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package kafka
33
import (
44
"bufio"
55
"bytes"
6+
"crypto/tls"
7+
"crypto/x509"
68
"errors"
79
"fmt"
810
"math"
@@ -29,6 +31,44 @@ type connection struct {
2931
readTimeout time.Duration
3032
}
3133

34+
func newTLSConnection(address string, ca, cert, key []byte, timeout, readTimeout time.Duration) (*connection, error) {
35+
roots := x509.NewCertPool()
36+
ok := roots.AppendCertsFromPEM(ca)
37+
if !ok {
38+
return nil, fmt.Errorf("Cannot parse root certificate")
39+
}
40+
41+
certificate, err := tls.X509KeyPair(cert, key)
42+
if err != nil {
43+
return nil, fmt.Errorf("Failed to parse key/cert for TLS: %s", err)
44+
}
45+
46+
conf := &tls.Config{
47+
Certificates: []tls.Certificate{certificate},
48+
RootCAs: roots,
49+
}
50+
51+
dialer := net.Dialer{
52+
Timeout: timeout,
53+
KeepAlive: 30 * time.Second,
54+
}
55+
conn, err := tls.DialWithDialer(&dialer, "tcp", address, conf)
56+
if err != nil {
57+
return nil, err
58+
}
59+
c := &connection{
60+
stop: make(chan struct{}),
61+
nextID: make(chan int32),
62+
rw: conn,
63+
respc: make(map[int32]chan []byte),
64+
logger: &nullLogger{},
65+
readTimeout: readTimeout,
66+
}
67+
go c.nextIDLoop()
68+
go c.readRespLoop()
69+
return c, nil
70+
}
71+
3272
// newConnection returns new, initialized connection or error
3373
func newTCPConnection(address string, timeout, readTimeout time.Duration) (*connection, error) {
3474
dialer := net.Dialer{

connection_test.go

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
package kafka
22

33
import (
4+
"crypto/tls"
5+
"crypto/x509"
6+
"fmt"
7+
"io"
8+
"io/ioutil"
9+
"log"
410
"net"
511
"reflect"
612
"strings"
@@ -10,10 +16,120 @@ import (
1016
"github.com/optiopay/kafka/proto"
1117
)
1218

19+
const TLSCaFile = "./testkeys/ca.crt"
20+
const TLSCertFile = "./testkeys/oats.crt"
21+
const TLSKeyFile = "./testkeys/oats.key"
22+
1323
type serializableMessage interface {
1424
Bytes() ([]byte, error)
1525
}
1626

27+
type TLSConf struct {
28+
ca []byte
29+
cert []byte
30+
key []byte
31+
}
32+
33+
func getTLSConf() (*TLSConf, error) {
34+
ca, err := ioutil.ReadFile(TLSCaFile)
35+
if err != nil {
36+
return nil, fmt.Errorf("Cannot read %s", TLSCaFile)
37+
}
38+
cert, err := ioutil.ReadFile(TLSCertFile)
39+
if err != nil {
40+
return nil, fmt.Errorf("Cannot read %s", TLSCertFile)
41+
}
42+
43+
key, err := ioutil.ReadFile(TLSKeyFile)
44+
if err != nil {
45+
return nil, fmt.Errorf("Cannot read %s", TLSKeyFile)
46+
}
47+
48+
return &TLSConf{ca: ca, cert: cert, key: key}, nil
49+
50+
}
51+
52+
//just read request before start to response
53+
func readRequest(r io.Reader) error {
54+
dec := proto.NewDecoder(r)
55+
size := dec.DecodeInt32()
56+
var read int32 = 0
57+
buf := make([]byte, size)
58+
59+
for read < size {
60+
n, err := r.Read(buf)
61+
if err != nil {
62+
return err
63+
}
64+
read += int32(n)
65+
}
66+
return nil
67+
}
68+
69+
func testTLSServer(messages ...serializableMessage) (net.Listener, error) {
70+
tlsConf, err := getTLSConf()
71+
if err != nil {
72+
return nil, err
73+
}
74+
75+
roots := x509.NewCertPool()
76+
ok := roots.AppendCertsFromPEM(tlsConf.ca)
77+
if !ok {
78+
return nil, fmt.Errorf("Cannot parse root certificate")
79+
}
80+
81+
certificate, err := tls.X509KeyPair(tlsConf.cert, tlsConf.key)
82+
if err != nil {
83+
return nil, fmt.Errorf("Failed to parse key/cert for TLS: %s", err)
84+
}
85+
86+
conf := &tls.Config{
87+
Certificates: []tls.Certificate{certificate},
88+
RootCAs: roots,
89+
}
90+
91+
_ = conf
92+
93+
ln, err := tls.Listen("tcp4", "localhost:22222", conf)
94+
if err != nil {
95+
return nil, err
96+
}
97+
98+
responses := make([][]byte, len(messages))
99+
for i, m := range messages {
100+
b, err := m.Bytes()
101+
if err != nil {
102+
_ = ln.Close()
103+
return nil, err
104+
}
105+
responses[i] = b
106+
}
107+
108+
go func() {
109+
for {
110+
cli, err := ln.Accept()
111+
112+
if err != nil {
113+
return
114+
}
115+
116+
go func(conn net.Conn) {
117+
err := readRequest(conn)
118+
if err != nil {
119+
log.Panic(err)
120+
}
121+
122+
time.Sleep(time.Millisecond * 50)
123+
for _, resp := range responses {
124+
_, _ = cli.Write(resp)
125+
}
126+
err = cli.Close()
127+
}(cli)
128+
}
129+
}()
130+
return ln, nil
131+
}
132+
17133
func testServer(messages ...serializableMessage) (net.Listener, error) {
18134
ln, err := net.Listen("tcp4", "")
19135
if err != nil {
@@ -620,3 +736,59 @@ func TestNoServerResponse(t *testing.T) {
620736
t.Fatalf("could not close test server: %s", err)
621737
}
622738
}
739+
740+
func TestTLSConnection(t *testing.T) {
741+
resp1 := &proto.MetadataResp{
742+
CorrelationID: 1,
743+
Brokers: []proto.MetadataRespBroker{
744+
{
745+
NodeID: 666,
746+
Host: "example.com",
747+
Port: 999,
748+
},
749+
},
750+
Topics: []proto.MetadataRespTopic{
751+
{
752+
Name: "foo",
753+
Partitions: []proto.MetadataRespPartition{
754+
{
755+
ID: 7,
756+
Leader: 7,
757+
Replicas: []int32{7},
758+
Isrs: []int32{7},
759+
},
760+
},
761+
},
762+
},
763+
}
764+
ln, err := testTLSServer(resp1)
765+
if err != nil {
766+
t.Fatalf("test server error: %s", err)
767+
}
768+
tlsConf, err := getTLSConf()
769+
if err != nil {
770+
t.Fatalf("cannot get tls parametes: %s", err)
771+
}
772+
_ = tlsConf
773+
conn, err := newTLSConnection(ln.Addr().String(), tlsConf.ca, tlsConf.cert, tlsConf.key, time.Second, time.Second)
774+
775+
if err != nil {
776+
t.Fatalf("could not conect to test server: %s", err)
777+
}
778+
resp, err := conn.Metadata(&proto.MetadataReq{
779+
ClientID: "tester",
780+
Topics: []string{"first", "second"},
781+
})
782+
if err != nil {
783+
t.Fatalf("could not fetch response: %s", err)
784+
}
785+
if !reflect.DeepEqual(resp, resp1) {
786+
t.Fatalf("expected different response %#v", resp)
787+
}
788+
if err := conn.Close(); err != nil {
789+
t.Fatalf("could not close kafka connection: %s", err)
790+
}
791+
if err := ln.Close(); err != nil {
792+
t.Fatalf("could not close test server: %s", err)
793+
}
794+
}

kafkatest/server.go

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func NewServer(middlewares ...Middleware) *Server {
5454
topics: make(map[string]map[int32][]*proto.Message),
5555
offsets: make(map[string]map[int32]map[string]*topicOffset),
5656
middlewares: middlewares,
57-
events: make(chan struct{}),
57+
events: make(chan struct{}, 1000),
5858
}
5959
return s
6060
}
@@ -370,8 +370,7 @@ func (s *Server) handleProduceRequest(nodeID int32, conn net.Conn, req *proto.Pr
370370
respParts[pi].Offset = int64(len(t[part.ID])) - 1
371371
}
372372
}
373-
close(s.events)
374-
s.events = make(chan struct{})
373+
s.events <- struct{}{}
375374
return resp
376375
}
377376

proto/messages.go

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,9 @@ func (r *FetchReq) Bytes() ([]byte, error) {
773773
enc.Encode(r.CorrelationID)
774774
enc.Encode(r.ClientID)
775775

776-
enc.Encode(r.ReplicaID)
776+
//enc.Encode(r.ReplicaID)
777+
enc.Encode(int32(-1))
778+
777779
enc.Encode(r.MaxWaitTime)
778780
enc.Encode(r.MinBytes)
779781

@@ -1808,7 +1810,8 @@ func (r *OffsetReq) Bytes() ([]byte, error) {
18081810
enc.Encode(r.CorrelationID)
18091811
enc.Encode(r.ClientID)
18101812

1811-
enc.Encode(r.ReplicaID)
1813+
//enc.Encode(r.ReplicaID)
1814+
enc.Encode(int32(-1))
18121815

18131816
if r.Version >= KafkaV2 {
18141817
enc.Encode(r.IsolationLevel)

0 commit comments

Comments
 (0)