@@ -8,18 +8,20 @@ package main
88
99import (
1010 "context"
11+ "crypto/tls"
1112 "encoding/base64"
1213 "errors"
1314 "fmt"
1415 "net/url"
1516 "os"
17+ "path/filepath"
1618 "testing"
1719 "time"
1820
1921 "github.com/stretchr/testify/assert"
22+ "github.com/stretchr/testify/require"
2023 "go.mongodb.org/mongo-driver/v2/bson"
2124 "go.mongodb.org/mongo-driver/v2/internal/handshake"
22- "go.mongodb.org/mongo-driver/v2/internal/require"
2325 "go.mongodb.org/mongo-driver/v2/mongo"
2426 "go.mongodb.org/mongo-driver/v2/mongo/options"
2527)
@@ -29,91 +31,91 @@ func TestAtlas(t *testing.T) {
2931 name string
3032 envVar string
3133 certKeyFile string
32- wantErr string
34+ wantErrCode string
3335 }{
3436 {
3537 name : "Atlas with TLS" ,
3638 envVar : "ATLAS_REPL" ,
3739 certKeyFile : "" ,
38- wantErr : "" ,
40+ wantErrCode : "" ,
3941 },
4042 {
4143 name : "Atlas with TLS and shared cluster" ,
4244 envVar : "ATLAS_SHRD" ,
4345 certKeyFile : "" ,
44- wantErr : "" ,
46+ wantErrCode : "" ,
4547 },
4648 {
4749 name : "Atlas with free tier" ,
4850 envVar : "ATLAS_FREE" ,
4951 certKeyFile : "" ,
50- wantErr : "" ,
52+ wantErrCode : "" ,
5153 },
5254 {
5355 name : "Atlas with TLS 1.1" ,
5456 envVar : "ATLAS_TLS11" ,
5557 certKeyFile : "" ,
56- wantErr : "" ,
58+ wantErrCode : "" ,
5759 },
5860 {
5961 name : "Atlas with TLS 1.2" ,
6062 envVar : "ATLAS_TLS12" ,
6163 certKeyFile : "" ,
62- wantErr : "" ,
64+ wantErrCode : "" ,
6365 },
6466 {
6567 name : "Atlas with serverless" ,
6668 envVar : "ATLAS_SERVERLESS" ,
6769 certKeyFile : "" ,
68- wantErr : "" ,
70+ wantErrCode : "" ,
6971 },
7072 {
7173 name : "Atlas with srv file on replica set" ,
7274 envVar : "ATLAS_SRV_REPL" ,
7375 certKeyFile : "" ,
74- wantErr : "" ,
76+ wantErrCode : "" ,
7577 },
7678 {
7779 name : "Atlas with srv file on shared cluster" ,
7880 envVar : "ATLAS_SRV_SHRD" ,
7981 certKeyFile : "" ,
80- wantErr : "" ,
82+ wantErrCode : "" ,
8183 },
8284 {
8385 name : "Atlas with srv file on free tier" ,
8486 envVar : "ATLAS_SRV_FREE" ,
8587 certKeyFile : "" ,
86- wantErr : "" ,
88+ wantErrCode : "" ,
8789 },
8890 {
8991 name : "Atlas with srv file on TLS 1.1" ,
9092 envVar : "ATLAS_SRV_TLS11" ,
9193 certKeyFile : "" ,
92- wantErr : "" ,
94+ wantErrCode : "" ,
9395 },
9496 {
9597 name : "Atlas with srv file on TLS 1.2" ,
9698 envVar : "ATLAS_SRV_TLS12" ,
9799 certKeyFile : "" ,
98- wantErr : "" ,
100+ wantErrCode : "" ,
99101 },
100102 {
101103 name : "Atlas with srv file on serverless" ,
102104 envVar : "ATLAS_SRV_SERVERLESS" ,
103105 certKeyFile : "" ,
104- wantErr : "" ,
106+ wantErrCode : "" ,
105107 },
106108 {
107109 name : "Atlas with X509 Dev" ,
108110 envVar : "ATLAS_X509_DEV" ,
109111 certKeyFile : createAtlasX509DevCertKeyFile (t ),
110- wantErr : "" ,
112+ wantErrCode : "" ,
111113 },
112114 {
113115 name : "Atlas with X509 Dev no user" ,
114116 envVar : "ATLAS_X509_DEV" ,
115117 certKeyFile : createAtlasX509DevCertKeyFileNoUser (t ),
116- wantErr : "(UserNotFound) Could not find user" ,
118+ wantErrCode : "11" , // UserNotFound
117119 },
118120 }
119121
@@ -133,21 +135,34 @@ func TestAtlas(t *testing.T) {
133135
134136 // Run basic connectivity test.
135137 err := runTest (context .Background (), clientOpts )
136- if tc .wantErr != "" {
137- assert .ErrorContains (t , err , tc .wantErr , "error running test with TLS" )
138+ if tc .wantErrCode != "" {
139+ var cmdErr * mongo.CommandError
140+ if errors .As (err , & cmdErr ) {
141+ assert .Equal (t , cmdErr .Code , tc .wantErrCode )
142+ } else {
143+ t .Fatalf ("expected error to be a command error, got: %v" , err )
144+ }
138145
139146 return
140147 }
141148 require .NoError (t , err , "error running test with TLS" )
142149
143- tlsConfigSkipVerify := clientOpts .TLSConfig
144- tlsConfigSkipVerify .InsecureSkipVerify = true
150+ orig := clientOpts .TLSConfig
151+ if orig == nil {
152+ orig = & tls.Config {}
153+ }
154+
155+ insecure := orig .Clone ()
156+ insecure .InsecureSkipVerify = true
145157
146158 // Run the connectivity test with InsecureSkipVerify to ensure SNI is done
147159 // correctly even if verification is disabled.
148- clientOpts .SetTLSConfig (tlsConfigSkipVerify )
160+ insecureClientOpts := options .Client ().
161+ ApplyURI (uri ).
162+ SetServerSelectionTimeout (1 * time .Second ).
163+ SetTLSConfig (insecure )
149164
150- err = runTest (context .Background (), clientOpts )
165+ err = runTest (context .Background (), insecureClientOpts )
151166 require .NoError (t , err , "error running test with tlsInsecure" )
152167 })
153168 }
@@ -180,7 +195,10 @@ func runTest(ctx context.Context, clientOpts *options.ClientOptions) error {
180195func createAtlasX509DevCertKeyFile (t * testing.T ) string {
181196 t .Helper ()
182197
183- certBytes , err := base64 .StdEncoding .DecodeString (os .Getenv ("ATLAS_X509_DEV_CERT_BASE64" ))
198+ b64 := os .Getenv ("ATLAS_X509_DEV_CERT_BASE64" )
199+ assert .NotEmpty (t , b64 , "Environment variable ATLAS_X509_DEV_CERT_BASE64 is not set" )
200+
201+ certBytes , err := base64 .StdEncoding .DecodeString (b64 )
184202 require .NoError (t , err , "failed to decode ATLAS_X509_DEV_CERT_BASE64" )
185203
186204 certFilePath := t .TempDir () + "/atlas_x509_dev_cert.pem"
@@ -194,7 +212,10 @@ func createAtlasX509DevCertKeyFile(t *testing.T) string {
194212func createAtlasX509DevCertKeyFileNoUser (t * testing.T ) string {
195213 t .Helper ()
196214
197- keyBytes , err := base64 .StdEncoding .DecodeString (os .Getenv ("ATLAS_X509_DEV_CERT_NOUSER_BASE64" ))
215+ b64 := os .Getenv ("ATLAS_X509_DEV_CERT_NOUSER_BASE64" )
216+ assert .NotEmpty (t , b64 , "Environment variable ATLAS_X509_DEV_CERT_NOUSER_BASE64 is not set" )
217+
218+ keyBytes , err := base64 .StdEncoding .DecodeString (b64 )
198219 require .NoError (t , err , "failed to decode ATLAS_X509_DEV_CERT_NOUSER_BASE64" )
199220
200221 keyFilePath := t .TempDir () + "/atlas_x509_dev_cert_no_user.pem"
@@ -212,7 +233,7 @@ func addTLSCertKeyFile(t *testing.T, certKeyFile, uri string) string {
212233 require .NoError (t , err , "failed to parse uri" )
213234
214235 q := u .Query ()
215- q .Set ("tlsCertificateKeyFile" , certKeyFile )
236+ q .Set ("tlsCertificateKeyFile" , filepath . ToSlash ( certKeyFile ) )
216237
217238 u .RawQuery = q .Encode ()
218239
0 commit comments