Skip to content

Commit b82cfc5

Browse files
Add support for repacking and merging tls.Config structs (#196)
Add support for the Unpack operation for tls.Config attributes that are specified as strings but stored as other values. The existing Unpack operations are limited to string only and these will cause issues when trying to merge previously unpacked structs.
1 parent ccbaaef commit b82cfc5

File tree

3 files changed

+110
-36
lines changed

3 files changed

+110
-36
lines changed

transport/tlscommon/types.go

Lines changed: 61 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -159,28 +159,37 @@ func (m TLSVerificationMode) MarshalText() ([]byte, error) {
159159
return nil, fmt.Errorf("could not marshal '%+v' to text", m)
160160
}
161161

162-
// Unpack unpacks the string into constants.
162+
// Unpack unpacks the input into a TLSVerificationMode.
163163
func (m *TLSVerificationMode) Unpack(in interface{}) error {
164164
if in == nil {
165165
*m = VerifyFull
166166
return nil
167167
}
168168

169-
s, ok := in.(string)
170-
if !ok {
171-
return fmt.Errorf("verification mode must be an identifier")
172-
}
173-
if s == "" {
174-
*m = VerifyFull
175-
return nil
169+
switch o := in.(type) {
170+
case string:
171+
if o == "" {
172+
*m = VerifyFull
173+
return nil
174+
}
175+
176+
mode, found := tlsVerificationModes[o]
177+
if !found {
178+
return fmt.Errorf("unknown verification mode '%v'", o)
179+
}
180+
*m = mode
181+
case uint64:
182+
*m = TLSVerificationMode(o)
183+
default:
184+
return fmt.Errorf("verification mode is an unknown type: %T", o)
176185
}
186+
return nil
187+
}
177188

178-
mode, found := tlsVerificationModes[s]
179-
if !found {
180-
return fmt.Errorf("unknown verification mode '%v'", s)
189+
func (m *TLSVerificationMode) Validate() error {
190+
if *m > VerifyStrict {
191+
return fmt.Errorf("unsupported verification mode: %v", m)
181192
}
182-
183-
*m = mode
184193
return nil
185194
}
186195

@@ -214,13 +223,20 @@ func (m *TLSClientAuth) Unpack(s string) error {
214223

215224
type CipherSuite uint16
216225

217-
func (cs *CipherSuite) Unpack(s string) error {
218-
suite, found := tlsCipherSuites[s]
219-
if !found {
220-
return fmt.Errorf("invalid tls cipher suite '%v'", s)
226+
func (cs *CipherSuite) Unpack(i interface{}) error {
227+
switch o := i.(type) {
228+
case string:
229+
suite, found := tlsCipherSuites[o]
230+
if !found {
231+
return fmt.Errorf("invalid tls cipher suite '%v'", o)
232+
}
233+
234+
*cs = suite
235+
case uint64:
236+
*cs = CipherSuite(o)
237+
default:
238+
return fmt.Errorf("cipher suite is an unknown type: %T", o)
221239
}
222-
223-
*cs = suite
224240
return nil
225241
}
226242

@@ -233,13 +249,20 @@ func (cs CipherSuite) String() string {
233249

234250
type tlsCurveType tls.CurveID
235251

236-
func (ct *tlsCurveType) Unpack(s string) error {
237-
t, found := tlsCurveTypes[s]
238-
if !found {
239-
return fmt.Errorf("invalid tls curve type '%v'", s)
252+
func (ct *tlsCurveType) Unpack(i interface{}) error {
253+
switch o := i.(type) {
254+
case string:
255+
t, found := tlsCurveTypes[o]
256+
if !found {
257+
return fmt.Errorf("invalid tls curve type '%v'", o)
258+
}
259+
260+
*ct = t
261+
case uint64:
262+
*ct = tlsCurveType(o)
263+
default:
264+
return fmt.Errorf("tls curve type is an unsupported input type: %T", o)
240265
}
241-
242-
*ct = t
243266
return nil
244267
}
245268

@@ -252,13 +275,20 @@ func (r TLSRenegotiationSupport) String() string {
252275
return "<" + unknownType + ">"
253276
}
254277

255-
func (r *TLSRenegotiationSupport) Unpack(s string) error {
256-
t, found := tlsRenegotiationSupportTypes[s]
257-
if !found {
258-
return fmt.Errorf("invalid tls renegotiation type '%v'", s)
278+
func (r *TLSRenegotiationSupport) Unpack(i interface{}) error {
279+
switch o := i.(type) {
280+
case string:
281+
t, found := tlsRenegotiationSupportTypes[o]
282+
if !found {
283+
return fmt.Errorf("invalid tls renegotiation type '%v'", o)
284+
}
285+
286+
*r = t
287+
case uint64:
288+
*r = TLSRenegotiationSupport(o)
289+
default:
290+
return fmt.Errorf("tls renegotation support is an unknown type: %T", o)
259291
}
260-
261-
*r = t
262292
return nil
263293
}
264294

transport/tlscommon/types_test.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"testing"
2323

2424
"github.com/elastic/elastic-agent-libs/config"
25+
"github.com/elastic/go-ucfg"
2526
"github.com/stretchr/testify/assert"
2627

2728
"github.com/stretchr/testify/require"
@@ -69,6 +70,36 @@ func TestLoadWithEmptyVerificationMode(t *testing.T) {
6970
assert.Equal(t, cfg.VerificationMode, VerifyFull)
7071
}
7172

73+
func TestRepackConfig(t *testing.T) {
74+
cfg, err := load(`
75+
enabled: true
76+
verification_mode: certificate
77+
supported_protocols: [TLSv1.1, TLSv1.2]
78+
cipher_suites:
79+
- RSA-AES-256-CBC-SHA
80+
certificate_authorities:
81+
- /path/to/ca.crt
82+
certificate: /path/to/cert.crt
83+
key: /path/to/key.crt
84+
curve_types:
85+
- P-521
86+
renegotiation: freely
87+
ca_sha256:
88+
- example
89+
ca_trusted_fingerprint: fingerprint
90+
`)
91+
92+
assert.NoError(t, err)
93+
assert.Equal(t, cfg.VerificationMode, VerifyCertificate)
94+
95+
tmp, err := ucfg.NewFrom(cfg)
96+
assert.NoError(t, err)
97+
98+
err = tmp.Unpack(cfg)
99+
assert.NoError(t, err)
100+
assert.Equal(t, cfg.VerificationMode, VerifyCertificate)
101+
}
102+
72103
func TestTLSClientAuthUnpack(t *testing.T) {
73104
tests := []struct {
74105
val string

transport/tlscommon/versions.go

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,25 @@ func (v TLSVersion) Details() *TLSVersionDetails {
3838
}
3939

4040
// Unpack transforms the string into a constant.
41-
func (v *TLSVersion) Unpack(s string) error {
42-
version, found := tlsProtocolVersions[s]
43-
if !found {
44-
return fmt.Errorf("invalid tls version '%v'", s)
41+
func (v *TLSVersion) Unpack(i interface{}) error {
42+
switch o := i.(type) {
43+
case string:
44+
version, found := tlsProtocolVersions[o]
45+
if !found {
46+
return fmt.Errorf("invalid tls version '%v'", o)
47+
}
48+
*v = version
49+
case uint64:
50+
*v = TLSVersion(o)
51+
default:
52+
return fmt.Errorf("tls version is an unknown type: %T", o)
4553
}
54+
return nil
55+
}
4656

47-
*v = version
57+
func (v *TLSVersion) Validate() error {
58+
if *v < TLSVersionMin || *v > TLSVersionMax {
59+
return fmt.Errorf("unsupported tls version: %v", v)
60+
}
4861
return nil
4962
}

0 commit comments

Comments
 (0)