Skip to content

Commit 543e2cc

Browse files
authored
Improve listener handling and fix tests (#92)
* Improve listener handling and fix tests While improving the close handling of the listener in `Proxy` I noticed that this could be generalized to handle multiple listeners (like `http.Server`). This could be quite useful in the future for supporting different multi-proxy features in a more user-friendly and resource efficient way. So `Proxy`, the type, now allows multiple listeners in a single proxy. This hasn't been exposed in an external feature yet. This also improves clean up for servers, and their listeners, used during testing and now tests generate unique, unbound ports for each test to avoid sporadic errors I was seeing in CI. * Fix comment
1 parent 1ede1d1 commit 543e2cc

File tree

5 files changed

+355
-207
lines changed

5 files changed

+355
-207
lines changed

proxy/proxy.go

Lines changed: 114 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ import (
3131
"strconv"
3232
"strings"
3333
"sync"
34-
"sync/atomic"
3534
"time"
3635

3736
"github.com/datastax/cql-proxy/parser"
@@ -50,6 +49,8 @@ var (
5049
)
5150

5251
var ErrProxyClosed = errors.New("proxy closed")
52+
var ErrProxyAlreadyConnected = errors.New("proxy already connected")
53+
var ErrProxyNotConnected = errors.New("proxy not connected")
5354

5455
const preparedIdSize = 16
5556

@@ -83,18 +84,19 @@ type Proxy struct {
8384
ctx context.Context
8485
config Config
8586
logger *zap.Logger
86-
listener *net.TCPListener
8787
cluster *proxycore.Cluster
8888
sessions [primitive.ProtocolVersionDse2 + 1]sync.Map // Cache sessions per protocol version
89-
sessMu sync.Mutex
90-
schemaEventClients sync.Map
89+
mu sync.Mutex
90+
isConnected bool
91+
isClosing bool
92+
clients map[*client]struct{}
93+
listeners map[*net.Listener]struct{}
94+
eventClients sync.Map
9195
preparedCache proxycore.PreparedCache
9296
preparedIdempotence sync.Map
93-
clientIdGen uint64
9497
lb proxycore.LoadBalancer
9598
systemLocalValues map[string]message.Column
9699
closed chan struct{}
97-
closingMu sync.Mutex
98100
localNode *node
99101
nodes []*node
100102
}
@@ -109,16 +111,15 @@ func (p *Proxy) OnEvent(event proxycore.Event) {
109111
switch evt := event.(type) {
110112
case *proxycore.SchemaChangeEvent:
111113
frm := frame.NewFrame(p.cluster.NegotiatedVersion, -1, evt.Message)
112-
p.schemaEventClients.Range(func(key, value interface{}) bool {
113-
cl := value.(*client)
114+
p.eventClients.Range(func(key, _ interface{}) bool {
115+
cl := key.(*client)
114116
err := cl.conn.Write(proxycore.SenderFunc(func(writer io.Writer) error {
115117
return codec.EncodeFrame(frm, writer)
116118
}))
117119
cl.conn.LocalAddr()
118120
if err != nil {
119121
p.logger.Error("unable to send schema change event",
120122
zap.Stringer("client", cl.conn.RemoteAddr()),
121-
zap.Uint64("id", cl.id),
122123
zap.Error(err))
123124
_ = cl.conn.Close()
124125
}
@@ -138,24 +139,24 @@ func NewProxy(ctx context.Context, config Config) *Proxy {
138139
config.RetryPolicy = NewDefaultRetryPolicy()
139140
}
140141
return &Proxy{
141-
ctx: ctx,
142-
config: config,
143-
logger: proxycore.GetOrCreateNopLogger(config.Logger),
144-
closed: make(chan struct{}),
142+
ctx: ctx,
143+
config: config,
144+
logger: proxycore.GetOrCreateNopLogger(config.Logger),
145+
clients: make(map[*client]struct{}),
146+
listeners: make(map[*net.Listener]struct{}),
147+
closed: make(chan struct{}),
145148
}
146149
}
147150

148-
func (p *Proxy) ListenAndServe(address string) error {
149-
err := p.Listen(address)
150-
if err != nil {
151-
return err
151+
func (p *Proxy) Connect() error {
152+
p.mu.Lock()
153+
defer p.mu.Unlock()
154+
155+
if p.isConnected {
156+
return ErrProxyAlreadyConnected
152157
}
153-
return p.Serve()
154-
}
155158

156-
func (p *Proxy) Listen(address string) error {
157159
var err error
158-
159160
p.preparedCache, err = getOrCreateDefaultPreparedCache(p.config.PreparedCache)
160161
if err != nil {
161162
return fmt.Errorf("unable to create prepared cache %w", err)
@@ -210,23 +211,23 @@ func (p *Proxy) Listen(address string) error {
210211

211212
p.sessions[p.cluster.NegotiatedVersion].Store("", sess) // No keyspace
212213

213-
tcpAddr, err := net.ResolveTCPAddr("tcp", address)
214-
if err != nil {
215-
return err
216-
}
217-
p.listener, err = net.ListenTCP("tcp", tcpAddr)
218-
if err != nil {
219-
return err
220-
}
221-
222-
p.logger.Info("proxy is listening", zap.Stringer("address", p.listener.Addr()))
223-
214+
p.isConnected = true
224215
return nil
225216
}
226217

227-
func (p *Proxy) Serve() error {
218+
// Serve the proxy using the specified listener. It can be called multiple times with different listeners allowing
219+
// them to share the same backend clusters.
220+
func (p *Proxy) Serve(l net.Listener) (err error) {
221+
l = &closeOnceListener{Listener: l}
222+
defer l.Close()
223+
224+
if err = p.addListener(&l); err != nil {
225+
return err
226+
}
227+
defer p.removeListener(&l)
228+
228229
for {
229-
conn, err := p.listener.AcceptTCP()
230+
conn, err := l.Accept()
230231
if err != nil {
231232
select {
232233
case <-p.closed:
@@ -239,15 +240,45 @@ func (p *Proxy) Serve() error {
239240
}
240241
}
241242

243+
func (p *Proxy) addListener(l *net.Listener) error {
244+
p.mu.Lock()
245+
defer p.mu.Unlock()
246+
if p.isClosing {
247+
return ErrProxyClosed
248+
}
249+
if !p.isConnected {
250+
return ErrProxyNotConnected
251+
}
252+
p.listeners[l] = struct{}{}
253+
return nil
254+
}
255+
256+
func (p *Proxy) removeListener(l *net.Listener) {
257+
p.mu.Lock()
258+
defer p.mu.Unlock()
259+
delete(p.listeners, l)
260+
}
261+
242262
func (p *Proxy) Close() error {
243-
p.closingMu.Lock()
244-
defer p.closingMu.Unlock()
263+
p.mu.Lock()
264+
defer p.mu.Unlock()
245265
select {
246266
case <-p.closed:
247267
default:
248268
close(p.closed)
249269
}
250-
return p.listener.Close()
270+
var err error
271+
for l := range p.listeners {
272+
if closeErr := (*l).Close(); closeErr != nil && err == nil {
273+
err = closeErr
274+
}
275+
}
276+
for cl := range p.clients {
277+
_ = cl.conn.Close()
278+
p.eventClients.Delete(cl)
279+
delete(p.clients, cl)
280+
}
281+
return err
251282
}
252283

253284
func (p *Proxy) Ready() bool {
@@ -258,28 +289,29 @@ func (p *Proxy) OutageDuration() time.Duration {
258289
return p.cluster.OutageDuration()
259290
}
260291

261-
func (p *Proxy) handle(conn *net.TCPConn) {
262-
if err := conn.SetKeepAlive(false); err != nil {
263-
p.logger.Warn("failed to disable keepalive on connection", zap.Error(err))
264-
}
265-
266-
if err := conn.SetNoDelay(true); err != nil {
267-
p.logger.Warn("failed to set TCP_NODELAY on connection", zap.Error(err))
292+
func (p *Proxy) handle(conn net.Conn) {
293+
if tcpConn, ok := conn.(*net.TCPConn); ok {
294+
if err := tcpConn.SetKeepAlive(false); err != nil {
295+
p.logger.Warn("failed to disable keepalive on connection", zap.Error(err))
296+
}
297+
if err := tcpConn.SetNoDelay(true); err != nil {
298+
p.logger.Warn("failed to set TCP_NODELAY on connection", zap.Error(err))
299+
}
268300
}
269301

270302
cl := &client{
271303
ctx: p.ctx,
272304
proxy: p,
273-
id: atomic.AddUint64(&p.clientIdGen, 1),
274305
preparedSystemQuery: make(map[[preparedIdSize]byte]interface{}),
275306
}
307+
p.addClient(cl)
276308
cl.conn = proxycore.NewConn(conn, cl)
277309
cl.conn.Start()
278310
}
279311

280312
func (p *Proxy) maybeCreateSession(version primitive.ProtocolVersion, keyspace string) (*proxycore.Session, error) {
281-
p.sessMu.Lock()
282-
defer p.sessMu.Unlock()
313+
p.mu.Lock()
314+
defer p.mu.Unlock()
283315
if cachedSession, ok := p.sessions[version].Load(keyspace); ok {
284316
return cachedSession.(*proxycore.Session), nil
285317
} else {
@@ -463,12 +495,30 @@ func (p *Proxy) maybeStorePreparedIdempotence(raw *frame.RawFrame, msg message.M
463495
}
464496
}
465497

498+
func (p *Proxy) addClient(cl *client) {
499+
p.mu.Lock()
500+
defer p.mu.Unlock()
501+
p.clients[cl] = struct{}{}
502+
}
503+
504+
func (p *Proxy) registerForEvents(cl *client) {
505+
p.eventClients.Store(cl, struct{}{})
506+
}
507+
508+
func (p *Proxy) removeClient(cl *client) {
509+
p.eventClients.Delete(cl)
510+
511+
p.mu.Lock()
512+
defer p.mu.Unlock()
513+
delete(p.clients, cl)
514+
515+
}
516+
466517
type client struct {
467518
ctx context.Context
468519
proxy *Proxy
469520
conn *proxycore.Conn
470521
keyspace string
471-
id uint64
472522
preparedSystemQuery map[[16]byte]interface{}
473523
}
474524

@@ -505,7 +555,7 @@ func (c *client) Receive(reader io.Reader) error {
505555
case *message.Register:
506556
for _, t := range msg.EventTypes {
507557
if t == primitive.EventTypeSchemaChange {
508-
c.proxy.schemaEventClients.Store(c.id, c)
558+
c.proxy.registerForEvents(c)
509559
}
510560
}
511561
c.send(raw.Header, &message.Ready{})
@@ -746,7 +796,7 @@ func (c *client) send(hdr *frame.Header, msg message.Message) {
746796
}
747797

748798
func (c *client) Closing(_ error) {
749-
c.proxy.schemaEventClients.Delete(c.id)
799+
c.proxy.removeClient(c)
750800
}
751801

752802
func getOrCreateDefaultPreparedCache(cache proxycore.PreparedCache) (proxycore.PreparedCache, error) {
@@ -818,3 +868,17 @@ func compareIPAddr(a *net.IPAddr, b *net.IPAddr) int {
818868

819869
return 0
820870
}
871+
872+
// Wrap the listener so that if it's closed in the serve loop it doesn't race with proxy Close()
873+
type closeOnceListener struct {
874+
net.Listener
875+
once sync.Once
876+
closeErr error
877+
}
878+
879+
func (oc *closeOnceListener) Close() error {
880+
oc.once.Do(oc.close)
881+
return oc.closeErr
882+
}
883+
884+
func (oc *closeOnceListener) close() { oc.closeErr = oc.Listener.Close() }

proxy/proxy_retries_test.go

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import (
2525
"github.com/datastax/go-cassandra-native-protocol/message"
2626
"github.com/datastax/go-cassandra-native-protocol/primitive"
2727
"github.com/stretchr/testify/assert"
28+
"github.com/stretchr/testify/require"
2829
)
2930

3031
const idempotentQuery = "INSERT INTO test.test (k, v) VALUES ('a', 123e4567-e89b-12d3-a456-426614174000)"
@@ -195,7 +196,7 @@ func TestProxy_Retries(t *testing.T) {
195196
}
196197

197198
for _, tt := range tests {
198-
numNodesTried, retryCount, err := testProxyRetry(t, &message.Query{Query: tt.query}, tt.response)
199+
numNodesTried, retryCount, err := testProxyRetry(t, &message.Query{Query: tt.query}, tt.response, tt.msg)
199200
assert.Error(t, err, tt.msg)
200201
assert.IsType(t, err, &proxycore.CqlError{}, tt.msg)
201202
assert.Equal(t, tt.numNodesTried, numNodesTried, tt.msg)
@@ -228,7 +229,7 @@ func TestProxy_PreparedRetries(t *testing.T) {
228229
}
229230

230231
for _, tt := range tests {
231-
numNodesTried, retryCount, err := testProxyRetry(t, tt.execute, tt.response)
232+
numNodesTried, retryCount, err := testProxyRetry(t, tt.execute, tt.response, tt.msg)
232233
assert.Error(t, err, tt.msg)
233234
assert.IsType(t, err, &proxycore.CqlError{}, tt.msg)
234235
assert.Equal(t, tt.numNodesTried, numNodesTried, tt.msg)
@@ -307,23 +308,21 @@ func TestProxy_BatchRetries(t *testing.T) {
307308
}
308309

309310
for _, tt := range tests {
310-
numNodesTried, retryCount, err := testProxyRetry(t, tt.batch, tt.response)
311+
numNodesTried, retryCount, err := testProxyRetry(t, tt.batch, tt.response, tt.msg)
311312
assert.Error(t, err, tt.msg)
312313
assert.IsType(t, err, &proxycore.CqlError{}, tt.msg)
313314
assert.Equal(t, tt.numNodesTried, numNodesTried, tt.msg)
314315
assert.Equal(t, tt.retryCount, retryCount, tt.msg)
315316
}
316317
}
317318

318-
func testProxyRetry(t *testing.T, query message.Message, response message.Error) (numNodesTried, retryCount int, responseError error) {
319-
ctx, cancel := context.WithCancel(context.Background())
320-
defer cancel()
321-
319+
func testProxyRetry(t *testing.T, query message.Message, response message.Error, testMessage string) (numNodesTried, retryCount int, responseError error) {
322320
var mu sync.Mutex
323321
tried := make(map[string]int)
324322
prepared := make(map[[16]byte]string)
325323

326-
cluster, proxy := setupProxyTest(t, ctx, 3, proxycore.MockRequestHandlers{
324+
ctx, cancel := context.WithCancel(context.Background())
325+
tester, proxyContactPoint, err := setupProxyTest(ctx, 3, proxycore.MockRequestHandlers{
327326
primitive.OpCodeQuery: func(cl *proxycore.MockClient, frm *frame.Frame) message.Message {
328327
if msg := cl.InterceptQuery(frm.Header, frm.Body.Message.(*message.Query)); msg != nil {
329328
return msg
@@ -381,13 +380,14 @@ func testProxyRetry(t *testing.T, query message.Message, response message.Error)
381380
},
382381
})
383382
defer func() {
384-
cluster.Shutdown()
385-
_ = proxy.Close()
383+
cancel()
384+
tester.shutdown()
386385
}()
386+
require.NoError(t, err, testMessage)
387387

388-
cl := connectTestClient(t, ctx)
388+
cl := connectTestClient(t, ctx, proxyContactPoint)
389389

390-
_, err := cl.Query(ctx, primitive.ProtocolVersion4, query)
390+
_, err = cl.Query(ctx, primitive.ProtocolVersion4, query)
391391

392392
if cqlErr, ok := err.(*proxycore.CqlError); ok {
393393
if unprepared, ok := cqlErr.Message.(*message.Unprepared); ok {

0 commit comments

Comments
 (0)