Skip to content

Commit 187088e

Browse files
authored
Merge pull request #86 from datastax/astra-use-local-dc-nodes
Ignore remote DC nodes when resolving endpoints (Astra only)
2 parents 3d061b7 + 8d1efdc commit 187088e

File tree

10 files changed

+132
-46
lines changed

10 files changed

+132
-46
lines changed

astra/endpoint.go

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,12 @@ import (
2525
"sync"
2626
"time"
2727

28-
"github.com/datastax/go-cassandra-native-protocol/primitive"
29-
3028
"github.com/datastax/cql-proxy/proxycore"
3129
)
3230

3331
type astraResolver struct {
3432
sniProxyAddress string
33+
region string
3534
bundle *Bundle
3635
mu *sync.Mutex
3736
}
@@ -77,6 +76,7 @@ func (r *astraResolver) Resolve() ([]proxycore.Endpoint, error) {
7776

7877
r.mu.Lock()
7978
r.sniProxyAddress = sniProxyAddress
79+
r.region = metadata.Region
8080
r.mu.Unlock()
8181

8282
var endpoints []proxycore.Endpoint
@@ -91,31 +91,35 @@ func (r *astraResolver) Resolve() ([]proxycore.Endpoint, error) {
9191
return endpoints, nil
9292
}
9393

94-
func (r *astraResolver) getSNIProxyAddress() (string, error) {
94+
func (r *astraResolver) getSNIProxyAddressAndRegion() (string, string, error) {
9595
r.mu.Lock()
9696
defer r.mu.Unlock()
9797
if len(r.sniProxyAddress) == 0 {
98-
return "", errors.New("SNI proxy address never resolved")
98+
return "", "", errors.New("SNI proxy address (and region) never resolved")
9999
}
100-
return r.sniProxyAddress, nil
100+
return r.sniProxyAddress, r.region, nil
101101
}
102102

103103
func (r *astraResolver) NewEndpoint(row proxycore.Row) (proxycore.Endpoint, error) {
104-
sniProxyAddress, err := r.getSNIProxyAddress()
104+
sniProxyAddress, region, err := r.getSNIProxyAddressAndRegion()
105105
if err != nil {
106106
return nil, err
107107
}
108-
hostId, err := row.ByName("host_id")
108+
dc, err := row.StringByName("data_center")
109109
if err != nil {
110110
return nil, err
111111
}
112-
if uuid, ok := hostId.(primitive.UUID); !ok {
113-
return nil, errors.New("ignoring host because its `host_id` is not set or is invalid")
112+
if len(region) > 0 && region != dc {
113+
return nil, proxycore.IgnoreEndpoint
114+
}
115+
hostId, err := row.UUIDByName("host_id")
116+
if err != nil {
117+
return nil, err
114118
} else {
115119
return &astraEndpoint{
116120
addr: sniProxyAddress,
117-
key: fmt.Sprintf("%s:%s", sniProxyAddress, &uuid),
118-
tlsConfig: copyTLSConfig(r.bundle, uuid.String()),
121+
key: fmt.Sprintf("%s:%s", sniProxyAddress, &hostId),
122+
tlsConfig: copyTLSConfig(r.bundle, hostId.String()),
119123
}, nil
120124
}
121125
}

astra/endpoint_test.go

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,10 +79,17 @@ func TestAstraResolver_NewEndpoint(t *testing.T) {
7979
Index: 0,
8080
Type: datatype.Uuid,
8181
},
82+
{
83+
Keyspace: "system",
84+
Table: "peers",
85+
Name: "data_center",
86+
Index: 1,
87+
Type: datatype.Varchar,
88+
},
8289
},
8390
},
8491
Data: message.RowSet{
85-
message.Row{makeUUID(hostId)},
92+
message.Row{makeUUID(hostId), makeVarchar("us-east1")},
8693
},
8794
}, primitive.ProtocolVersion4)
8895

@@ -92,6 +99,43 @@ func TestAstraResolver_NewEndpoint(t *testing.T) {
9299
assert.Contains(t, endpoint.Key(), hostId)
93100
}
94101

