@@ -17,13 +17,19 @@ package proxy
17
17
import (
18
18
"bytes"
19
19
"context"
20
+ "crypto"
20
21
"crypto/md5"
21
22
"encoding/hex"
22
23
"errors"
23
24
"fmt"
24
25
"io"
26
+ "math"
27
+ "math/big"
25
28
"net"
26
29
"reflect"
30
+ "sort"
31
+ "strconv"
32
+ "strings"
27
33
"sync"
28
34
"sync/atomic"
29
35
"time"
@@ -40,14 +46,19 @@ import (
40
46
)
41
47
42
48
var (
43
- encodedOneValue , _ = proxycore .EncodeType (datatype .Int , primitive .ProtocolVersion4 , 1 )
44
- encodedZeroValue , _ = proxycore .EncodeType (datatype .Int , primitive .ProtocolVersion4 , 0 )
49
+ encodedOneValue , _ = proxycore .EncodeType (datatype .Int , primitive .ProtocolVersion4 , 1 )
45
50
)
46
51
47
52
var ErrProxyClosed = errors .New ("proxy closed" )
48
53
49
54
const preparedIdSize = 16
50
55
56
+ type PeerConfig struct {
57
+ RPCAddr string `yaml:"rpc-address"`
58
+ DC string `yaml:"data-center"`
59
+ Tokens []string `yaml:"tokens"`
60
+ }
61
+
51
62
type Config struct {
52
63
Version primitive.ProtocolVersion
53
64
MaxVersion primitive.ProtocolVersion
@@ -59,6 +70,10 @@ type Config struct {
59
70
Logger * zap.Logger
60
71
HeartBeatInterval time.Duration
61
72
IdleTimeout time.Duration
73
+ RPCAddr string
74
+ DC string
75
+ Tokens []string
76
+ Peers []PeerConfig
62
77
// PreparedCache a cache that stores prepared queries. If not set it uses the default implementation with a max
63
78
// capacity of ~100MB.
64
79
PreparedCache proxycore.PreparedCache
@@ -80,6 +95,14 @@ type Proxy struct {
80
95
systemLocalValues map [string ]message.Column
81
96
closed chan struct {}
82
97
closingMu sync.Mutex
98
+ localNode * node
99
+ nodes []* node
100
+ }
101
+
102
+ type node struct {
103
+ addr * net.IPAddr
104
+ dc string
105
+ tokens []string
83
106
}
84
107
85
108
func (p * Proxy ) OnEvent (event proxycore.Event ) {
@@ -157,6 +180,11 @@ func (p *Proxy) Listen(address string) error {
157
180
return fmt .Errorf ("unable to register to listen for schema events %w" , err )
158
181
}
159
182
183
+ err = p .buildNodes ()
184
+ if err != nil {
185
+ return fmt .Errorf ("unable to build node information: %w" , err )
186
+ }
187
+
160
188
p .buildLocalRow ()
161
189
162
190
p .lb = proxycore .NewRoundRobinLoadBalancer ()
@@ -288,22 +316,102 @@ func (p *Proxy) newQueryPlan() proxycore.QueryPlan {
288
316
289
317
var (
290
318
schemaVersion , _ = primitive .ParseUuid ("4f2b29e6-59b5-4e2d-8fd6-01e32e67f0d7" )
291
- hostId , _ = primitive .ParseUuid ("19e26944-ffb1-40a9-a184-a9b065e5e06b" )
292
319
)
293
320
321
+ func (p * Proxy ) buildNodes () (err error ) {
322
+ numPeers := len (p .config .Peers )
323
+ nodes := make ([]* node , 0 , numPeers + 1 )
324
+
325
+ var localAddr * net.IPAddr
326
+ if len (p .config .RPCAddr ) > 0 {
327
+ localAddr , err = net .ResolveIPAddr ("ip" , p .config .RPCAddr )
328
+ if err != nil {
329
+ return fmt .Errorf ("invalid RPC address: %w" , err )
330
+ }
331
+ } else if numPeers > 0 {
332
+ return errors .New ("peers provided, but RPC address is not set" )
333
+ }
334
+
335
+ localDC := p .config .DC
336
+ if len (localDC ) == 0 {
337
+ localDC = p .cluster .Info .LocalDC
338
+ p .logger .Info ("no local DC configured using DC from the first successful contact point" ,
339
+ zap .String ("dc" , localDC ))
340
+ }
341
+
342
+ var localTokens []string
343
+ calculateTokens := false
344
+ if len (p .config .Tokens ) > 0 {
345
+ localTokens = p .config .Tokens
346
+ } else {
347
+ calculateTokens = true
348
+ localTokens = []string {strconv .FormatInt (math .MinInt64 , 10 )}
349
+ }
350
+
351
+ p .localNode = & node {
352
+ addr : localAddr ,
353
+ dc : localDC ,
354
+ tokens : localTokens ,
355
+ }
356
+ nodes = append (nodes , p .localNode )
357
+
358
+ for i , peer := range p .config .Peers {
359
+ if len (peer .RPCAddr ) == 0 {
360
+ return fmt .Errorf ("no 'rpc-address' provided for peer #%d" , i + 1 )
361
+ }
362
+ addr , err := net .ResolveIPAddr ("ip" , peer .RPCAddr )
363
+ if err != nil {
364
+ return fmt .Errorf ("invalid peer address: %w" , err )
365
+ }
366
+ if compareIPAddr (localAddr , addr ) == 0 {
367
+ p .logger .Info ("ignoring local address in peers configuration" , zap .Stringer ("localAddr" , localAddr ))
368
+ continue
369
+ }
370
+ dc := peer .DC
371
+ if len (dc ) == 0 {
372
+ dc = localDC
373
+ }
374
+ if ! calculateTokens && len (peer .Tokens ) == 0 {
375
+ return errors .New ("tokens must be provided for all peer proxies if tokens are provided for this proxy" )
376
+ }
377
+ nodes = append (nodes , & node {
378
+ addr : addr ,
379
+ dc : dc ,
380
+ })
381
+ }
382
+
383
+ if calculateTokens && len (nodes ) > 1 {
384
+ sort .Slice (nodes , func (i , j int ) bool {
385
+ return compareIPAddr (nodes [i ].addr , nodes [j ].addr ) < 0
386
+ })
387
+
388
+ var numTokens big.Int
389
+ numTokens .SetUint64 (math .MaxUint64 / uint64 (numPeers + 1 ) + 1 )
390
+ start := big .NewInt (math .MinInt64 )
391
+
392
+ for _ , n := range nodes {
393
+ n .tokens = []string {start .Text (10 )}
394
+ start .Add (start , & numTokens )
395
+ }
396
+ }
397
+
398
+ p .nodes = nodes
399
+
400
+ return nil
401
+ }
402
+
294
403
func (p * Proxy ) buildLocalRow () {
295
404
p .systemLocalValues = map [string ]message.Column {
296
405
"key" : p .encodeTypeFatal (datatype .Varchar , "local" ),
297
- "data_center" : p .encodeTypeFatal (datatype .Varchar , "dc1" ),
406
+ "data_center" : p .encodeTypeFatal (datatype .Varchar , p . localNode . dc ),
298
407
"rack" : p .encodeTypeFatal (datatype .Varchar , "rack1" ),
299
- "tokens" : p .encodeTypeFatal (datatype .NewListType (datatype .Varchar ), [] string { "0" } ),
408
+ "tokens" : p .encodeTypeFatal (datatype .NewListType (datatype .Varchar ), p . localNode . tokens ),
300
409
"release_version" : p .encodeTypeFatal (datatype .Varchar , p .cluster .Info .ReleaseVersion ),
301
410
"partitioner" : p .encodeTypeFatal (datatype .Varchar , p .cluster .Info .Partitioner ),
302
411
"cluster_name" : p .encodeTypeFatal (datatype .Varchar , "cql-proxy" ),
303
412
"cql_version" : p .encodeTypeFatal (datatype .Varchar , p .cluster .Info .CQLVersion ),
304
413
"schema_version" : p .encodeTypeFatal (datatype .Uuid , schemaVersion ), // TODO: Make this match the downstream cluster(s)
305
414
"native_protocol_version" : p .encodeTypeFatal (datatype .Varchar , p .cluster .NegotiatedVersion .String ()),
306
- "host_id" : p .encodeTypeFatal (datatype .Uuid , hostId ),
307
415
}
308
416
}
309
417
@@ -508,33 +616,53 @@ func (c *client) handleQuery(raw *frame.RawFrame, msg *partialQuery) {
508
616
}
509
617
}
510
618
511
- func (c * client ) columnValue (values map [string ]message.Column , name string , table string ) message.Column {
512
- var val message.Column
513
- var ok bool
514
- if val , ok = values [name ]; ! ok {
515
- if name == "rpc_address" && table == "local" {
516
- switch addr := c .conn .LocalAddr ().(type ) {
517
- case * net.TCPAddr :
518
- val , _ = proxycore .EncodeType (datatype .Inet , c .proxy .cluster .NegotiatedVersion , addr .IP )
519
- }
619
+ func (c * client ) filterSystemLocalValues (stmt * parser.SelectStatement ) (row []message.Column , err error ) {
620
+ return parser .FilterValues (stmt , parser .SystemLocalColumns , func (name string ) (value message.Column , err error ) {
621
+ if name == "rpc_address" {
622
+ return proxycore .EncodeType (datatype .Inet , c .proxy .cluster .NegotiatedVersion , c .localIP ())
623
+ } else if name == "host_id" {
624
+ return proxycore .EncodeType (datatype .Uuid , c .proxy .cluster .NegotiatedVersion , nameBasedUUID (c .localIP ().String ()))
625
+ } else if val , ok := c .proxy .systemLocalValues [name ]; ok {
626
+ return val , nil
627
+ } else if name == parser .CountValueName {
628
+ return encodedOneValue , nil
629
+ } else {
630
+ return nil , fmt .Errorf ("no column value for %s" , name )
631
+ }
632
+ })
633
+ }
634
+
635
+ func (c * client ) localIP () net.IP {
636
+ if c .proxy .localNode .addr != nil {
637
+ return c .proxy .localNode .addr .IP
638
+ } else {
639
+ switch a := c .conn .LocalAddr ().(type ) {
640
+ case * net.TCPAddr :
641
+ return a .IP
642
+ case * net.IPAddr :
643
+ return a .IP
644
+ default :
645
+ panic ("unhandled local address type" )
520
646
}
521
647
}
522
- return val
523
648
}
524
649
525
- func (c * client ) filterSystemLocalValues (stmt * parser.SelectStatement ) (row []message.Column , err error ) {
526
- return parser .FilterValues (stmt , parser .SystemLocalColumns , func (name string ) (value message.Column , err error ) {
527
- if val , ok := c .proxy .systemLocalValues [name ]; ok {
528
- return val , nil
650
+ func (c * client ) filterSystemPeerValues (stmt * parser.SelectStatement , peer * node , peerCount int ) (row []message.Column , err error ) {
651
+ return parser .FilterValues (stmt , parser .SystemPeersColumns , func (name string ) (value message.Column , err error ) {
652
+ if name == "data_center" {
653
+ return proxycore .EncodeType (datatype .Varchar , c .proxy .cluster .NegotiatedVersion , peer .dc )
654
+ } else if name == "host_id" {
655
+ return proxycore .EncodeType (datatype .Uuid , c .proxy .cluster .NegotiatedVersion , nameBasedUUID (peer .addr .String ()))
656
+ } else if name == "tokens" {
657
+ return proxycore .EncodeType (datatype .NewListType (datatype .Varchar ), c .proxy .cluster .NegotiatedVersion , peer .tokens )
658
+ } else if name == "peer" {
659
+ return proxycore .EncodeType (datatype .Inet , c .proxy .cluster .NegotiatedVersion , peer .addr .IP )
529
660
} else if name == "rpc_address" {
530
- switch addr := c .conn .LocalAddr ().(type ) {
531
- case * net.TCPAddr :
532
- return proxycore .EncodeType (datatype .Inet , c .proxy .cluster .NegotiatedVersion , addr .IP )
533
- default :
534
- return nil , errors .New ("unhandled local address type" )
535
- }
661
+ return proxycore .EncodeType (datatype .Inet , c .proxy .cluster .NegotiatedVersion , peer .addr .IP )
662
+ } else if val , ok := c .proxy .systemLocalValues [name ]; ok {
663
+ return val , nil
536
664
} else if name == parser .CountValueName {
537
- return encodedOneValue , nil
665
+ return proxycore . EncodeType ( datatype . Int , c . proxy . cluster . NegotiatedVersion , peerCount )
538
666
} else {
539
667
return nil , fmt .Errorf ("no column value for %s" , name )
540
668
}
@@ -563,16 +691,27 @@ func (c *client) interceptSystemQuery(hdr *frame.Header, stmt interface{}) {
563
691
c .send (hdr , & message.Invalid {ErrorMessage : err .Error ()})
564
692
} else {
565
693
var data []message.Row
566
- if parser .IsCountStarQuery (s ) { // COUNT(*) always returns a value, even when there are no rows
567
- data = []message.Row {{encodedZeroValue }}
694
+ for _ , n := range c .proxy .nodes {
695
+ if n != c .proxy .localNode {
696
+ var row message.Row
697
+ row , err = c .filterSystemPeerValues (s , n , len (c .proxy .nodes )- 1 )
698
+ if err != nil {
699
+ break
700
+ }
701
+ data = append (data , row )
702
+ }
703
+ }
704
+ if err != nil {
705
+ c .send (hdr , & message.Invalid {ErrorMessage : err .Error ()})
706
+ } else {
707
+ c .send (hdr , & message.RowsResult {
708
+ Metadata : & message.RowsMetadata {
709
+ ColumnCount : int32 (len (columns )),
710
+ Columns : columns ,
711
+ },
712
+ Data : data ,
713
+ })
568
714
}
569
- c .send (hdr , & message.RowsResult {
570
- Metadata : & message.RowsMetadata {
571
- ColumnCount : int32 (len (columns )),
572
- Columns : columns ,
573
- },
574
- Data : data ,
575
- })
576
715
}
577
716
} else if columns , ok := parser .SystemColumnsByName [s .Table ]; ok {
578
717
c .send (hdr , & message.RowsResult {
@@ -642,3 +781,36 @@ func preparedIdKey(bytes []byte) [preparedIdSize]byte {
642
781
copy (buf [:], bytes )
643
782
return buf
644
783
}
784
+
785
+ func nameBasedUUID (name string ) primitive.UUID {
786
+ var uuid primitive.UUID
787
+ m := crypto .MD5 .New ()
788
+ _ , _ = io .WriteString (m , name )
789
+ hash := m .Sum (nil )
790
+ for i := 0 ; i < len (uuid ); i ++ {
791
+ uuid [i ] = hash [i ]
792
+ }
793
+ uuid [6 ] &= 0x0F
794
+ uuid [6 ] |= 0x30
795
+ uuid [8 ] &= 0x3F
796
+ uuid [8 ] |= 0x80
797
+ return uuid
798
+ }
799
+
800
+ func compareIPAddr (a * net.IPAddr , b * net.IPAddr ) int {
801
+ if a == b {
802
+ return 0
803
+ }
804
+
805
+ res := bytes .Compare (a .IP , b .IP )
806
+ if res != 0 {
807
+ return res
808
+ }
809
+
810
+ res = strings .Compare (a .Zone , b .Zone )
811
+ if res != 0 {
812
+ return res
813
+ }
814
+
815
+ return 0
816
+ }
0 commit comments