diff --git a/proxy/proxy_test.go b/proxy/proxy_test.go index 3e7a2ac..0ba0022 100644 --- a/proxy/proxy_test.go +++ b/proxy/proxy_test.go @@ -341,7 +341,7 @@ func setupProxyTestWithConfig(ctx context.Context, numNodes int, cfg *proxyTestC tester.proxy = NewProxy(ctx, Config{ Version: primitive.ProtocolVersion4, - Resolver: proxycore.NewResolverWithDefaultPort([]string{clusterAddr}, clusterPort), + Resolver: proxycore.NewResolverWithDefaultPort([]string{clusterAddr}, clusterPort, nil), ReconnectPolicy: proxycore.NewReconnectPolicyWithDelays(200*time.Millisecond, time.Second), NumConns: 2, HeartBeatInterval: 30 * time.Second, diff --git a/proxy/run.go b/proxy/run.go index e67f58d..01b46c8 100644 --- a/proxy/run.go +++ b/proxy/run.go @@ -17,6 +17,7 @@ package proxy import ( "context" "crypto/tls" + "crypto/x509" "encoding/json" "errors" "fmt" @@ -63,6 +64,9 @@ type runConfig struct { NumConns int `yaml:"num-conns" help:"Number of connection to create to each node of the backend cluster" default:"1" env:"NUM_CONNS"` ProxyCertFile string `yaml:"proxy-cert-file" help:"Path to a PEM encoded certificate file with its intermediate certificate chain. This is used to encrypt traffic for proxy clients" env:"PROXY_CERT_FILE"` ProxyKeyFile string `yaml:"proxy-key-file" help:"Path to a PEM encoded private key file. This is used to encrypt traffic for proxy clients" env:"PROXY_KEY_FILE"` + ClusterCAFile string `yaml:"cluster-ca-file" help:"Path to a PEM encoded file with CA certificates and their intermediate certificate chains. This is used to encrypt traffic between the proxy and the backend cluster" env:"CLUSTER_CA_FILE"` + ClusterCertFile string `yaml:"cluster-cert-file" help:"Path to a PEM encoded client certificate file with its intermediate certificate chain. This is used for mutual TLS when connecting to the backend cluster" env:"CLUSTER_CERT_FILE"` + ClusterKeyFile string `yaml:"cluster-key-file" help:"Path to a PEM encoded client private key file. This is used for mutual TLS when connecting to the backend cluster" env:"CLUSTER_KEY_FILE"` RpcAddress string `yaml:"rpc-address" help:"Address to advertise in the 'system.local' table for 'rpc_address'. It must be set if configuring peer proxies" env:"RPC_ADDRESS"` DataCenter string `yaml:"data-center" help:"Data center to use in system tables" env:"DATA_CENTER"` Tokens []string `yaml:"tokens" help:"Tokens to use in the system tables. It's not recommended" env:"TOKENS"` @@ -117,7 +121,32 @@ func Run(ctx context.Context, args []string) int { cfg.Username = "token" cfg.Password = cfg.AstraToken } else if len(cfg.ContactPoints) > 0 { - resolver = proxycore.NewResolverWithDefaultPort(cfg.ContactPoints, cfg.Port) + var tlsConfig *tls.Config + + if len(cfg.ClusterCAFile) > 0 { // Use proxy to cluster TLS + caCert, err := ioutil.ReadFile(cfg.ClusterCAFile) + + if err != nil { + cliCtx.Fatalf("unable to load cluster CA file %s: %v", cfg.ClusterCAFile, err) + } + caCertPool := x509.NewCertPool() + caCertPool.AppendCertsFromPEM(caCert) + + tlsConfig = &tls.Config{ + RootCAs: caCertPool, + InsecureSkipVerify: true, // Skip server name validation + } + + if len(cfg.ClusterCertFile) > 0 || len(cfg.ClusterKeyFile) > 0 { + cert, err := loadCertificate(cfg.ClusterCertFile, cfg.ClusterKeyFile) + if err != nil { + cliCtx.Fatalf("problem loading cluster TLS client certificate pair: %v", err) + } + tlsConfig.Certificates = []tls.Certificate{cert} + } + } + + resolver = proxycore.NewResolverWithDefaultPort(cfg.ContactPoints, cfg.Port, tlsConfig) } else { cliCtx.Errorf("must provide either bundle path, token, or contact points") return 1 @@ -342,14 +371,22 @@ func (c *runConfig) listenAndServe(p *Proxy, mux *http.ServeMux, ctx context.Con return err } -func resolveAndListen(address, cert, key string) (net.Listener, error) { - if len(cert) > 0 || len(key) > 0 { - if len(cert) == 0 || len(key) == 0 { - return nil, errors.New("both certificate and private key are required for TLS") - } - cert, err := tls.LoadX509KeyPair(cert, key) +func loadCertificate(certFile, keyFile string) (tls.Certificate, error) { + if len(certFile) == 0 || len(keyFile) == 0 { + return tls.Certificate{}, errors.New("both certificate and private key are required for TLS") + } + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return tls.Certificate{}, err + } + return cert, nil +} + +func resolveAndListen(address, certFile, keyFile string) (net.Listener, error) { + if len(certFile) > 0 || len(keyFile) > 0 { + cert, err := loadCertificate(certFile, keyFile) if err != nil { - return nil, fmt.Errorf("unable to load TLS certificate pair: %v", err) + return nil, fmt.Errorf("problem loading proxy TLS certificate pair: %v", err) } return tls.Listen("tcp", address, &tls.Config{Certificates: []tls.Certificate{cert}}) } else { diff --git a/proxycore/cluster.go b/proxycore/cluster.go index 79275ba..84fca4a 100644 --- a/proxycore/cluster.go +++ b/proxycore/cluster.go @@ -264,6 +264,12 @@ func (c *Cluster) mergeHosts(hosts []*Host) error { c.logger.Info("adding host to the cluster", zap.Stringer("host", host)) c.sendEvent(&AddEvent{host}) } + + endpoints, err := c.config.Resolver.Resolve(c.ctx) + if err != nil { + return err + } + host.Endpoint = endpoints[len(endpoints)-1] } for _, host := range existing { diff --git a/proxycore/endpoint.go b/proxycore/endpoint.go index aff0e1d..bd36d0d 100644 --- a/proxycore/endpoint.go +++ b/proxycore/endpoint.go @@ -67,6 +67,7 @@ type EndpointResolver interface { type defaultEndpointResolver struct { contactPoints []string defaultPort string + tlsConfig *tls.Config } func NewEndpoint(addr string) Endpoint { @@ -78,13 +79,14 @@ func NewEndpointTLS(addr string, cfg *tls.Config) Endpoint { } func NewResolver(contactPoints ...string) EndpointResolver { - return NewResolverWithDefaultPort(contactPoints, 9042) + return NewResolverWithDefaultPort(contactPoints, 9042, nil) } -func NewResolverWithDefaultPort(contactPoints []string, defaultPort int) EndpointResolver { +func NewResolverWithDefaultPort(contactPoints []string, defaultPort int, tlsConfig *tls.Config) EndpointResolver { return &defaultEndpointResolver{ contactPoints: contactPoints, defaultPort: strconv.Itoa(defaultPort), + tlsConfig: tlsConfig, } } @@ -106,6 +108,7 @@ func (r *defaultEndpointResolver) Resolve(ctx context.Context) ([]Endpoint, erro for _, addr := range addrs { endpoints = append(endpoints, &defaultEndpoint{ addr: net.JoinHostPort(addr, port), + tlsConfig: r.tlsConfig, }) } }