Skip to content

Commit bc72398

Browse files
authored
Fix Astra timeouts (#96)
This adds Context and timeouts to several paths where calls could block when contacting Astra services. The timeout can be changed by providing the --astra-timeout flag with a duration.
1 parent 602acda commit bc72398

File tree

8 files changed

+62
-49
lines changed

8 files changed

+62
-49
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ Flags:
4545
or contact points option ($ASTRA_TOKEN).
4646
-i, --astra-database-id=STRING Database ID of the Astra database. Requires '--astra-token' ($ASTRA_DATABASE_ID)
4747
--astra-api-url="https://api.astra.datastax.com" URL for the Astra API ($ASTRA_API_URL)
48+
--astra-timeout=10s Timeout for contacting Astra when retrieving the bundle and metadata ($ASTRA_TIMEOUT)
4849
-c, --contact-points=CONTACT-POINTS,... Contact points for cluster. Ignored if using the bundle path or token option ($CONTACT_POINTS).
4950
-u, --username=STRING Username to use for authentication ($USERNAME)
5051
-p, --password=STRING Password to use for authentication ($PASSWORD)

astra/bundle.go

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,9 @@ import (
3333
)
3434

3535
type Bundle struct {
36-
tlsConfig *tls.Config
37-
host string
38-
port int
36+
TLSConfig *tls.Config
37+
Host string
38+
Port int
3939
}
4040

4141
func LoadBundleZip(reader *zip.Reader) (*Bundle, error) {
@@ -69,13 +69,13 @@ func LoadBundleZip(reader *zip.Reader) (*Bundle, error) {
6969
}
7070

7171
return &Bundle{
72-
tlsConfig: &tls.Config{
72+
TLSConfig: &tls.Config{
7373
RootCAs: rootCAs,
7474
Certificates: []tls.Certificate{cert},
7575
ServerName: config.Host,
7676
},
77-
host: config.Host,
78-
port: config.Port,
77+
Host: config.Host,
78+
Port: config.Port,
7979
}, nil
8080
}
8181

@@ -93,7 +93,7 @@ func LoadBundleZipFromPath(path string) (*Bundle, error) {
9393
}
9494

9595
func LoadBundleZipFromURL(url, databaseID, token string, timeout time.Duration) (*Bundle, error) {
96-
ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(timeout))
96+
ctx, cancel := context.WithTimeout(context.Background(), timeout)
9797
defer cancel()
9898

9999
credsURL, err := generateSecureBundleURLWithResponse(url, databaseID, token, ctx)
@@ -154,18 +154,6 @@ func generateSecureBundleURLWithResponse(url, databaseID, token string, ctx cont
154154
return res.JSON200, nil
155155
}
156156

157-
func (b *Bundle) Host() string {
158-
return b.host
159-
}
160-
161-
func (b *Bundle) Port() int {
162-
return b.port
163-
}
164-
165-
func (b *Bundle) TLSConfig() *tls.Config {
166-
return b.tlsConfig.Clone()
167-
}
168-
169157
func extract(reader *zip.Reader) (map[string][]byte, error) {
170158
contents := make(map[string][]byte)
171159

astra/bundle_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,22 +40,22 @@ func TestLoadBundleZip(t *testing.T) {
4040
b, err := LoadBundleZipFromPath(path)
4141
require.NoError(t, err)
4242

43-
assert.Equal(t, hostname, b.Host())
44-
assert.Equal(t, port, b.Port())
43+
assert.Equal(t, hostname, b.Host)
44+
assert.Equal(t, port, b.Port)
4545

4646
block, _ := pem.Decode(testCAPEM)
4747
ca, _ := x509.ParseCertificate(block.Bytes)
4848

4949
// Verify CA added to cert pool
5050
caSub, err := asn1.Marshal(ca.Subject.ToRDNSequence())
5151
found := false
52-
for _, sub := range b.TLSConfig().RootCAs.Subjects() {
52+
for _, sub := range b.TLSConfig.RootCAs.Subjects() {
5353
if bytes.Compare(caSub, sub) == 0 {
5454
found = true
5555
}
5656
}
5757
assert.True(t, found)
58-
require.Equal(t, 1, len(b.TLSConfig().Certificates))
58+
require.Equal(t, 1, len(b.TLSConfig.Certificates))
5959
}
6060

6161
func TestLoadBundleZip_InvalidJson(t *testing.T) {

astra/endpoint.go

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@
1515
package astra
1616

1717
import (
18+
"context"
1819
"crypto/tls"
1920
"crypto/x509"
2021
"encoding/json"
2122
"errors"
2223
"fmt"
23-
"io/ioutil"
2424
"net/http"
2525
"sync"
2626
"time"
@@ -33,6 +33,7 @@ type astraResolver struct {
3333
region string
3434
bundle *Bundle
3535
mu *sync.Mutex
36+
timeout time.Duration
3637
}
3738

3839
type astraEndpoint struct {
@@ -41,28 +42,38 @@ type astraEndpoint struct {
4142
tlsConfig *tls.Config
4243
}
4344

44-
func NewResolver(bundle *Bundle) proxycore.EndpointResolver {
45+
func NewResolver(bundle *Bundle, timeout time.Duration) proxycore.EndpointResolver {
4546
return &astraResolver{
46-
bundle: bundle,
47-
mu: &sync.Mutex{},
47+
bundle: bundle,
48+
mu: &sync.Mutex{},
49+
timeout: timeout,
4850
}
4951
}
5052

51-
func (r *astraResolver) Resolve() ([]proxycore.Endpoint, error) {
53+
func (r *astraResolver) Resolve(ctx context.Context) ([]proxycore.Endpoint, error) {
5254
var metadata *astraMetadata
5355

54-
url := fmt.Sprintf("https://%s:%d/metadata", r.bundle.Host(), r.bundle.Port())
56+
ctx, cancel := context.WithTimeout(ctx, r.timeout)
57+
defer cancel()
58+
5559
httpsClient := &http.Client{
5660
Transport: &http.Transport{
57-
TLSClientConfig: r.bundle.TLSConfig(),
61+
TLSClientConfig: r.bundle.TLSConfig.Clone(),
5862
},
5963
}
60-
response, err := httpsClient.Get(url)
64+
65+
url := fmt.Sprintf("https://%s:%d/metadata", r.bundle.Host, r.bundle.Port)
66+
req, err := http.NewRequestWithContext(ctx, "GET", url, http.NoBody)
67+
if err != nil {
68+
return nil, err
69+
}
70+
71+
response, err := httpsClient.Do(req)
6172
if err != nil {
62-
return nil, fmt.Errorf("unable to get metadata from %s: %v", url, err)
73+
return nil, fmt.Errorf("unable to get metadata from %s: %w", url, err)
6374
}
6475

65-
body, err := ioutil.ReadAll(response.Body)
76+
body, err := readAllWithTimeout(response.Body, ctx)
6677
if err != nil {
6778
return nil, err
6879
}
@@ -145,7 +156,7 @@ func (a astraEndpoint) TLSConfig() *tls.Config {
145156
}
146157

147158
func copyTLSConfig(bundle *Bundle, serverName string) *tls.Config {
148-
tlsConfig := bundle.TLSConfig()
159+
tlsConfig := bundle.TLSConfig.Clone()
149160
tlsConfig.ServerName = serverName
150161
tlsConfig.InsecureSkipVerify = true
151162
tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
@@ -161,7 +172,7 @@ func copyTLSConfig(bundle *Bundle, serverName string) *tls.Config {
161172
opts := x509.VerifyOptions{
162173
Roots: tlsConfig.RootCAs,
163174
CurrentTime: time.Now(),
164-
DNSName: bundle.Host(),
175+
DNSName: bundle.Host,
165176
Intermediates: x509.NewCertPool(),
166177
}
167178
for _, cert := range certs[1:] {

astra/endpoint_test.go

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@
1515
package astra
1616

1717
import (
18+
"context"
1819
"crypto/tls"
1920
"encoding/json"
2021
"errors"
2122
"net"
2223
"net/http"
2324
"os"
2425
"testing"
26+
"time"
2527

2628
"github.com/datastax/cql-proxy/proxycore"
2729
"github.com/datastax/go-cassandra-native-protocol/datatype"
@@ -51,7 +53,7 @@ func TestMain(m *testing.M) {
5153

5254
func TestAstraResolver_Resolve(t *testing.T) {
5355
resolver := createResolver(t)
54-
endpoints, err := resolver.Resolve()
56+
endpoints, err := resolver.Resolve(context.Background())
5557
require.NoError(t, err)
5658

5759
for _, endpoint := range endpoints {
@@ -63,7 +65,7 @@ func TestAstraResolver_Resolve(t *testing.T) {
6365

6466
func TestAstraResolver_NewEndpoint(t *testing.T) {
6567
resolver := createResolver(t)
66-
_, err := resolver.Resolve()
68+
_, err := resolver.Resolve(context.Background())
6769
require.NoError(t, err)
6870

6971
const hostId = "a2e24181-d732-402a-ab06-894a8b2f6094"
@@ -101,7 +103,7 @@ func TestAstraResolver_NewEndpoint(t *testing.T) {
101103

102104
func TestAstraResolver_NewEndpoint_Ignored(t *testing.T) {
103105
resolver := createResolver(t)
104-
_, err := resolver.Resolve()
106+
_, err := resolver.Resolve(context.Background())
105107
require.NoError(t, err)
106108

107109
const hostId = "a2e24181-d732-402a-ab06-894a8b2f6094"
@@ -138,7 +140,7 @@ func TestAstraResolver_NewEndpoint_Ignored(t *testing.T) {
138140

139141
func TestAstraResolver_NewEndpointInvalidHostID(t *testing.T) {
140142
resolver := createResolver(t)
141-
_, err := resolver.Resolve()
143+
_, err := resolver.Resolve(context.Background())
142144
require.NoError(t, err)
143145

144146
rs := proxycore.NewResultSet(&message.RowsResult{
@@ -164,14 +166,23 @@ func TestAstraResolver_NewEndpointInvalidHostID(t *testing.T) {
164166
assert.Error(t, err, "ignoring host because its `host_id` is not set or is invalid")
165167
}
166168

169+
func TestAstraResolver_Timeout(t *testing.T) {
170+
ctx, cancel := context.WithTimeout(context.Background(), 1) // Very short timeout
171+
defer cancel()
172+
173+
resolver := createResolver(t)
174+
_, err := resolver.Resolve(ctx)
175+
assert.ErrorIs(t, err, context.DeadlineExceeded) // Expect a timeout
176+
}
177+
167178
func createResolver(t *testing.T) proxycore.EndpointResolver {
168179
path, err := writeBundle("127.0.0.1", 8080)
169180
require.NoError(t, err)
170181

171182
bundle, err := LoadBundleZipFromPath(path)
172183
require.NoError(t, err)
173184

174-
return NewResolver(bundle)
185+
return NewResolver(bundle, 10*time.Second)
175186
}
176187

177188
func runTestMetaSvcAsync(sniProxyAddr string, contactPoints []string) (*http.Server, error) {

proxy/run.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ type runConfig struct {
4444
AstraToken string `yaml:"astra-token" help:"Token used to authenticate to an Astra database. Requires '--astra-database-id'. Ignored if using the bundle path or contact points option." short:"t" env:"ASTRA_TOKEN"`
4545
AstraDatabaseID string `yaml:"astra-database-id" help:"Database ID of the Astra database. Requires '--astra-token'" short:"i" env:"ASTRA_DATABASE_ID"`
4646
AstraApiURL string `yaml:"astra-api-url" help:"URL for the Astra API" default:"https://api.astra.datastax.com" env:"ASTRA_API_URL"`
47+
AstraTimeout time.Duration `yaml:"astra-timeout" help:"Timeout for contacting Astra when retrieving the bundle and metadata" default:"10s" env:"ASTRA_TIMEOUT"`
4748
ContactPoints []string `yaml:"contact-points" help:"Contact points for cluster. Ignored if using the bundle path or token option." short:"c" env:"CONTACT_POINTS"`
4849
Username string `yaml:"username" help:"Username to use for authentication" short:"u" env:"USERNAME"`
4950
Password string `yaml:"password" help:"Password to use for authentication" short:"p" env:"PASSWORD"`
@@ -101,19 +102,18 @@ func Run(ctx context.Context, args []string) int {
101102
if len(cfg.AstraBundle) > 0 {
102103
if bundle, err := astra.LoadBundleZipFromPath(cfg.AstraBundle); err != nil {
103104
cliCtx.Errorf("unable to open bundle %s from file: %v", cfg.AstraBundle, err)
104-
return 1
105105
} else {
106-
resolver = astra.NewResolver(bundle)
106+
resolver = astra.NewResolver(bundle, cfg.AstraTimeout)
107107
}
108108
} else if len(cfg.AstraToken) > 0 {
109109
if len(cfg.AstraDatabaseID) == 0 {
110110
cliCtx.Fatalf("database ID is required when using a token")
111111
}
112-
bundle, err := astra.LoadBundleZipFromURL(cfg.AstraApiURL, cfg.AstraDatabaseID, cfg.AstraToken, 10*time.Second)
113-
if err != nil {
112+
if bundle, err := astra.LoadBundleZipFromURL(cfg.AstraApiURL, cfg.AstraDatabaseID, cfg.AstraToken, cfg.AstraTimeout); err != nil {
114113
cliCtx.Fatalf("unable to load bundle %s from astra: %v", cfg.AstraBundle, err)
114+
} else {
115+
resolver = astra.NewResolver(bundle, cfg.AstraTimeout)
115116
}
116-
resolver = astra.NewResolver(bundle)
117117
cfg.Username = "token"
118118
cfg.Password = cfg.AstraToken
119119
} else if len(cfg.ContactPoints) > 0 {

proxycore/cluster.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ func ConnectCluster(ctx context.Context, config ClusterConfig) (*Cluster, error)
150150
listeners: make([]ClusterListener, 0),
151151
}
152152

153-
endpoints, err := config.Resolver.Resolve()
153+
endpoints, err := config.Resolver.Resolve(ctx)
154154
if err != nil {
155155
return nil, err
156156
}

proxycore/endpoint.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
package proxycore
1616

1717
import (
18+
"context"
1819
"crypto/tls"
1920
"errors"
2021
"fmt"
@@ -59,7 +60,7 @@ func (e defaultEndpoint) TLSConfig() *tls.Config {
5960
}
6061

6162
type EndpointResolver interface {
62-
Resolve() ([]Endpoint, error)
63+
Resolve(ctx context.Context) ([]Endpoint, error)
6364
NewEndpoint(row Row) (Endpoint, error)
6465
}
6566

@@ -87,14 +88,15 @@ func NewResolverWithDefaultPort(contactPoints []string, defaultPort int) Endpoin
8788
}
8889
}
8990

90-
func (r *defaultEndpointResolver) Resolve() ([]Endpoint, error) {
91+
func (r *defaultEndpointResolver) Resolve(ctx context.Context) ([]Endpoint, error) {
9192
var endpoints []Endpoint
93+
var resolver net.Resolver
9294
for _, cp := range r.contactPoints {
9395
host, port, err := net.SplitHostPort(cp)
9496
if err != nil {
9597
host = cp
9698
}
97-
addrs, err := net.LookupHost(host)
99+
addrs, err := resolver.LookupHost(ctx, host)
98100
if err != nil {
99101
return nil, fmt.Errorf("unable to resolve contact point %s: %v", cp, err)
100102
}

0 commit comments

Comments
 (0)