Skip to content

Commit 49c1705

Browse files
pierreprinettimdbooth
authored andcommitted
CHERRY-PICK: Refactoring: never assign unacceptable TLS versions
This commit makes security linting easier by never setting a TLS version outside v1.2 or v1.3, even in case of an unacceptable user input. Upstream PR: kubernetes-sigs#2037 (cherry picked from commit 27526d5)
1 parent 9753c5c commit 49c1705

File tree

2 files changed

+58
-33
lines changed

2 files changed

+58
-33
lines changed

main.go

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -372,14 +372,19 @@ func concurrency(c int) controller.Options {
372372
func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error) {
373373
var tlsOptions []func(config *tls.Config)
374374

375-
tlsMinVersion, err := GetTLSVersion(options.TLSMinVersion)
376-
if err != nil {
377-
return nil, err
378-
}
379-
380-
tlsMaxVersion, err := GetTLSVersion(options.TLSMaxVersion)
381-
if err != nil {
382-
return nil, err
375+
// To make a static analyzer happy, this block ensures there is no code
376+
// path that sets a TLS version outside the acceptable values, even in
377+
// case of unexpected user input.
378+
var tlsMinVersion, tlsMaxVersion uint16
379+
for version, option := range map[*uint16]string{&tlsMinVersion: options.TLSMinVersion, &tlsMaxVersion: options.TLSMaxVersion} {
380+
switch option {
381+
case TLSVersion12:
382+
*version = tls.VersionTLS12
383+
case TLSVersion13:
384+
*version = tls.VersionTLS13
385+
default:
386+
return nil, fmt.Errorf("unexpected TLS version %q (must be one of: %s)", option, strings.Join(tlsSupportedVersions, ", "))
387+
}
383388
}
384389

385390
if tlsMaxVersion != 0 && tlsMinVersion > tlsMaxVersion {
@@ -421,18 +426,3 @@ func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error)
421426

422427
return tlsOptions, nil
423428
}
424-
425-
// GetTLSVersion returns the corresponding tls.Version or error.
426-
func GetTLSVersion(version string) (uint16, error) {
427-
var v uint16
428-
429-
switch version {
430-
case TLSVersion12:
431-
v = tls.VersionTLS12
432-
case TLSVersion13:
433-
v = tls.VersionTLS13
434-
default:
435-
return 0, fmt.Errorf("unexpected TLS version %q (must be one of: %s)", version, strings.Join(tlsSupportedVersions, ", "))
436-
}
437-
return v, nil
438-
}

main_test.go

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package main
1818

1919
import (
2020
"bytes"
21+
"crypto/tls"
2122
"testing"
2223

2324
. "github.com/onsi/gomega"
@@ -75,25 +76,59 @@ func Test13CipherSuite(t *testing.T) {
7576
klog.SetOutput(bufWriter)
7677
klog.LogToStderr(false) // this is important, because klog by default logs to stderr only
7778
_, err := GetTLSOptionOverrideFuncs(tlsMockOptions)
78-
g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers"))
7979
g.Expect(err).Should(BeNil())
80+
g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers"))
8081
})
8182
}
8283

83-
func TestGetTLSVersion(t *testing.T) {
84-
t.Run("should error out when incorrect tls version passed", func(t *testing.T) {
84+
func TestGetTLSOverrideFuncs(t *testing.T) {
85+
t.Run("should error out when incorrect min tls version passed", func(t *testing.T) {
86+
g := NewWithT(t)
87+
_, err := GetTLSOptionOverrideFuncs(TLSOptions{
88+
TLSMinVersion: "TLS11",
89+
TLSMaxVersion: "TLS12",
90+
})
91+
g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)"))
92+
})
93+
t.Run("should error out when incorrect max tls version passed", func(t *testing.T) {
8594
g := NewWithT(t)
86-
tlsVersion := "TLS11"
87-
_, err := GetTLSVersion(tlsVersion)
95+
_, err := GetTLSOptionOverrideFuncs(TLSOptions{
96+
TLSMinVersion: "TLS12",
97+
TLSMaxVersion: "TLS11",
98+
})
8899
g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)"))
89100
})
90-
t.Run("should pass and output correct tls version", func(t *testing.T) {
91-
const VersionTLS12 uint16 = 771
101+
t.Run("should apply the requested TLS versions", func(t *testing.T) {
102+
g := NewWithT(t)
103+
tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(TLSOptions{
104+
TLSMinVersion: "TLS12",
105+
TLSMaxVersion: "TLS13",
106+
})
107+
108+
var tlsConfig tls.Config
109+
for _, apply := range tlsOptionOverrides {
110+
apply(&tlsConfig)
111+
}
112+
113+
g.Expect(err).Should(BeNil())
114+
g.Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS12)))
115+
g.Expect(tlsConfig.MaxVersion).To(Equal(uint16(tls.VersionTLS13)))
116+
})
117+
t.Run("should apply the requested non-default TLS versions", func(t *testing.T) {
92118
g := NewWithT(t)
93-
tlsVersion := "TLS12"
94-
version, err := GetTLSVersion(tlsVersion)
95-
g.Expect(version).To(Equal(VersionTLS12))
119+
tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(TLSOptions{
120+
TLSMinVersion: "TLS13",
121+
TLSMaxVersion: "TLS13",
122+
})
123+
124+
var tlsConfig tls.Config
125+
for _, apply := range tlsOptionOverrides {
126+
apply(&tlsConfig)
127+
}
128+
96129
g.Expect(err).Should(BeNil())
130+
g.Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS13)))
131+
g.Expect(tlsConfig.MaxVersion).To(Equal(uint16(tls.VersionTLS13)))
97132
})
98133
}
99134

0 commit comments

Comments
 (0)