@@ -31,7 +31,6 @@ import (
31
31
"strconv"
32
32
"strings"
33
33
"sync"
34
- "sync/atomic"
35
34
"time"
36
35
37
36
"github.com/datastax/cql-proxy/parser"
50
49
)
51
50
52
51
var ErrProxyClosed = errors .New ("proxy closed" )
52
+ var ErrProxyAlreadyConnected = errors .New ("proxy already connected" )
53
+ var ErrProxyNotConnected = errors .New ("proxy not connected" )
53
54
54
55
const preparedIdSize = 16
55
56
@@ -83,18 +84,19 @@ type Proxy struct {
83
84
ctx context.Context
84
85
config Config
85
86
logger * zap.Logger
86
- listener * net.TCPListener
87
87
cluster * proxycore.Cluster
88
88
sessions [primitive .ProtocolVersionDse2 + 1 ]sync.Map // Cache sessions per protocol version
89
- sessMu sync.Mutex
90
- schemaEventClients sync.Map
89
+ mu sync.Mutex
90
+ isConnected bool
91
+ isClosing bool
92
+ clients map [* client ]struct {}
93
+ listeners map [* net.Listener ]struct {}
94
+ eventClients sync.Map
91
95
preparedCache proxycore.PreparedCache
92
96
preparedIdempotence sync.Map
93
- clientIdGen uint64
94
97
lb proxycore.LoadBalancer
95
98
systemLocalValues map [string ]message.Column
96
99
closed chan struct {}
97
- closingMu sync.Mutex
98
100
localNode * node
99
101
nodes []* node
100
102
}
@@ -109,16 +111,15 @@ func (p *Proxy) OnEvent(event proxycore.Event) {
109
111
switch evt := event .(type ) {
110
112
case * proxycore.SchemaChangeEvent :
111
113
frm := frame .NewFrame (p .cluster .NegotiatedVersion , - 1 , evt .Message )
112
- p .schemaEventClients .Range (func (key , value interface {}) bool {
113
- cl := value .(* client )
114
+ p .eventClients .Range (func (key , _ interface {}) bool {
115
+ cl := key .(* client )
114
116
err := cl .conn .Write (proxycore .SenderFunc (func (writer io.Writer ) error {
115
117
return codec .EncodeFrame (frm , writer )
116
118
}))
117
119
cl .conn .LocalAddr ()
118
120
if err != nil {
119
121
p .logger .Error ("unable to send schema change event" ,
120
122
zap .Stringer ("client" , cl .conn .RemoteAddr ()),
121
- zap .Uint64 ("id" , cl .id ),
122
123
zap .Error (err ))
123
124
_ = cl .conn .Close ()
124
125
}
@@ -138,24 +139,24 @@ func NewProxy(ctx context.Context, config Config) *Proxy {
138
139
config .RetryPolicy = NewDefaultRetryPolicy ()
139
140
}
140
141
return & Proxy {
141
- ctx : ctx ,
142
- config : config ,
143
- logger : proxycore .GetOrCreateNopLogger (config .Logger ),
144
- closed : make (chan struct {}),
142
+ ctx : ctx ,
143
+ config : config ,
144
+ logger : proxycore .GetOrCreateNopLogger (config .Logger ),
145
+ clients : make (map [* client ]struct {}),
146
+ listeners : make (map [* net.Listener ]struct {}),
147
+ closed : make (chan struct {}),
145
148
}
146
149
}
147
150
148
- func (p * Proxy ) ListenAndServe (address string ) error {
149
- err := p .Listen (address )
150
- if err != nil {
151
- return err
151
+ func (p * Proxy ) Connect () error {
152
+ p .mu .Lock ()
153
+ defer p .mu .Unlock ()
154
+
155
+ if p .isConnected {
156
+ return ErrProxyAlreadyConnected
152
157
}
153
- return p .Serve ()
154
- }
155
158
156
- func (p * Proxy ) Listen (address string ) error {
157
159
var err error
158
-
159
160
p .preparedCache , err = getOrCreateDefaultPreparedCache (p .config .PreparedCache )
160
161
if err != nil {
161
162
return fmt .Errorf ("unable to create prepared cache %w" , err )
@@ -210,23 +211,23 @@ func (p *Proxy) Listen(address string) error {
210
211
211
212
p .sessions [p .cluster .NegotiatedVersion ].Store ("" , sess ) // No keyspace
212
213
213
- tcpAddr , err := net .ResolveTCPAddr ("tcp" , address )
214
- if err != nil {
215
- return err
216
- }
217
- p .listener , err = net .ListenTCP ("tcp" , tcpAddr )
218
- if err != nil {
219
- return err
220
- }
221
-
222
- p .logger .Info ("proxy is listening" , zap .Stringer ("address" , p .listener .Addr ()))
223
-
214
+ p .isConnected = true
224
215
return nil
225
216
}
226
217
227
- func (p * Proxy ) Serve () error {
218
+ // Serve the proxy using the specified listener. It can be called multiple times with different listeners allowing
219
+ // them to share the same backend clusters.
220
+ func (p * Proxy ) Serve (l net.Listener ) (err error ) {
221
+ l = & closeOnceListener {Listener : l }
222
+ defer l .Close ()
223
+
224
+ if err = p .addListener (& l ); err != nil {
225
+ return err
226
+ }
227
+ defer p .removeListener (& l )
228
+
228
229
for {
229
- conn , err := p . listener . AcceptTCP ()
230
+ conn , err := l . Accept ()
230
231
if err != nil {
231
232
select {
232
233
case <- p .closed :
@@ -239,15 +240,45 @@ func (p *Proxy) Serve() error {
239
240
}
240
241
}
241
242
243
+ func (p * Proxy ) addListener (l * net.Listener ) error {
244
+ p .mu .Lock ()
245
+ defer p .mu .Unlock ()
246
+ if p .isClosing {
247
+ return ErrProxyClosed
248
+ }
249
+ if ! p .isConnected {
250
+ return ErrProxyNotConnected
251
+ }
252
+ p .listeners [l ] = struct {}{}
253
+ return nil
254
+ }
255
+
256
+ func (p * Proxy ) removeListener (l * net.Listener ) {
257
+ p .mu .Lock ()
258
+ defer p .mu .Unlock ()
259
+ delete (p .listeners , l )
260
+ }
261
+
242
262
func (p * Proxy ) Close () error {
243
- p .closingMu .Lock ()
244
- defer p .closingMu .Unlock ()
263
+ p .mu .Lock ()
264
+ defer p .mu .Unlock ()
245
265
select {
246
266
case <- p .closed :
247
267
default :
248
268
close (p .closed )
249
269
}
250
- return p .listener .Close ()
270
+ var err error
271
+ for l := range p .listeners {
272
+ if closeErr := (* l ).Close (); closeErr != nil && err == nil {
273
+ err = closeErr
274
+ }
275
+ }
276
+ for cl := range p .clients {
277
+ _ = cl .conn .Close ()
278
+ p .eventClients .Delete (cl )
279
+ delete (p .clients , cl )
280
+ }
281
+ return err
251
282
}
252
283
253
284
func (p * Proxy ) Ready () bool {
@@ -258,28 +289,29 @@ func (p *Proxy) OutageDuration() time.Duration {
258
289
return p .cluster .OutageDuration ()
259
290
}
260
291
261
- func (p * Proxy ) handle (conn * net.TCPConn ) {
262
- if err := conn .SetKeepAlive (false ); err != nil {
263
- p .logger .Warn ("failed to disable keepalive on connection" , zap .Error (err ))
264
- }
265
-
266
- if err := conn .SetNoDelay (true ); err != nil {
267
- p .logger .Warn ("failed to set TCP_NODELAY on connection" , zap .Error (err ))
292
+ func (p * Proxy ) handle (conn net.Conn ) {
293
+ if tcpConn , ok := conn .(* net.TCPConn ); ok {
294
+ if err := tcpConn .SetKeepAlive (false ); err != nil {
295
+ p .logger .Warn ("failed to disable keepalive on connection" , zap .Error (err ))
296
+ }
297
+ if err := tcpConn .SetNoDelay (true ); err != nil {
298
+ p .logger .Warn ("failed to set TCP_NODELAY on connection" , zap .Error (err ))
299
+ }
268
300
}
269
301
270
302
cl := & client {
271
303
ctx : p .ctx ,
272
304
proxy : p ,
273
- id : atomic .AddUint64 (& p .clientIdGen , 1 ),
274
305
preparedSystemQuery : make (map [[preparedIdSize ]byte ]interface {}),
275
306
}
307
+ p .addClient (cl )
276
308
cl .conn = proxycore .NewConn (conn , cl )
277
309
cl .conn .Start ()
278
310
}
279
311
280
312
func (p * Proxy ) maybeCreateSession (version primitive.ProtocolVersion , keyspace string ) (* proxycore.Session , error ) {
281
- p .sessMu .Lock ()
282
- defer p .sessMu .Unlock ()
313
+ p .mu .Lock ()
314
+ defer p .mu .Unlock ()
283
315
if cachedSession , ok := p .sessions [version ].Load (keyspace ); ok {
284
316
return cachedSession .(* proxycore.Session ), nil
285
317
} else {
@@ -463,12 +495,30 @@ func (p *Proxy) maybeStorePreparedIdempotence(raw *frame.RawFrame, msg message.M
463
495
}
464
496
}
465
497
498
+ func (p * Proxy ) addClient (cl * client ) {
499
+ p .mu .Lock ()
500
+ defer p .mu .Unlock ()
501
+ p .clients [cl ] = struct {}{}
502
+ }
503
+
504
+ func (p * Proxy ) registerForEvents (cl * client ) {
505
+ p .eventClients .Store (cl , struct {}{})
506
+ }
507
+
508
+ func (p * Proxy ) removeClient (cl * client ) {
509
+ p .eventClients .Delete (cl )
510
+
511
+ p .mu .Lock ()
512
+ defer p .mu .Unlock ()
513
+ delete (p .clients , cl )
514
+
515
+ }
516
+
466
517
type client struct {
467
518
ctx context.Context
468
519
proxy * Proxy
469
520
conn * proxycore.Conn
470
521
keyspace string
471
- id uint64
472
522
preparedSystemQuery map [[16 ]byte ]interface {}
473
523
}
474
524
@@ -505,7 +555,7 @@ func (c *client) Receive(reader io.Reader) error {
505
555
case * message.Register :
506
556
for _ , t := range msg .EventTypes {
507
557
if t == primitive .EventTypeSchemaChange {
508
- c .proxy .schemaEventClients . Store ( c . id , c )
558
+ c .proxy .registerForEvents ( c )
509
559
}
510
560
}
511
561
c .send (raw .Header , & message.Ready {})
@@ -746,7 +796,7 @@ func (c *client) send(hdr *frame.Header, msg message.Message) {
746
796
}
747
797
748
798
func (c * client ) Closing (_ error ) {
749
- c .proxy .schemaEventClients . Delete ( c . id )
799
+ c .proxy .removeClient ( c )
750
800
}
751
801
752
802
func getOrCreateDefaultPreparedCache (cache proxycore.PreparedCache ) (proxycore.PreparedCache , error ) {
@@ -818,3 +868,17 @@ func compareIPAddr(a *net.IPAddr, b *net.IPAddr) int {
818
868
819
869
return 0
820
870
}
871
+
872
+ // Wrap the listener so that if it's closed in the serve loop it doesn't race with proxy Close()
873
+ type closeOnceListener struct {
874
+ net.Listener
875
+ once sync.Once
876
+ closeErr error
877
+ }
878
+
879
+ func (oc * closeOnceListener ) Close () error {
880
+ oc .once .Do (oc .close )
881
+ return oc .closeErr
882
+ }
883
+
884
+ func (oc * closeOnceListener ) close () { oc .closeErr = oc .Listener .Close () }
0 commit comments