Skip to content

Commit 2e6bf7d

Browse files
author
Divjot Arora
committed
Set SSL options if they're not empty strings or nil.
GODRIVER-688 Change-Id: I4c3244b0fc5314a8322f059b78c7dffb6db1e230
1 parent 55ea16d commit 2e6bf7d

File tree

2 files changed

+52
-7
lines changed

2 files changed

+52
-7
lines changed

mongo/client_options_test.go

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/mongodb/mongo-go-driver/x/mongo/driver/topology"
2525
"github.com/mongodb/mongo-go-driver/x/network/connstring"
2626
"github.com/stretchr/testify/require"
27+
"reflect"
2728
)
2829

2930
func TestClientOptions_simple(t *testing.T) {
@@ -144,7 +145,7 @@ func TestClientOptions_chainAll(t *testing.T) {
144145
SSLClientCertificateKeyFile: "client.pem",
145146
SSLClientCertificateKeyFileSet: true,
146147
SSLClientCertificateKeyPassword: nil,
147-
SSLClientCertificateKeyPasswordSet: true,
148+
SSLClientCertificateKeyPasswordSet: false, // will not be set if it's nil
148149
SSLInsecure: false,
149150
SSLInsecureSet: true,
150151
SSLCaFile: "ca.pem",
@@ -159,6 +160,44 @@ func TestClientOptions_chainAll(t *testing.T) {
159160
require.Equal(t, expectedClient, opts)
160161
}
161162

163+
func TestClientOptions_sslOptions(t *testing.T) {
164+
t.Parallel()
165+
166+
t.Run("TestEmptyOptionsNotSet", func(t *testing.T) {
167+
ssl := &options.SSLOpt{}
168+
c, err := NewClientWithOptions("mongodb://localhost", options.Client().SetSSL(ssl))
169+
require.NoError(t, err)
170+
171+
require.Equal(t, c.connString.SSLClientCertificateKeyFile, "")
172+
require.Equal(t, c.connString.SSLClientCertificateKeyFileSet, false)
173+
require.Nil(t, c.connString.SSLClientCertificateKeyPassword)
174+
require.Equal(t, c.connString.SSLClientCertificateKeyPasswordSet, false)
175+
require.Equal(t, c.connString.SSLCaFile, "")
176+
require.Equal(t, c.connString.SSLCaFileSet, false)
177+
})
178+
179+
t.Run("TestNonEmptyOptionsSet", func(t *testing.T) {
180+
f := func() string {
181+
return "KeyPassword"
182+
}
183+
184+
ssl := &options.SSLOpt{
185+
ClientCertificateKeyFile: "KeyFile",
186+
ClientCertificateKeyPassword: f,
187+
CaFile: "CaFile",
188+
}
189+
c, err := NewClientWithOptions("mongodb://localhost", options.Client().SetSSL(ssl))
190+
require.NoError(t, err)
191+
192+
require.Equal(t, c.connString.SSLClientCertificateKeyFile, "KeyFile")
193+
require.Equal(t, c.connString.SSLClientCertificateKeyFileSet, true)
194+
require.Equal(t, reflect.ValueOf(c.connString.SSLClientCertificateKeyPassword).Pointer(), reflect.ValueOf(f).Pointer())
195+
require.Equal(t, c.connString.SSLClientCertificateKeyPasswordSet, true)
196+
require.Equal(t, c.connString.SSLCaFile, "CaFile")
197+
require.Equal(t, c.connString.SSLCaFileSet, true)
198+
})
199+
}
200+
162201
func TestClientOptions_CustomDialer(t *testing.T) {
163202
td := &testDialer{d: &net.Dialer{}}
164203
opts := options.Client().SetDialer(td)

mongo/options/clientoptions.go

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -285,17 +285,23 @@ func (c *ClientOptions) SetSSL(ssl *SSLOpt) *ClientOptions {
285285
c.ConnString.SSL = ssl.Enabled
286286
c.ConnString.SSLSet = true
287287

288-
c.ConnString.SSLClientCertificateKeyFile = ssl.ClientCertificateKeyFile
289-
c.ConnString.SSLClientCertificateKeyFileSet = true
288+
if ssl.ClientCertificateKeyFile != "" {
289+
c.ConnString.SSLClientCertificateKeyFile = ssl.ClientCertificateKeyFile
290+
c.ConnString.SSLClientCertificateKeyFileSet = true
291+
}
290292

291-
c.ConnString.SSLClientCertificateKeyPassword = ssl.ClientCertificateKeyPassword
292-
c.ConnString.SSLClientCertificateKeyPasswordSet = true
293+
if ssl.ClientCertificateKeyPassword != nil {
294+
c.ConnString.SSLClientCertificateKeyPassword = ssl.ClientCertificateKeyPassword
295+
c.ConnString.SSLClientCertificateKeyPasswordSet = true
296+
}
293297

294298
c.ConnString.SSLInsecure = ssl.Insecure
295299
c.ConnString.SSLInsecureSet = true
296300

297-
c.ConnString.SSLCaFile = ssl.CaFile
298-
c.ConnString.SSLCaFileSet = true
301+
if ssl.CaFile != "" {
302+
c.ConnString.SSLCaFile = ssl.CaFile
303+
c.ConnString.SSLCaFileSet = true
304+
}
299305

300306
return c
301307
}

0 commit comments

Comments
 (0)