Skip to content

Commit 7ae202e

Browse files
authored
[TA] Target Allocator TLS Unit-tests (#265)
* TLS tests
1 parent 31db083 commit 7ae202e

File tree

3 files changed

+278
-2
lines changed

3 files changed

+278
-2
lines changed

cmd/amazon-cloudwatch-agent-target-allocator/config/config.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,10 @@ import (
1111
"fmt"
1212
"io/fs"
1313
"os"
14-
"sigs.k8s.io/controller-runtime/pkg/certwatcher"
1514
"time"
1615

16+
"sigs.k8s.io/controller-runtime/pkg/certwatcher"
17+
1718
"github.com/go-logr/logr"
1819
"github.com/prometheus/common/model"
1920
promconfig "github.com/prometheus/prometheus/config"

cmd/amazon-cloudwatch-agent-target-allocator/server/server.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ func WithTLSConfig(tlsConfig *tls.Config, httpsListenAddr string) Option {
7373
s.setRouter(httpsRouter)
7474

7575
s.httpsServer = &http.Server{Addr: httpsListenAddr, Handler: httpsRouter, ReadHeaderTimeout: 90 * time.Second, TLSConfig: tlsConfig}
76+
s.server.Shutdown(context.Background())
77+
s.server = s.httpsServer
7678
}
7779
}
7880

cmd/amazon-cloudwatch-agent-target-allocator/server/server_test.go

Lines changed: 274 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,22 @@
44
package server
55

66
import (
7+
"context"
8+
"crypto/ecdsa"
9+
"crypto/elliptic"
10+
"crypto/rand"
711
"crypto/tls"
12+
"crypto/x509"
13+
"crypto/x509/pkix"
814
"encoding/json"
15+
"encoding/pem"
916
"fmt"
1017
"io"
18+
"math/big"
1119
"net/http"
1220
"net/http/httptest"
1321
"net/url"
22+
"os"
1423
"testing"
1524
"time"
1625

@@ -185,7 +194,7 @@ func TestServer_TargetsHandler(t *testing.T) {
185194

186195
func TestServer_ScrapeConfigsHandler(t *testing.T) {
187196
svrConfig := allocatorconfig.HTTPSServerConfig{}
188-
tlsConfig, _ := svrConfig.NewTLSConfig()
197+
tlsConfig, _ := svrConfig.NewTLSConfig(context.TODO())
189198
tests := []struct {
190199
description string
191200
scrapeConfigs map[string]*promconfig.ScrapeConfig
@@ -605,6 +614,7 @@ func TestServer_JobHandler(t *testing.T) {
605614
})
606615
}
607616
}
617+
608618
func TestServer_Readiness(t *testing.T) {
609619
tests := []struct {
610620
description string
@@ -669,6 +679,269 @@ func TestServer_Readiness(t *testing.T) {
669679
}
670680
}
671681

682+
func TestServer_ValidCAonTLS(t *testing.T) {
683+
listenAddr := ":8443"
684+
server, clientTlsConfig, err := createTestTLSServer(listenAddr)
685+
assert.NoError(t, err)
686+
go func() {
687+
assert.ErrorIs(t, server.StartHTTPS(), http.ErrServerClosed)
688+
}()
689+
time.Sleep(100 * time.Millisecond) // wait for server to launch
690+
defer func() {
691+
err := server.ShutdownHTTPS(context.Background())
692+
if err != nil {
693+
assert.NoError(t, err)
694+
}
695+
}()
696+
tests := []struct {
697+
description string
698+
endpoint string
699+
expectedCode int
700+
}{
701+
{
702+
description: "with tls test for scrape config",
703+
endpoint: "scrape_configs",
704+
expectedCode: http.StatusOK,
705+
},
706+
{
707+
description: "with tls test for jobs",
708+
endpoint: "jobs",
709+
expectedCode: http.StatusOK,
710+
},
711+
}
712+
for _, tc := range tests {
713+
t.Run(tc.description, func(t *testing.T) {
714+
// Create a custom HTTP client with TLS transport
715+
client := &http.Client{
716+
Transport: &http.Transport{
717+
TLSClientConfig: clientTlsConfig,
718+
},
719+
}
720+
721+
// Make the GET request
722+
request, err := client.Get(fmt.Sprintf("https://localhost%s/%s", listenAddr, tc.endpoint))
723+
724+
// Verify if a certificate verification error occurred
725+
require.NoError(t, err)
726+
727+
// Only check the status code if there was no error
728+
if err == nil {
729+
assert.Equal(t, tc.expectedCode, request.StatusCode)
730+
} else {
731+
t.Log(err)
732+
}
733+
})
734+
}
735+
}
736+
737+
func TestServer_MissingCAonTLS(t *testing.T) {
738+
listenAddr := ":8443"
739+
server, _, err := createTestTLSServer(listenAddr)
740+
assert.NoError(t, err)
741+
go func() {
742+
assert.ErrorIs(t, server.StartHTTPS(), http.ErrServerClosed)
743+
}()
744+
time.Sleep(100 * time.Millisecond) // wait for server to launch
745+
defer func() {
746+
err := server.ShutdownHTTPS(context.Background())
747+
if err != nil {
748+
assert.NoError(t, err)
749+
}
750+
}()
751+
tests := []struct {
752+
description string
753+
endpoint string
754+
expectedCode int
755+
}{
756+
{
757+
description: "no tls test for scrape config",
758+
endpoint: "scrape_configs",
759+
expectedCode: http.StatusBadRequest,
760+
},
761+
{
762+
description: "no tls test for jobs",
763+
endpoint: "jobs",
764+
expectedCode: http.StatusBadRequest,
765+
},
766+
}
767+
for _, tc := range tests {
768+
t.Run(tc.description, func(t *testing.T) {
769+
request, err := http.Get(fmt.Sprintf("https://localhost%s/%s", listenAddr, tc.endpoint))
770+
771+
// Verify if a certificate verification error occurred
772+
require.Error(t, err)
773+
774+
// Only check the status code if there was no error
775+
if err == nil {
776+
assert.Equal(t, tc.expectedCode, request.StatusCode)
777+
}
778+
})
779+
}
780+
}
781+
782+
func TestServer_HTTPOnTLS(t *testing.T) {
783+
listenAddr := ":8443"
784+
server, _, err := createTestTLSServer(listenAddr)
785+
assert.NoError(t, err)
786+
go func() {
787+
assert.NoError(t, server.StartHTTPS())
788+
}()
789+
time.Sleep(100 * time.Millisecond) // wait for server to launch
790+
791+
defer func(s *Server, ctx context.Context) {
792+
err := s.Shutdown(ctx)
793+
if err != nil {
794+
assert.NoError(t, err)
795+
}
796+
}(server, context.Background())
797+
tests := []struct {
798+
description string
799+
endpoint string
800+
expectedCode int
801+
}{
802+
{
803+
description: "no tls test for scrape config",
804+
endpoint: "scrape_configs",
805+
expectedCode: http.StatusBadRequest,
806+
},
807+
{
808+
description: "no tls test for jobs",
809+
endpoint: "jobs",
810+
expectedCode: http.StatusBadRequest,
811+
},
812+
}
813+
for _, tc := range tests {
814+
t.Run(tc.description, func(t *testing.T) {
815+
request, err := http.Get(fmt.Sprintf("http://localhost%s/%s", listenAddr, tc.endpoint))
816+
817+
// Only check the status code if there was no error
818+
if err == nil {
819+
assert.Equal(t, tc.expectedCode, request.StatusCode)
820+
}
821+
})
822+
}
823+
}
824+
825+
func createTestTLSServer(listenAddr string) (*Server, *tls.Config, error) {
826+
//testing using this function replicates customer environment
827+
svrConfig := allocatorconfig.HTTPSServerConfig{}
828+
caBundle, caCert, caKey, err := generateTestingCerts()
829+
if err != nil {
830+
return nil, nil, err
831+
}
832+
svrConfig.TLSKeyFilePath = caKey
833+
svrConfig.TLSCertFilePath = caCert
834+
tlsConfig, err := svrConfig.NewTLSConfig(context.TODO())
835+
if err != nil {
836+
return nil, nil, err
837+
}
838+
httpOptions := []Option{}
839+
httpOptions = append(httpOptions, WithTLSConfig(tlsConfig, listenAddr))
840+
841+
//generate ca bundle
842+
bundle, err := readCABundle(caBundle)
843+
if err != nil {
844+
return nil, nil, err
845+
}
846+
allocator := &mockAllocator{targetItems: map[string]*target.Item{
847+
"a": target.NewItem("job1", "", model.LabelSet{}, ""),
848+
}}
849+
850+
return NewServer(logger, allocator, listenAddr, httpOptions...), bundle, nil
851+
}
852+
672853
func newLink(jobName string) target.LinkJSON {
673854
return target.LinkJSON{Link: fmt.Sprintf("/jobs/%s/targets", url.QueryEscape(jobName))}
674855
}
856+
857+
func readCABundle(caBundlePath string) (*tls.Config, error) {
858+
// Load the CA bundle
859+
caCert, err := os.ReadFile(caBundlePath)
860+
if err != nil {
861+
return nil, fmt.Errorf("failed to read CA bundle: %w", err)
862+
}
863+
864+
// Create a CA pool and add the CA certificate(s)
865+
caCertPool := x509.NewCertPool()
866+
if !caCertPool.AppendCertsFromPEM(caCert) {
867+
return nil, fmt.Errorf("failed to add CA certificates to pool")
868+
}
869+
870+
// Set up TLS configuration with the CA pool
871+
tlsConfig := &tls.Config{
872+
RootCAs: caCertPool,
873+
}
874+
return tlsConfig, nil
875+
}
876+
877+
func generateTestingCerts() (caBundlePath, caCertPath, caKeyPath string, err error) {
878+
// Generate private key
879+
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
880+
if err != nil {
881+
return "", "", "", fmt.Errorf("error generating private key: %w", err)
882+
}
883+
884+
// Set up certificate template
885+
template := x509.Certificate{
886+
SerialNumber: big.NewInt(1),
887+
Subject: pkix.Name{
888+
CommonName: "localhost",
889+
},
890+
NotBefore: time.Now(),
891+
NotAfter: time.Now().Add(365 * 24 * time.Hour), // 1 year validity
892+
KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature,
893+
ExtKeyUsage: []x509.ExtKeyUsage{
894+
x509.ExtKeyUsageServerAuth,
895+
},
896+
DNSNames: []string{"localhost"},
897+
}
898+
899+
// Self-sign the certificate
900+
certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey)
901+
if err != nil {
902+
return "", "", "", fmt.Errorf("error creating certificate: %w", err)
903+
}
904+
905+
// Create temporary files
906+
tempDir := os.TempDir()
907+
908+
caCertFile, err := os.CreateTemp(tempDir, "ca-cert-*.crt")
909+
if err != nil {
910+
return "", "", "", fmt.Errorf("error creating temp CA cert file: %w", err)
911+
}
912+
defer caCertFile.Close()
913+
914+
caKeyFile, err := os.CreateTemp(tempDir, "ca-key-*.key")
915+
if err != nil {
916+
return "", "", "", fmt.Errorf("error creating temp CA key file: %w", err)
917+
}
918+
defer caKeyFile.Close()
919+
920+
caBundleFile, err := os.CreateTemp(tempDir, "ca-bundle-*.crt")
921+
if err != nil {
922+
return "", "", "", fmt.Errorf("error creating temp CA bundle file: %w", err)
923+
}
924+
defer caBundleFile.Close()
925+
926+
// Write the private key to the key file
927+
privateKeyBytes, err := x509.MarshalECPrivateKey(privateKey)
928+
if err != nil {
929+
return "", "", "", fmt.Errorf("error writing private key: %w", err)
930+
}
931+
err = pem.Encode(caKeyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privateKeyBytes})
932+
if err != nil {
933+
return "", "", "", fmt.Errorf("error writing private key: %w", err)
934+
}
935+
936+
// Write the certificate to the certificate and bundle files
937+
certPEM := &pem.Block{Type: "CERTIFICATE", Bytes: certBytes}
938+
if err = pem.Encode(caCertFile, certPEM); err != nil {
939+
return "", "", "", fmt.Errorf("error writing certificate: %w", err)
940+
}
941+
if err = pem.Encode(caBundleFile, certPEM); err != nil {
942+
return "", "", "", fmt.Errorf("error writing bundle certificate: %w", err)
943+
}
944+
945+
// Return the file paths
946+
return caBundleFile.Name(), caCertFile.Name(), caKeyFile.Name(), nil
947+
}

0 commit comments

Comments
 (0)