Skip to content

Commit 1960fbb

Browse files
authored
Merge pull request kubernetes-sigs#2062 from k8s-infra-cherrypick-robot/cherry-pick-2037-to-release-0.10
[release-0.10] 🌱 Refactoring: never assign unacceptable TLS versions
2 parents f049a35 + 7731959 commit 1960fbb

File tree

2 files changed

+58
-33
lines changed

2 files changed

+58
-33
lines changed

main.go

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -373,14 +373,19 @@ func concurrency(c int) controller.Options {
373373
func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error) {
374374
var tlsOptions []func(config *tls.Config)
375375

376-
tlsMinVersion, err := GetTLSVersion(options.TLSMinVersion)
377-
if err != nil {
378-
return nil, err
379-
}
380-
381-
tlsMaxVersion, err := GetTLSVersion(options.TLSMaxVersion)
382-
if err != nil {
383-
return nil, err
376+
// To make a static analyzer happy, this block ensures there is no code
377+
// path that sets a TLS version outside the acceptable values, even in
378+
// case of unexpected user input.
379+
var tlsMinVersion, tlsMaxVersion uint16
380+
for version, option := range map[*uint16]string{&tlsMinVersion: options.TLSMinVersion, &tlsMaxVersion: options.TLSMaxVersion} {
381+
switch option {
382+
case TLSVersion12:
383+
*version = tls.VersionTLS12
384+
case TLSVersion13:
385+
*version = tls.VersionTLS13
386+
default:
387+
return nil, fmt.Errorf("unexpected TLS version %q (must be one of: %s)", option, strings.Join(tlsSupportedVersions, ", "))
388+
}
384389
}
385390

386391
if tlsMaxVersion != 0 && tlsMinVersion > tlsMaxVersion {
@@ -422,18 +427,3 @@ func GetTLSOptionOverrideFuncs(options TLSOptions) ([]func(*tls.Config), error)
422427

423428
return tlsOptions, nil
424429
}
425-
426-
// GetTLSVersion returns the corresponding tls.Version or error.
427-
func GetTLSVersion(version string) (uint16, error) {
428-
var v uint16
429-
430-
switch version {
431-
case TLSVersion12:
432-
v = tls.VersionTLS12
433-
case TLSVersion13:
434-
v = tls.VersionTLS13
435-
default:
436-
return 0, fmt.Errorf("unexpected TLS version %q (must be one of: %s)", version, strings.Join(tlsSupportedVersions, ", "))
437-
}
438-
return v, nil
439-
}

main_test.go

Lines changed: 45 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ package main
1818

1919
import (
2020
"bytes"
21+
"crypto/tls"
2122
"testing"
2223

2324
. "github.com/onsi/gomega"
@@ -75,25 +76,59 @@ func Test13CipherSuite(t *testing.T) {
7576
klog.SetOutput(bufWriter)
7677
klog.LogToStderr(false) // this is important, because klog by default logs to stderr only
7778
_, err := GetTLSOptionOverrideFuncs(tlsMockOptions)
78-
g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers"))
7979
g.Expect(err).Should(BeNil())
80+
g.Expect(bufWriter.String()).Should(ContainSubstring("warning: Cipher suites should not be set for TLS version 1.3. Ignoring ciphers"))
8081
})
8182
}
8283

83-
func TestGetTLSVersion(t *testing.T) {
84-
t.Run("should error out when incorrect tls version passed", func(t *testing.T) {
84+
func TestGetTLSOverrideFuncs(t *testing.T) {
85+
t.Run("should error out when incorrect min tls version passed", func(t *testing.T) {
86+
g := NewWithT(t)
87+
_, err := GetTLSOptionOverrideFuncs(TLSOptions{
88+
TLSMinVersion: "TLS11",
89+
TLSMaxVersion: "TLS12",
90+
})
91+
g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)"))
92+
})
93+
t.Run("should error out when incorrect max tls version passed", func(t *testing.T) {
8594
g := NewWithT(t)
86-
tlsVersion := "TLS11"
87-
_, err := GetTLSVersion(tlsVersion)
95+
_, err := GetTLSOptionOverrideFuncs(TLSOptions{
96+
TLSMinVersion: "TLS12",
97+
TLSMaxVersion: "TLS11",
98+
})
8899
g.Expect(err.Error()).Should(Equal("unexpected TLS version \"TLS11\" (must be one of: TLS12, TLS13)"))
89100
})
90-
t.Run("should pass and output correct tls version", func(t *testing.T) {
91-
const VersionTLS12 uint16 = 771
101+
t.Run("should apply the requested TLS versions", func(t *testing.T) {
102+
g := NewWithT(t)
103+
tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(TLSOptions{
104+
TLSMinVersion: "TLS12",
105+
TLSMaxVersion: "TLS13",
106+
})
107+
108+
var tlsConfig tls.Config
109+
for _, apply := range tlsOptionOverrides {
110+
apply(&tlsConfig)
111+
}
112+
113+
g.Expect(err).Should(BeNil())
114+
g.Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS12)))
115+
g.Expect(tlsConfig.MaxVersion).To(Equal(uint16(tls.VersionTLS13)))
116+
})
117+
t.Run("should apply the requested non-default TLS versions", func(t *testing.T) {
92118
g := NewWithT(t)
93-
tlsVersion := "TLS12"
94-
version, err := GetTLSVersion(tlsVersion)
95-
g.Expect(version).To(Equal(VersionTLS12))
119+
tlsOptionOverrides, err := GetTLSOptionOverrideFuncs(TLSOptions{
120+
TLSMinVersion: "TLS13",
121+
TLSMaxVersion: "TLS13",
122+
})
123+
124+
var tlsConfig tls.Config
125+
for _, apply := range tlsOptionOverrides {
126+
apply(&tlsConfig)
127+
}
128+
96129
g.Expect(err).Should(BeNil())
130+
g.Expect(tlsConfig.MinVersion).To(Equal(uint16(tls.VersionTLS13)))
131+
g.Expect(tlsConfig.MaxVersion).To(Equal(uint16(tls.VersionTLS13)))
97132
})
98133
}
99134

0 commit comments

Comments
 (0)