Skip to content

Commit ba87c35

Browse files
authored
Merge pull request #39 from nunnatsa/fix-ai
TLS flags refactoring
2 parents 39a076d + b75f7d6 commit ba87c35

File tree

4 files changed

+220
-37
lines changed

4 files changed

+220
-37
lines changed

.github/workflows/ci_checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,5 @@ jobs:
2222
- name: Build
2323
run: go build .
2424
- name: Run Unit Tests
25-
run: go test .
25+
run: go test ./...
2626

config/config.go

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
package config
2+
3+
import (
4+
"flag"
5+
"fmt"
6+
"math"
7+
"strconv"
8+
"strings"
9+
)
10+
11+
type Config struct {
12+
TLS TLSConfig
13+
}
14+
15+
type TLSConfig struct {
16+
minTLSVersion uint16
17+
tlsCipherSuites []uint16
18+
}
19+
20+
var (
21+
minTLSVersionFlag = flag.Uint("tls-min-version", 0, "The minimum TLS version to use")
22+
tlsCipherSuitesFlag = flag.String("tls-cipher-suites", "", "A comma-separated list of cipher suites to use")
23+
)
24+
25+
func GetConfig() (*Config, error) {
26+
cfg := &Config{}
27+
28+
if *minTLSVersionFlag > 0 {
29+
if *minTLSVersionFlag > math.MaxUint16 {
30+
return nil, fmt.Errorf("the --tls-min-version flag is with a wrong value: %d is lager than the max allowed value of %d", *minTLSVersionFlag, math.MaxUint16)
31+
}
32+
33+
cfg.TLS.minTLSVersion = uint16(*minTLSVersionFlag)
34+
}
35+
36+
if *tlsCipherSuitesFlag != "" {
37+
ciphers := strings.Split(*tlsCipherSuitesFlag, ",")
38+
tlsCipherSuites := make([]uint16, 0, len(ciphers))
39+
40+
for _, cipherStr := range ciphers {
41+
cipherStr = strings.TrimSpace(cipherStr)
42+
cipher, err := strconv.ParseUint(cipherStr, 10, 16)
43+
if err != nil {
44+
return nil, fmt.Errorf("can't parse cipher %q; %w", cipherStr, err)
45+
}
46+
47+
tlsCipherSuites = append(tlsCipherSuites, uint16(cipher))
48+
}
49+
50+
cfg.TLS.tlsCipherSuites = tlsCipherSuites
51+
}
52+
53+
return cfg, nil
54+
}
55+
56+
func (cfg *Config) GetMinTLSVersion() uint16 {
57+
return cfg.TLS.minTLSVersion
58+
}
59+
60+
func (cfg *Config) GetTLSCipherSuites() []uint16 {
61+
return cfg.TLS.tlsCipherSuites
62+
}

config/config_test.go

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
package config
2+
3+
import (
4+
"flag"
5+
"fmt"
6+
"math"
7+
"reflect"
8+
"testing"
9+
)
10+
11+
func TestGetTLSCipherSuites(t *testing.T) {
12+
for _, tc := range []struct {
13+
name string
14+
flagVal string
15+
want []uint16
16+
}{
17+
{name: "valid input", flagVal: "1,2,3,4", want: []uint16{1, 2, 3, 4}},
18+
{name: "no flag", flagVal: "", want: nil},
19+
{name: "max val", flagVal: fmt.Sprintf("%d", math.MaxUint16), want: []uint16{math.MaxUint16}},
20+
} {
21+
t.Run(tc.flagVal, func(t *testing.T) {
22+
err := setFlags("0", tc.flagVal)
23+
if err != nil {
24+
t.Fatal(err)
25+
}
26+
27+
cfg, err := GetConfig()
28+
if err != nil {
29+
t.Fatal(err)
30+
}
31+
32+
if got, want := cfg.GetTLSCipherSuites(), tc.want; !reflect.DeepEqual(got, want) {
33+
t.Errorf("GetTLSCipherSuites() = %v, want %v", got, want)
34+
}
35+
})
36+
}
37+
}
38+
39+
func TestWrongGetTLSCipherSuites_tooLarge(t *testing.T) {
40+
err := setFlags("0", fmt.Sprintf("%d", math.MaxUint16+1))
41+
if err != nil {
42+
t.Fatal(err)
43+
}
44+
45+
_, err = GetConfig()
46+
if err == nil {
47+
t.Fatal("error should have errored")
48+
}
49+
50+
t.Logf("got expected error: %v", err)
51+
}
52+
53+
func TestWrongGetTLSCipherSuites_negative(t *testing.T) {
54+
err := setFlags("0", "-42")
55+
if err != nil {
56+
t.Fatal(err)
57+
}
58+
59+
_, err = GetConfig()
60+
if err == nil {
61+
t.Fatal("error should have errored")
62+
}
63+
64+
t.Logf("got expected error: %v", err)
65+
}
66+
67+
func TestWrongGetTLSCipherSuites_notNum(t *testing.T) {
68+
err := setFlags("0", "not a number")
69+
if err != nil {
70+
t.Fatal(err)
71+
}
72+
73+
_, err = GetConfig()
74+
if err == nil {
75+
t.Fatal("error should have errored")
76+
}
77+
78+
t.Logf("got expected error: %v", err)
79+
}
80+
81+
func TestWrongGetTLSCipherSuites_float(t *testing.T) {
82+
err := setFlags("0", "1.42")
83+
if err != nil {
84+
t.Fatal(err)
85+
}
86+
87+
_, err = GetConfig()
88+
if err == nil {
89+
t.Fatal("error should have errored")
90+
}
91+
92+
t.Logf("got expected error: %v", err)
93+
}
94+
95+
func TestGetMinTLSVersion(t *testing.T) {
96+
for _, tc := range []struct {
97+
name string
98+
flagVal string
99+
want uint16
100+
}{
101+
{name: "valid input", flagVal: "42", want: 42},
102+
{name: "no flag", flagVal: "0", want: 0},
103+
{name: "max val", flagVal: fmt.Sprintf("%d", math.MaxUint16), want: math.MaxUint16},
104+
} {
105+
t.Run(tc.flagVal, func(t *testing.T) {
106+
err := setFlags(tc.flagVal, "")
107+
if err != nil {
108+
t.Fatal(err)
109+
}
110+
111+
cfg, err := GetConfig()
112+
if err != nil {
113+
t.Fatal(err)
114+
}
115+
116+
if got, want := cfg.GetMinTLSVersion(), tc.want; !reflect.DeepEqual(got, want) {
117+
t.Errorf("GetMinTLSVersion() = %d, want %d", got, want)
118+
}
119+
})
120+
}
121+
}
122+
123+
func TestWrongGetMinTLSVersion_tooLarge(t *testing.T) {
124+
err := setFlags(fmt.Sprintf("%d", math.MaxUint16+1), "")
125+
if err != nil {
126+
t.Fatal(err)
127+
}
128+
129+
_, err = GetConfig()
130+
if err == nil {
131+
t.Fatal("error should have errored")
132+
}
133+
134+
t.Logf("got expected error: %v", err)
135+
}
136+
137+
func setFlags(minVer, ciphers string) error {
138+
err := flag.Set("tls-min-version", minVer)
139+
if err != nil {
140+
return err
141+
}
142+
err = flag.Set("tls-cipher-suites", ciphers)
143+
if err != nil {
144+
return err
145+
}
146+
147+
return nil
148+
}

