Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 49 additions & 6 deletions cmd/webhook/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"flag"
"fmt"
"net/http"
"strings"

"github.com/open-policy-agent/cert-controller/pkg/rotator"
"k8s.io/apimachinery/pkg/runtime"
Expand Down Expand Up @@ -42,6 +43,7 @@ const (
var (
audience string
webhookCertDir string
tlsCipherSuites string
tlsMinVersion string
healthAddr string
metricsAddr string
Expand Down Expand Up @@ -73,6 +75,7 @@ func mainErr() error {
flag.StringVar(&audience, "audience", "", "Audience for service account token")
flag.StringVar(&webhookCertDir, "webhook-cert-dir", "/certs", "Webhook certificates dir to use. Defaults to /certs")
flag.BoolVar(&disableCertRotation, "disable-cert-rotation", false, "disable automatic generation and rotation of webhook TLS certificates/keys")
flag.StringVar(&tlsCipherSuites, "tls-cipher-suites", "", "Comma-separated list of TLS cipher suites")
flag.StringVar(&tlsMinVersion, "tls-min-version", "1.3", "Minimum TLS version")
flag.StringVar(&healthAddr, "health-addr", ":9440", "The address the health endpoint binds to")
flag.StringVar(&metricsAddr, "metrics-addr", ":8095", "The address the metrics endpoint binds to")
Expand Down Expand Up @@ -114,10 +117,20 @@ func mainErr() error {
if err != nil {
return fmt.Errorf("entrypoint: unable to parse TLS version: %w", err)
}
tlsOpts := []func(c *tls.Config){func(c *tls.Config) { c.MinVersion = tlsVersion }}

cipherSuites, err := parseTLSCipherSuites(tlsCipherSuites)
if err != nil {
return fmt.Errorf("entrypoint: unable to parse TLS cipher suites: %w", err)
}

if len(cipherSuites) > 0 {
tlsOpts = append(tlsOpts, func(c *tls.Config) { c.CipherSuites = cipherSuites })
}

serverOpts := webhook.Options{
CertDir: webhookCertDir,
TLSOpts: []func(c *tls.Config){func(c *tls.Config) { c.MinVersion = tlsVersion }},
TLSOpts: tlsOpts,
}
mgr, err := ctrl.NewManager(config, ctrl.Options{
Scheme: scheme,
Expand Down Expand Up @@ -207,15 +220,45 @@ func setupProbeEndpoints(mgr ctrl.Manager, setupFinished chan struct{}) {

func parseTLSVersion(tlsVersion string) (uint16, error) {
switch tlsVersion {
case "1.0":
case "1.0", "VersionTLS10":
return tls.VersionTLS10, nil
case "1.1":
case "1.1", "VersionTLS11":
return tls.VersionTLS11, nil
case "1.2":
case "1.2", "VersionTLS12":
return tls.VersionTLS12, nil
case "1.3":
case "1.3", "VersionTLS13":
return tls.VersionTLS13, nil
default:
return 0, fmt.Errorf("invalid TLS version. Must be one of: 1.0, 1.1, 1.2, 1.3")
return 0, fmt.Errorf("invalid TLS version. Must be one of: 1.0, 1.1, 1.2, 1.3, VersionTLS10, VersionTLS11, VersionTLS12, VersionTLS13")
}
}

func parseTLSCipherSuites(cipherSuites string) ([]uint16, error) {
if cipherSuites == "" {
return nil, nil
}

// Build a map of all available cipher suites
availableSuites := make(map[string]uint16)
for _, s := range tls.CipherSuites() {
availableSuites[s.Name] = s.ID
}
// Also include insecure suites just in case, though discouraged
for _, s := range tls.InsecureCipherSuites() {
availableSuites[s.Name] = s.ID
}

var ids []uint16
for _, name := range strings.Split(cipherSuites, ",") {
name = strings.TrimSpace(name)
if name == "" {
continue
}
id, ok := availableSuites[name]
if !ok {
return nil, fmt.Errorf("unsupported cipher suite: %s", name)
}
ids = append(ids, id)
}
return ids, nil
}
147 changes: 147 additions & 0 deletions cmd/webhook/main_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
package main

import (
"crypto/tls"
"reflect"
"testing"
)

func TestParseTLSVersion(t *testing.T) {
tests := []struct {
name string
version string
want uint16
wantErr bool
}{
{
name: "TLS 1.0",
version: "1.0",
want: tls.VersionTLS10,
wantErr: false,
},
{
name: "VersionTLS10",
version: "VersionTLS10",
want: tls.VersionTLS10,
wantErr: false,
},
{
name: "TLS 1.1",
version: "1.1",
want: tls.VersionTLS11,
wantErr: false,
},
{
name: "VersionTLS11",
version: "VersionTLS11",
want: tls.VersionTLS11,
wantErr: false,
},
{
name: "TLS 1.2",
version: "1.2",
want: tls.VersionTLS12,
wantErr: false,
},
{
name: "VersionTLS12",
version: "VersionTLS12",
want: tls.VersionTLS12,
wantErr: false,
},
{
name: "TLS 1.3",
version: "1.3",
want: tls.VersionTLS13,
wantErr: false,
},
{
name: "VersionTLS13",
version: "VersionTLS13",
want: tls.VersionTLS13,
wantErr: false,
},
{
name: "Invalid version",
version: "1.4",
want: 0,
wantErr: true,
},
{
name: "Empty version",
version: "",
want: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseTLSVersion(tt.version)
if (err != nil) != tt.wantErr {
t.Errorf("parseTLSVersion() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("parseTLSVersion() = %v, want %v", got, tt.want)
}
})
}
}

func TestParseTLSCipherSuites(t *testing.T) {
tests := []struct {
name string
cipherSuites string
want []uint16
wantErr bool
}{
{
name: "Empty cipher suites",
cipherSuites: "",
want: nil,
wantErr: false,
},
{
name: "Valid cipher suite",
cipherSuites: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
want: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256},
wantErr: false,
},
{
name: "Multiple valid cipher suites",
cipherSuites: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
want: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384},
wantErr: false,
},
{
name: "Valid cipher suites with spaces",
cipherSuites: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
want: []uint16{tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384},
wantErr: false,
},
{
name: "Invalid cipher suite",
cipherSuites: "INVALID_CIPHER",
want: nil,
wantErr: true,
},
{
name: "Mixed valid and invalid cipher suites",
cipherSuites: "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,INVALID_CIPHER",
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := parseTLSCipherSuites(tt.cipherSuites)
if (err != nil) != tt.wantErr {
t.Errorf("parseTLSCipherSuites() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("parseTLSCipherSuites() = %v, want %v", got, tt.want)
}
})
}
}