Skip to content

Commit b93c1a6

Browse files
authored
Implement mTLS resources and configuration for Target Allocator server (#284)
1 parent f946810 commit b93c1a6

File tree

11 files changed

+627
-87
lines changed

11 files changed

+627
-87
lines changed
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package config
5+
6+
import (
7+
"context"
8+
"crypto/tls"
9+
"crypto/x509"
10+
"fmt"
11+
"os"
12+
"sync"
13+
"time"
14+
15+
"github.com/fsnotify/fsnotify"
16+
"sigs.k8s.io/controller-runtime/pkg/certwatcher"
17+
)
18+
19+
type CertAndCAWatcher struct {
20+
certWatcher *certwatcher.CertWatcher
21+
22+
caFilePath string
23+
caPool *x509.CertPool
24+
caWatcher *fsnotify.Watcher
25+
26+
mu sync.RWMutex
27+
}
28+
29+
func NewCertAndCAWatcher(certPath, keyPath, caPath string) (*CertAndCAWatcher, error) {
30+
certWatcher, err := certwatcher.New(certPath, keyPath)
31+
if err != nil {
32+
return nil, fmt.Errorf("error creating cert watcher: %w", err)
33+
}
34+
35+
caPool, err := loadCAPool(caPath)
36+
if err != nil {
37+
return nil, fmt.Errorf("error loading CA pool: %w", err)
38+
}
39+
40+
caWatcher, err := fsnotify.NewWatcher()
41+
if err != nil {
42+
return nil, fmt.Errorf("error creating CA file watcher: %w", err)
43+
}
44+
if err := caWatcher.Add(caPath); err != nil {
45+
return nil, fmt.Errorf("error adding CA file to watcher: %w", err)
46+
}
47+
48+
return &CertAndCAWatcher{
49+
certWatcher: certWatcher,
50+
caFilePath: caPath,
51+
caPool: caPool,
52+
caWatcher: caWatcher,
53+
}, nil
54+
}
55+
56+
func loadCAPool(caPath string) (*x509.CertPool, error) {
57+
caCert, err := os.ReadFile(caPath)
58+
caCertPool := x509.NewCertPool()
59+
if err != nil {
60+
return nil, fmt.Errorf("error reading CA file: %w", err)
61+
}
62+
caCertPool.AppendCertsFromPEM(caCert)
63+
return caCertPool, nil
64+
}
65+
66+
func (w *CertAndCAWatcher) Start(ctx context.Context) error {
67+
go func() {
68+
_ = w.certWatcher.Start(ctx)
69+
}()
70+
71+
go w.watchCA(ctx)
72+
73+
<-ctx.Done()
74+
return nil
75+
}
76+
77+
func (w *CertAndCAWatcher) watchCA(ctx context.Context) {
78+
for {
79+
select {
80+
case event, ok := <-w.caWatcher.Events:
81+
if !ok {
82+
return
83+
}
84+
if event.Op.Has(fsnotify.Write) || event.Op.Has(fsnotify.Create) || event.Op.Has(fsnotify.Remove) {
85+
newPool, err := loadCAPool(w.caFilePath)
86+
if err != nil {
87+
continue
88+
}
89+
w.mu.Lock()
90+
w.caPool = newPool
91+
w.mu.Unlock()
92+
93+
// needed incase file removed
94+
if event.Op.Has(fsnotify.Remove) {
95+
time.Sleep(100 * time.Millisecond)
96+
_ = w.caWatcher.Add(w.caFilePath)
97+
}
98+
}
99+
case <-ctx.Done():
100+
return
101+
}
102+
}
103+
}
104+
105+
func (w *CertAndCAWatcher) GetCertificate(clientHello *tls.ClientHelloInfo) (*tls.Certificate, error) {
106+
return w.certWatcher.GetCertificate(clientHello)
107+
}
108+
109+
func (w *CertAndCAWatcher) GetCAPool() *x509.CertPool {
110+
w.mu.RLock()
111+
defer w.mu.RUnlock()
112+
return w.caPool
113+
}
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
4+
package config
5+
6+
import (
7+
"bytes"
8+
"context"
9+
"crypto/rand"
10+
"crypto/rsa"
11+
"crypto/x509"
12+
"crypto/x509/pkix"
13+
"encoding/pem"
14+
"math/big"
15+
"os"
16+
"path/filepath"
17+
"testing"
18+
"time"
19+
)
20+
21+
func generateSelfSignedCertAndKey(commonName string) (certPEM, keyPEM []byte, err error) {
22+
// Generate RSA key
23+
priv, err := rsa.GenerateKey(rand.Reader, 2048)
24+
if err != nil {
25+
return nil, nil, err
26+
}
27+
28+
// Create a minimal self-signed certificate template
29+
serial, err := rand.Int(rand.Reader, big.NewInt(1<<63-1))
30+
if err != nil {
31+
return nil, nil, err
32+
}
33+
34+
template := &x509.Certificate{
35+
SerialNumber: serial,
36+
Subject: pkix.Name{
37+
CommonName: commonName,
38+
},
39+
NotBefore: time.Now().Add(-time.Hour),
40+
NotAfter: time.Now().Add(time.Hour),
41+
42+
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment | x509.KeyUsageCertSign,
43+
IsCA: true,
44+
BasicConstraintsValid: true,
45+
}
46+
47+
// Self-sign the certificate
48+
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
49+
if err != nil {
50+
return nil, nil, err
51+
}
52+
53+
// Encode cert + key to PEM
54+
var certBuf, keyBuf bytes.Buffer
55+
err = pem.Encode(&certBuf, &pem.Block{Type: "CERTIFICATE", Bytes: der})
56+
if err != nil {
57+
return nil, nil, err
58+
}
59+
err = pem.Encode(&keyBuf, &pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
60+
if err != nil {
61+
return nil, nil, err
62+
}
63+
64+
return certBuf.Bytes(), keyBuf.Bytes(), nil
65+
}
66+
67+
func TestCertAndCAWatcher_UpdatesCA(t *testing.T) {
68+
t.Parallel()
69+
70+
// Generate a server cert/key for certwatcher
71+
certPEM, keyPEM, err := generateSelfSignedCertAndKey("test-server")
72+
if err != nil {
73+
t.Fatalf("failed to generate server cert/key: %v", err)
74+
}
75+
76+
// Generate two distinct self-signed certs to represent old CA vs new CA
77+
oldCAPEM, _, err := generateSelfSignedCertAndKey("old-ca")
78+
if err != nil {
79+
t.Fatalf("failed to generate old CA: %v", err)
80+
}
81+
newCAPEM, _, err := generateSelfSignedCertAndKey("new-ca")
82+
if err != nil {
83+
t.Fatalf("failed to generate new CA: %v", err)
84+
}
85+
86+
// Write all these PEM files into a temp dir
87+
tmpDir := t.TempDir()
88+
89+
certPath := filepath.Join(tmpDir, "tls.crt")
90+
keyPath := filepath.Join(tmpDir, "tls.key")
91+
caPath := filepath.Join(tmpDir, "ca.crt")
92+
93+
if err := os.WriteFile(certPath, certPEM, 0600); err != nil {
94+
t.Fatalf("failed to write cert file: %v", err)
95+
}
96+
if err := os.WriteFile(keyPath, keyPEM, 0600); err != nil {
97+
t.Fatalf("failed to write key file: %v", err)
98+
}
99+
if err := os.WriteFile(caPath, oldCAPEM, 0600); err != nil {
100+
t.Fatalf("failed to write initial CA file: %v", err)
101+
}
102+
103+
// Create the CertAndCAWatcher using our files
104+
watcher, err := NewCertAndCAWatcher(certPath, keyPath, caPath)
105+
if err != nil {
106+
t.Fatalf("failed to create CertAndCAWatcher: %v", err)
107+
}
108+
109+
// Start the watcher in the background
110+
ctx, cancel := context.WithCancel(context.Background())
111+
defer cancel()
112+
go func() {
113+
_ = watcher.Start(ctx)
114+
}()
115+
116+
// Record the initial CA pool pointer
117+
oldPool := watcher.GetCAPool()
118+
if oldPool == nil {
119+
t.Fatal("expected non-nil initial CA pool")
120+
}
121+
122+
// Overwrite the CA file with newCAPEM, triggering a reload
123+
if err := os.WriteFile(caPath, newCAPEM, 0600); err != nil {
124+
t.Fatalf("failed to write new CA file: %v", err)
125+
}
126+
127+
// Loop until the watcher updates the CA pool (or times out)
128+
deadline := time.Now().Add(2 * time.Second)
129+
for {
130+
newPool := watcher.GetCAPool()
131+
if newPool != oldPool {
132+
t.Log("CA pool successfully updated.")
133+
return
134+
}
135+
if time.Now().After(deadline) {
136+
t.Fatal("timed out waiting for CA pool to be updated")
137+
}
138+
time.Sleep(100 * time.Millisecond)
139+
}
140+
}

cmd/amazon-cloudwatch-agent-target-allocator/config/config.go

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ package config
66
import (
77
"context"
88
"crypto/tls"
9-
"crypto/x509"
109
"errors"
1110
"fmt"
1211
"io/fs"
@@ -24,23 +23,22 @@ import (
2423
"k8s.io/client-go/tools/clientcmd"
2524
"k8s.io/klog/v2"
2625
ctrl "sigs.k8s.io/controller-runtime"
27-
"sigs.k8s.io/controller-runtime/pkg/certwatcher"
2826
"sigs.k8s.io/controller-runtime/pkg/log/zap"
2927

3028
tamanifest "github.com/aws/amazon-cloudwatch-agent-operator/internal/manifests/targetallocator"
3129
)
3230

3331
const (
34-
DefaultResyncTime = 5 * time.Minute
35-
DefaultConfigFilePath string = "/conf/targetallocator.yaml"
36-
DefaultCRScrapeInterval model.Duration = model.Duration(time.Second * 30)
37-
DefaultAllocationStrategy = "consistent-hashing"
38-
DefaultFilterStrategy = "relabel-config"
39-
DefaultListenAddr = ":8443"
40-
DefaultCertMountPath = tamanifest.TACertMountPath
41-
DefaultTLSKeyPath = DefaultCertMountPath + "/server.key"
42-
DefaultTLSCertPath = DefaultCertMountPath + "/server.crt"
43-
DefaultCABundlePath = ""
32+
DefaultResyncTime = 5 * time.Minute
33+
DefaultConfigFilePath string = "/conf/targetallocator.yaml"
34+
DefaultCRScrapeInterval model.Duration = model.Duration(time.Second * 30)
35+
DefaultAllocationStrategy = "consistent-hashing"
36+
DefaultListenAddr = ":8443"
37+
DefaultCertMountPath = tamanifest.TACertMountPath
38+
DefaultClientCertMountPath = tamanifest.ClientCertMountPath
39+
DefaultTLSKeyPath = DefaultCertMountPath + "/server.key"
40+
DefaultTLSCertPath = DefaultCertMountPath + "/server.crt"
41+
DefaultCABundlePath = DefaultClientCertMountPath + "/tls-ca.crt"
4442
)
4543

4644
type Config struct {
@@ -150,7 +148,6 @@ func LoadFromCLI(target *Config, flagSet *pflag.FlagSet) error {
150148
}
151149

152150
func unmarshal(cfg *Config, configFile string) error {
153-
154151
yamlFile, err := os.ReadFile(configFile)
155152
if err != nil {
156153
return err
@@ -217,31 +214,29 @@ func ValidateConfig(config *Config) error {
217214
}
218215

219216
func (c HTTPSServerConfig) NewTLSConfig(ctx context.Context) (*tls.Config, error) {
220-
tlsConfig := &tls.Config{
221-
MinVersion: tls.VersionTLS13,
222-
}
223-
224-
certWatcher, err := certwatcher.New(c.TLSCertFilePath, c.TLSKeyFilePath)
217+
certWatcher, err := NewCertAndCAWatcher(c.TLSCertFilePath, c.TLSKeyFilePath, c.CAFilePath)
225218
if err != nil {
226-
return nil, err
219+
return nil, fmt.Errorf("error creating certwatcher: %w", err)
227220
}
228-
tlsConfig.GetCertificate = certWatcher.GetCertificate
221+
229222
go func() {
230223
_ = certWatcher.Start(ctx)
231224
}()
232225

233-
if c.CAFilePath == "" {
234-
return tlsConfig, nil
226+
// Create the TLS config
227+
tlsConfig := &tls.Config{
228+
MinVersion: tls.VersionTLS13,
229+
GetCertificate: certWatcher.GetCertificate,
230+
ClientCAs: certWatcher.GetCAPool(),
231+
ClientAuth: tls.RequireAndVerifyClientCert,
235232
}
236233

237-
caCert, err := os.ReadFile(c.CAFilePath)
238-
caCertPool := x509.NewCertPool()
239-
if err != nil {
240-
return nil, err
234+
// Dynamically update the CA pool if needed
235+
tlsConfig.GetConfigForClient = func(clientHello *tls.ClientHelloInfo) (*tls.Config, error) {
236+
newTLSConfig := tlsConfig.Clone()
237+
newTLSConfig.ClientCAs = certWatcher.GetCAPool()
238+
return newTLSConfig, nil
241239
}
242-
caCertPool.AppendCertsFromPEM(caCert)
243-
tlsConfig.ClientCAs = caCertPool
244-
tlsConfig.ClientAuth = tls.RequireAndVerifyClientCert
245240

246241
return tlsConfig, nil
247242
}

0 commit comments

Comments
 (0)