102+
func TestAstraResolver_NewEndpoint_Ignored(t *testing.T) {
103+
resolver := createResolver(t)
104+
_, err := resolver.Resolve()
105+
require.NoError(t, err)
106+
107+
const hostId = "a2e24181-d732-402a-ab06-894a8b2f6094"
108+
109+
rs := proxycore.NewResultSet(&message.RowsResult{
110+
Metadata: &message.RowsMetadata{
111+
ColumnCount: 1,
112+
Columns: []*message.ColumnMetadata{
113+
{
114+
Keyspace: "system",
115+
Table: "peers",
116+
Name: "host_id",
117+
Index: 0,
118+
Type: datatype.Uuid,
119+
},
120+
{
121+
Keyspace: "system",
122+
Table: "peers",
123+
Name: "data_center",
124+
Index: 1,
125+
Type: datatype.Varchar,
126+
},
127+
},
128+
},
129+
Data: message.RowSet{
130+
message.Row{makeUUID(hostId), makeVarchar("ignored")},
131+
},
132+
}, primitive.ProtocolVersion4)
133+
134+
endpoint, err := resolver.NewEndpoint(rs.Row(0))
135+
assert.Nil(t, endpoint)
136+
assert.ErrorIs(t, err, proxycore.IgnoreEndpoint)
137+
}
138+
95139
func TestAstraResolver_NewEndpointInvalidHostID(t *testing.T) {
96140
resolver := createResolver(t)
97141
_, err := resolver.Resolve()
@@ -229,3 +273,8 @@ func makeUUID(uuid string) []byte {
229273
bytes, _ := proxycore.EncodeType(datatype.Uuid, primitive.ProtocolVersion4, parsedUuid)
230274
return bytes
231275
}
276+
277+
func makeVarchar(s string) []byte {
278+
bytes, _ := proxycore.EncodeType(datatype.Varchar, primitive.ProtocolVersion4, s)
279+
return bytes
280+
}

proxycore/cluster.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ func (c *Cluster) mergeHosts(hosts []*Host) error {
240240
existing := make(map[string]*Host)
241241

242242
for _, host := range c.hosts {
243-
existing[host.Endpoint().Key()] = host
243+
existing[host.Key()] = host
244244
}
245245

246246
c.currentHostIndex = -1
@@ -282,7 +282,6 @@ func (c *Cluster) sendEvent(event Event) {
282282

283283
func (c *Cluster) queryHosts(ctx context.Context, conn *ClientConn, version primitive.ProtocolVersion) (hosts []*Host, info ClusterInfo, err error) {
284284
var rs *ResultSet
285-
var val interface{}
286285

287286
rs, err = conn.Query(ctx, version, &message.Query{
288287
Query: "SELECT * FROM system.local",
@@ -299,23 +298,20 @@ func (c *Cluster) queryHosts(ctx context.Context, conn *ClientConn, version prim
299298
hosts = c.addHosts(hosts, rs)
300299
row := rs.Row(0)
301300

302-
val, err = row.ByName("partitioner")
301+
partitioner, err := row.StringByName("partitioner")
303302
if err != nil {
304303
return nil, ClusterInfo{}, err
305304
}
306-
partitioner := val.(string)
307305

308-
val, err = row.ByName("release_version")
306+
releaseVersion, err := row.StringByName("release_version")
309307
if err != nil {
310308
return nil, ClusterInfo{}, err
311309
}
312-
releaseVersion := val.(string)
313310

314-
val, err = row.ByName("cql_version")
311+
cqlVersion, err := row.StringByName("cql_version")
315312
if err != nil {
316313
return nil, ClusterInfo{}, err
317314
}
318-
cqlVersion := val.(string)
319315

320316
rs, err = conn.Query(ctx, version, &message.Query{
321317
Query: "SELECT * FROM system.peers",
@@ -341,7 +337,11 @@ func (c *Cluster) addHosts(hosts []*Host, rs *ResultSet) []*Host {
341337
if endpoint, err := c.config.Resolver.NewEndpoint(row); err == nil {
342338
if host, err := NewHostFromRow(endpoint, row); err == nil {
343339
hosts = append(hosts, host)
340+
} else {
341+
c.logger.Error("unable to create new host", zap.Stringer("endpoint", endpoint), zap.Error(err))
344342
}
343+
} else if err != IgnoreEndpoint {
344+
c.logger.Error("unable to create new endpoint", zap.Error(err))
345345
}
346346
}
347347
return hosts
@@ -350,7 +350,7 @@ func (c *Cluster) addHosts(hosts []*Host, rs *ResultSet) []*Host {
350350
func (c *Cluster) reconnect() bool {
351351
c.currentHostIndex = (c.currentHostIndex + 1) % len(c.hosts)
352352
host := c.hosts[c.currentHostIndex]
353-
err := c.connect(c.ctx, host.Endpoint(), false)
353+
err := c.connect(c.ctx, host.Endpoint, false)
354354
if err != nil {
355355
c.logger.Error("error reconnecting to host", zap.Stringer("host", host), zap.Error(err))
356356
return false

proxycore/endpoint.go

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ import (
2323
"strconv"
2424
)
2525

26+
var IgnoreEndpoint = errors.New("ignore endpoint")
27+
2628
type Endpoint interface {
2729
fmt.Stringer
2830
Addr() string
@@ -108,19 +110,14 @@ func (r *defaultEndpointResolver) NewEndpoint(row Row) (Endpoint, error) {
108110
if err != nil && !errors.Is(err, ColumnNameNotFound) {
109111
return nil, err
110112
}
111-
rpcAddress, err := row.ByName("rpc_address")
113+
rpcAddress, err := row.InetByName("rpc_address")
112114
if err != nil {
113-
return nil, err
114-
}
115-
116-
var ok bool
117-
var addr net.IP
118-
119-
if addr, ok = rpcAddress.(net.IP); !ok {
120-
return nil, errors.New("ignoring host because its `rpc_address` is not set or is invalid")
115+
return nil, fmt.Errorf("ignoring host because its `rpc_address` is not set or is invalid: %w", err)
121116
}
122117

118+
addr := rpcAddress
123119
if addr.Equal(net.IPv4zero) || addr.Equal(net.IPv6zero) {
120+
var ok bool
124121
if addr, ok = peer.(net.IP); !ok {
125122
return nil, errors.New("ignoring host because its `peer` is not set or is invalid")
126123
}

proxycore/host.go

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,17 @@
1515
package proxycore
1616

1717
type Host struct {
18-
endpoint Endpoint
18+
Endpoint
1919
}
2020

21-
func NewHostFromRow(endpoint Endpoint, _ Row) (*Host, error) {
21+
func NewHostFromRow(endpoint Endpoint, row Row) (*Host, error) {
2222
return &Host{endpoint}, nil
2323
}
2424

25-
func (h *Host) Endpoint() Endpoint {
26-
return h.endpoint
27-
}
28-
2925
func (h *Host) Key() string {
30-
return h.endpoint.Key()
26+
return h.Endpoint.Key()
3127
}
3228

3329
func (h *Host) String() string {
34-
return h.endpoint.String()
30+
return h.Endpoint.String()
3531
}

proxycore/lb.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func (l *roundRobinLoadBalancer) OnEvent(event Event) {
5454
case *RemoveEvent:
5555
cpy := l.copy()
5656
for i, h := range cpy {
57-
if h.Endpoint().Key() == evt.Host.Key() {
57+
if h.Key() == evt.Host.Key() {
5858
l.hosts.Store(append(cpy[:i], cpy[i+1:]...))
5959
break
6060
}

proxycore/lb_test.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
package proxycore
1616

1717
import (
18-
"github.com/stretchr/testify/assert"
1918
"testing"
19+
20+
"github.com/stretchr/testify/assert"
2021
)
2122

2223
func TestRoundRobinLoadBalancer_NewQueryPlan(t *testing.T) {
@@ -26,7 +27,7 @@ func TestRoundRobinLoadBalancer_NewQueryPlan(t *testing.T) {
2627
assert.Nil(t, qp.Next())
2728

2829
newHost := func(addr string) *Host {
29-
return &Host{endpoint: &defaultEndpoint{addr: addr}}
30+
return &Host{Endpoint: &defaultEndpoint{addr: addr}}
3031
}
3132

3233
lb.OnEvent(&BootstrapEvent{Hosts: []*Host{newHost("127.0.0.1"), newHost("127.0.0.2"), newHost("127.0.0.3")}})

proxycore/resultset.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ package proxycore
1616

1717
import (
1818
"errors"
19+
"fmt"
20+
"net"
21+
1922
"github.com/datastax/go-cassandra-native-protocol/message"
2023
"github.com/datastax/go-cassandra-native-protocol/primitive"
2124
)
@@ -72,3 +75,39 @@ func (r Row) ByName(n string) (interface{}, error) {
7275
return r.ByPos(i)
7376
}
7477
}
78+
79+
func (r Row) StringByName(n string) (string, error) {
80+
val, err := r.ByName(n)
81+
if err != nil {
82+
return "", err
83+
}
84+
if s, ok := val.(string); !ok {
85+
return "", fmt.Errorf("'%s' is not a string", n)
86+
} else {
87+
return s, nil
88+
}
89+
}
90+
91+
func (r Row) InetByName(n string) (net.IP, error) {
92+
val, err := r.ByName(n)
93+
if err != nil {
94+
return nil, err
95+
}
96+
if ip, ok := val.(net.IP); !ok {
97+
return nil, fmt.Errorf("'%s' is not an inet", n)
98+
} else {
99+
return ip, nil
100+
}
101+
}
102+
103+
func (r Row) UUIDByName(n string) (primitive.UUID, error) {
104+
val, err := r.ByName(n)
105+
if err != nil {
106+
return [16]byte{}, err
107+
}
108+
if u, ok := val.(primitive.UUID); !ok {
109+
return [16]byte{}, fmt.Errorf("'%s' is not a uuid", n)
110+
} else {
111+
return u, nil
112+
}
113+
}

proxycore/session.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ func (s *Session) Send(host *Host, request Request) error {
9999
}
100100

101101
func (s *Session) leastBusyConn(host *Host) *ClientConn {
102-
if p, ok := s.pools.Load(host.Endpoint().Key()); ok {
102+
if p, ok := s.pools.Load(host.Key()); ok {
103103
pool := p.(*connPool)
104104
return pool.leastBusyConn()
105105
}
@@ -118,7 +118,7 @@ func (s *Session) OnEvent(event Event) {
118118
for _, host := range evt.Hosts {
119119
go func(host *Host) {
120120
pool, err := connectPool(s.ctx, connPoolConfig{
121-
Endpoint: host.Endpoint(),
121+
Endpoint: host.Endpoint,
122122
SessionConfig: s.config,
123123
})
124124
if err != nil {
@@ -127,7 +127,7 @@ func (s *Session) OnEvent(event Event) {
127127
default:
128128
}
129129
}
130-
s.pools.Store(host.Endpoint().Key(), pool)
130+
s.pools.Store(host.Key(), pool)
131131
wg.Done()
132132
}(host)
133133
}
@@ -139,7 +139,7 @@ func (s *Session) OnEvent(event Event) {
139139
case *AddEvent:
140140
// There's no compute if absent for sync.Map, figure a better way to do this if the pool already exists.
141141
if pool, loaded := s.pools.LoadOrStore(evt.Host.Key(), connectPoolNoFail(s.ctx, connPoolConfig{
142-
Endpoint: evt.Host.Endpoint(),
142+
Endpoint: evt.Host.Endpoint,
143143
SessionConfig: s.config,
144144
})); loaded {
145145
p := pool.(*connPool)

proxycore/session_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func TestConnectSession(t *testing.T) {
7171
require.NoError(t, err)
7272

7373
newHost := func(addr string) *Host {
74-
return &Host{endpoint: &defaultEndpoint{addr: addr}}
74+
return &Host{Endpoint: &defaultEndpoint{addr: addr}}
7575
}
7676

7777
var wg sync.WaitGroup

0 commit comments

Comments
 (0)