main.go

Lines changed: 9 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,16 @@ package main
33
import (
44
"crypto/tls"
55
"flag"
6-
"fmt"
76
"log"
87
"net/http"
98
"os"
10-
"strconv"
11-
"strings"
129
"time"
1310

1411
cache "github.com/chenyahui/gin-cache"
1512
"github.com/chenyahui/gin-cache/persist"
1613
"github.com/gin-contrib/gzip"
1714
"github.com/gin-gonic/gin"
15+
"github.com/kubevirt-ui/kubevirt-apiserver-proxy/config"
1816
"github.com/kubevirt-ui/kubevirt-apiserver-proxy/handlers"
1917
)
2018

@@ -23,37 +21,13 @@ const (
2321
apiCacheTime = 15 * time.Second
2422
)
2523

26-
var (
27-
minTLSVersion uint16
28-
tlsCipherSuites []uint16
29-
30-
minTLSVersionFlag = flag.Uint("tls-min-version", 0, "The minimum TLS version to use")
31-
tlsCipherSuitesFlag = flag.String("tls-cipher-suites", "", "A comma-separated list of cipher suites to use")
32-
)
33-
34-
func init() {
24+
func main() {
3525
flag.Parse()
3626

37-
if *minTLSVersionFlag > 0 {
38-
minTLSVersion = uint16(*minTLSVersionFlag)
39-
}
40-
41-
if *tlsCipherSuitesFlag != "" {
42-
ciphers := strings.Split(*tlsCipherSuitesFlag, ",")
43-
tlsCipherSuites = make([]uint16, 0, len(ciphers))
44-
45-
for _, cipherStr := range ciphers {
46-
cipher, err := strconv.ParseUint(cipherStr, 10, 16)
47-
if err != nil {
48-
panic(fmt.Errorf("can't parse cipher %q; %w", cipherStr, err))
49-
}
50-
51-
tlsCipherSuites = append(tlsCipherSuites, uint16(cipher))
52-
}
27+
cfg, err := config.GetConfig()
28+
if err != nil {
29+
log.Fatal(err)
5330
}
54-
}
55-
56-
func main() {
5731

5832
router := gin.Default()
5933

@@ -72,17 +46,16 @@ func main() {
7246
},
7347
}
7448

75-
if minTLSVersion != 0 {
76-
server.TLSConfig.MinVersion = minTLSVersion
49+
if minTLSVer := cfg.GetMinTLSVersion(); minTLSVer != 0 {
50+
server.TLSConfig.MinVersion = minTLSVer
7751
}
7852

79-
if len(tlsCipherSuites) > 0 {
80-
server.TLSConfig.CipherSuites = tlsCipherSuites
53+
if ciphers := cfg.GetTLSCipherSuites(); len(ciphers) > 0 {
54+
server.TLSConfig.CipherSuites = ciphers
8155
}
8256

8357
log.Printf("listening for server 8080 - v0.0.10 - API cache time: %v", apiCacheTime)
8458

85-
var err error
8659
if os.Getenv("APP_ENV") == "dev" {
8760
err = server.ListenAndServe()
8861
} else {

0 commit comments

Comments
 (0)