|
4 | 4 | package server |
5 | 5 |
|
6 | 6 | import ( |
| 7 | + "context" |
| 8 | + "crypto/ecdsa" |
| 9 | + "crypto/elliptic" |
| 10 | + "crypto/rand" |
7 | 11 | "crypto/tls" |
| 12 | + "crypto/x509" |
| 13 | + "crypto/x509/pkix" |
8 | 14 | "encoding/json" |
| 15 | + "encoding/pem" |
9 | 16 | "fmt" |
10 | 17 | "io" |
| 18 | + "math/big" |
11 | 19 | "net/http" |
12 | 20 | "net/http/httptest" |
13 | 21 | "net/url" |
| 22 | + "os" |
14 | 23 | "testing" |
15 | 24 | "time" |
16 | 25 |
|
@@ -185,7 +194,7 @@ func TestServer_TargetsHandler(t *testing.T) { |
185 | 194 |
|
186 | 195 | func TestServer_ScrapeConfigsHandler(t *testing.T) { |
187 | 196 | svrConfig := allocatorconfig.HTTPSServerConfig{} |
188 | | - tlsConfig, _ := svrConfig.NewTLSConfig() |
| 197 | + tlsConfig, _ := svrConfig.NewTLSConfig(context.TODO()) |
189 | 198 | tests := []struct { |
190 | 199 | description string |
191 | 200 | scrapeConfigs map[string]*promconfig.ScrapeConfig |
@@ -605,6 +614,7 @@ func TestServer_JobHandler(t *testing.T) { |
605 | 614 | }) |
606 | 615 | } |
607 | 616 | } |
| 617 | + |
608 | 618 | func TestServer_Readiness(t *testing.T) { |
609 | 619 | tests := []struct { |
610 | 620 | description string |
@@ -669,6 +679,269 @@ func TestServer_Readiness(t *testing.T) { |
669 | 679 | } |
670 | 680 | } |
671 | 681 |
|
| 682 | +func TestServer_ValidCAonTLS(t *testing.T) { |
| 683 | + listenAddr := ":8443" |
| 684 | + server, clientTlsConfig, err := createTestTLSServer(listenAddr) |
| 685 | + assert.NoError(t, err) |
| 686 | + go func() { |
| 687 | + assert.ErrorIs(t, server.StartHTTPS(), http.ErrServerClosed) |
| 688 | + }() |
| 689 | + time.Sleep(100 * time.Millisecond) // wait for server to launch |
| 690 | + defer func() { |
| 691 | + err := server.ShutdownHTTPS(context.Background()) |
| 692 | + if err != nil { |
| 693 | + assert.NoError(t, err) |
| 694 | + } |
| 695 | + }() |
| 696 | + tests := []struct { |
| 697 | + description string |
| 698 | + endpoint string |
| 699 | + expectedCode int |
| 700 | + }{ |
| 701 | + { |
| 702 | + description: "with tls test for scrape config", |
| 703 | + endpoint: "scrape_configs", |
| 704 | + expectedCode: http.StatusOK, |
| 705 | + }, |
| 706 | + { |
| 707 | + description: "with tls test for jobs", |
| 708 | + endpoint: "jobs", |
| 709 | + expectedCode: http.StatusOK, |
| 710 | + }, |
| 711 | + } |
| 712 | + for _, tc := range tests { |
| 713 | + t.Run(tc.description, func(t *testing.T) { |
| 714 | + // Create a custom HTTP client with TLS transport |
| 715 | + client := &http.Client{ |
| 716 | + Transport: &http.Transport{ |
| 717 | + TLSClientConfig: clientTlsConfig, |
| 718 | + }, |
| 719 | + } |
| 720 | + |
| 721 | + // Make the GET request |
| 722 | + request, err := client.Get(fmt.Sprintf("https://localhost%s/%s", listenAddr, tc.endpoint)) |
| 723 | + |
| 724 | + // Verify if a certificate verification error occurred |
| 725 | + require.NoError(t, err) |
| 726 | + |
| 727 | + // Only check the status code if there was no error |
| 728 | + if err == nil { |
| 729 | + assert.Equal(t, tc.expectedCode, request.StatusCode) |
| 730 | + } else { |
| 731 | + t.Log(err) |
| 732 | + } |
| 733 | + }) |
| 734 | + } |
| 735 | +} |
| 736 | + |
| 737 | +func TestServer_MissingCAonTLS(t *testing.T) { |
| 738 | + listenAddr := ":8443" |
| 739 | + server, _, err := createTestTLSServer(listenAddr) |
| 740 | + assert.NoError(t, err) |
| 741 | + go func() { |
| 742 | + assert.ErrorIs(t, server.StartHTTPS(), http.ErrServerClosed) |
| 743 | + }() |
| 744 | + time.Sleep(100 * time.Millisecond) // wait for server to launch |
| 745 | + defer func() { |
| 746 | + err := server.ShutdownHTTPS(context.Background()) |
| 747 | + if err != nil { |
| 748 | + assert.NoError(t, err) |
| 749 | + } |
| 750 | + }() |
| 751 | + tests := []struct { |
| 752 | + description string |
| 753 | + endpoint string |
| 754 | + expectedCode int |
| 755 | + }{ |
| 756 | + { |
| 757 | + description: "no tls test for scrape config", |
| 758 | + endpoint: "scrape_configs", |
| 759 | + expectedCode: http.StatusBadRequest, |
| 760 | + }, |
| 761 | + { |
| 762 | + description: "no tls test for jobs", |
| 763 | + endpoint: "jobs", |
| 764 | + expectedCode: http.StatusBadRequest, |
| 765 | + }, |
| 766 | + } |
| 767 | + for _, tc := range tests { |
| 768 | + t.Run(tc.description, func(t *testing.T) { |
| 769 | + request, err := http.Get(fmt.Sprintf("https://localhost%s/%s", listenAddr, tc.endpoint)) |
| 770 | + |
| 771 | + // Verify if a certificate verification error occurred |
| 772 | + require.Error(t, err) |
| 773 | + |
| 774 | + // Only check the status code if there was no error |
| 775 | + if err == nil { |
| 776 | + assert.Equal(t, tc.expectedCode, request.StatusCode) |
| 777 | + } |
| 778 | + }) |
| 779 | + } |
| 780 | +} |
| 781 | + |
| 782 | +func TestServer_HTTPOnTLS(t *testing.T) { |
| 783 | + listenAddr := ":8443" |
| 784 | + server, _, err := createTestTLSServer(listenAddr) |
| 785 | + assert.NoError(t, err) |
| 786 | + go func() { |
| 787 | + assert.NoError(t, server.StartHTTPS()) |
| 788 | + }() |
| 789 | + time.Sleep(100 * time.Millisecond) // wait for server to launch |
| 790 | + |
| 791 | + defer func(s *Server, ctx context.Context) { |
| 792 | + err := s.Shutdown(ctx) |
| 793 | + if err != nil { |
| 794 | + assert.NoError(t, err) |
| 795 | + } |
| 796 | + }(server, context.Background()) |
| 797 | + tests := []struct { |
| 798 | + description string |
| 799 | + endpoint string |
| 800 | + expectedCode int |
| 801 | + }{ |
| 802 | + { |
| 803 | + description: "no tls test for scrape config", |
| 804 | + endpoint: "scrape_configs", |
| 805 | + expectedCode: http.StatusBadRequest, |
| 806 | + }, |
| 807 | + { |
| 808 | + description: "no tls test for jobs", |
| 809 | + endpoint: "jobs", |
| 810 | + expectedCode: http.StatusBadRequest, |
| 811 | + }, |
| 812 | + } |
| 813 | + for _, tc := range tests { |
| 814 | + t.Run(tc.description, func(t *testing.T) { |
| 815 | + request, err := http.Get(fmt.Sprintf("http://localhost%s/%s", listenAddr, tc.endpoint)) |
| 816 | + |
| 817 | + // Only check the status code if there was no error |
| 818 | + if err == nil { |
| 819 | + assert.Equal(t, tc.expectedCode, request.StatusCode) |
| 820 | + } |
| 821 | + }) |
| 822 | + } |
| 823 | +} |
| 824 | + |
| 825 | +func createTestTLSServer(listenAddr string) (*Server, *tls.Config, error) { |
| 826 | + //testing using this function replicates customer environment |
| 827 | + svrConfig := allocatorconfig.HTTPSServerConfig{} |
| 828 | + caBundle, caCert, caKey, err := generateTestingCerts() |
| 829 | + if err != nil { |
| 830 | + return nil, nil, err |
| 831 | + } |
| 832 | + svrConfig.TLSKeyFilePath = caKey |
| 833 | + svrConfig.TLSCertFilePath = caCert |
| 834 | + tlsConfig, err := svrConfig.NewTLSConfig(context.TODO()) |
| 835 | + if err != nil { |
| 836 | + return nil, nil, err |
| 837 | + } |
| 838 | + httpOptions := []Option{} |
| 839 | + httpOptions = append(httpOptions, WithTLSConfig(tlsConfig, listenAddr)) |
| 840 | + |
| 841 | + //generate ca bundle |
| 842 | + bundle, err := readCABundle(caBundle) |
| 843 | + if err != nil { |
| 844 | + return nil, nil, err |
| 845 | + } |
| 846 | + allocator := &mockAllocator{targetItems: map[string]*target.Item{ |
| 847 | + "a": target.NewItem("job1", "", model.LabelSet{}, ""), |
| 848 | + }} |
| 849 | + |
| 850 | + return NewServer(logger, allocator, listenAddr, httpOptions...), bundle, nil |
| 851 | +} |
| 852 | + |
672 | 853 | func newLink(jobName string) target.LinkJSON { |
673 | 854 | return target.LinkJSON{Link: fmt.Sprintf("/jobs/%s/targets", url.QueryEscape(jobName))} |
674 | 855 | } |
| 856 | + |
| 857 | +func readCABundle(caBundlePath string) (*tls.Config, error) { |
| 858 | + // Load the CA bundle |
| 859 | + caCert, err := os.ReadFile(caBundlePath) |
| 860 | + if err != nil { |
| 861 | + return nil, fmt.Errorf("failed to read CA bundle: %w", err) |
| 862 | + } |
| 863 | + |
| 864 | + // Create a CA pool and add the CA certificate(s) |
| 865 | + caCertPool := x509.NewCertPool() |
| 866 | + if !caCertPool.AppendCertsFromPEM(caCert) { |
| 867 | + return nil, fmt.Errorf("failed to add CA certificates to pool") |
| 868 | + } |
| 869 | + |
| 870 | + // Set up TLS configuration with the CA pool |
| 871 | + tlsConfig := &tls.Config{ |
| 872 | + RootCAs: caCertPool, |
| 873 | + } |
| 874 | + return tlsConfig, nil |
| 875 | +} |
| 876 | + |
| 877 | +func generateTestingCerts() (caBundlePath, caCertPath, caKeyPath string, err error) { |
| 878 | + // Generate private key |
| 879 | + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) |
| 880 | + if err != nil { |
| 881 | + return "", "", "", fmt.Errorf("error generating private key: %w", err) |
| 882 | + } |
| 883 | + |
| 884 | + // Set up certificate template |
| 885 | + template := x509.Certificate{ |
| 886 | + SerialNumber: big.NewInt(1), |
| 887 | + Subject: pkix.Name{ |
| 888 | + CommonName: "localhost", |
| 889 | + }, |
| 890 | + NotBefore: time.Now(), |
| 891 | + NotAfter: time.Now().Add(365 * 24 * time.Hour), // 1 year validity |
| 892 | + KeyUsage: x509.KeyUsageKeyEncipherment | x509.KeyUsageDigitalSignature, |
| 893 | + ExtKeyUsage: []x509.ExtKeyUsage{ |
| 894 | + x509.ExtKeyUsageServerAuth, |
| 895 | + }, |
| 896 | + DNSNames: []string{"localhost"}, |
| 897 | + } |
| 898 | + |
| 899 | + // Self-sign the certificate |
| 900 | + certBytes, err := x509.CreateCertificate(rand.Reader, &template, &template, &privateKey.PublicKey, privateKey) |
| 901 | + if err != nil { |
| 902 | + return "", "", "", fmt.Errorf("error creating certificate: %w", err) |
| 903 | + } |
| 904 | + |
| 905 | + // Create temporary files |
| 906 | + tempDir := os.TempDir() |
| 907 | + |
| 908 | + caCertFile, err := os.CreateTemp(tempDir, "ca-cert-*.crt") |
| 909 | + if err != nil { |
| 910 | + return "", "", "", fmt.Errorf("error creating temp CA cert file: %w", err) |
| 911 | + } |
| 912 | + defer caCertFile.Close() |
| 913 | + |
| 914 | + caKeyFile, err := os.CreateTemp(tempDir, "ca-key-*.key") |
| 915 | + if err != nil { |
| 916 | + return "", "", "", fmt.Errorf("error creating temp CA key file: %w", err) |
| 917 | + } |
| 918 | + defer caKeyFile.Close() |
| 919 | + |
| 920 | + caBundleFile, err := os.CreateTemp(tempDir, "ca-bundle-*.crt") |
| 921 | + if err != nil { |
| 922 | + return "", "", "", fmt.Errorf("error creating temp CA bundle file: %w", err) |
| 923 | + } |
| 924 | + defer caBundleFile.Close() |
| 925 | + |
| 926 | + // Write the private key to the key file |
| 927 | + privateKeyBytes, err := x509.MarshalECPrivateKey(privateKey) |
| 928 | + if err != nil { |
| 929 | + return "", "", "", fmt.Errorf("error writing private key: %w", err) |
| 930 | + } |
| 931 | + err = pem.Encode(caKeyFile, &pem.Block{Type: "EC PRIVATE KEY", Bytes: privateKeyBytes}) |
| 932 | + if err != nil { |
| 933 | + return "", "", "", fmt.Errorf("error writing private key: %w", err) |
| 934 | + } |
| 935 | + |
| 936 | + // Write the certificate to the certificate and bundle files |
| 937 | + certPEM := &pem.Block{Type: "CERTIFICATE", Bytes: certBytes} |
| 938 | + if err = pem.Encode(caCertFile, certPEM); err != nil { |
| 939 | + return "", "", "", fmt.Errorf("error writing certificate: %w", err) |
| 940 | + } |
| 941 | + if err = pem.Encode(caBundleFile, certPEM); err != nil { |
| 942 | + return "", "", "", fmt.Errorf("error writing bundle certificate: %w", err) |
| 943 | + } |
| 944 | + |
| 945 | + // Return the file paths |
| 946 | + return caBundleFile.Name(), caCertFile.Name(), caKeyFile.Name(), nil |
| 947 | +} |
0 commit comments