@@ -8,56 +8,164 @@ package main
88
99import (
1010 "context"
11+ "crypto/tls"
12+ "encoding/base64"
1113 "errors"
12- "flag"
1314 "fmt"
15+ "net/url"
1416 "os"
17+ "path/filepath"
1518 "testing"
1619 "time"
1720
21+ "github.com/stretchr/testify/assert"
22+ "github.com/stretchr/testify/require"
1823 "go.mongodb.org/mongo-driver/v2/bson"
1924 "go.mongodb.org/mongo-driver/v2/internal/handshake"
2025 "go.mongodb.org/mongo-driver/v2/mongo"
2126 "go.mongodb.org/mongo-driver/v2/mongo/options"
2227)
2328
24- func TestMain (m * testing.M ) {
25- flag .Parse ()
26- os .Exit (m .Run ())
27- }
28-
2929func TestAtlas (t * testing.T ) {
30- uris := flag .Args ()
31- ctx := context .Background ()
32-
33- t .Logf ("Running atlas tests for %d uris\n " , len (uris ))
34-
35- for idx , uri := range uris {
36- t .Logf ("Running test %d\n " , idx )
37-
38- // Set a low server selection timeout so we fail fast if there are errors.
39- clientOpts := options .Client ().
40- ApplyURI (uri ).
41- SetServerSelectionTimeout (1 * time .Second )
42-
43- // Run basic connectivity test.
44- if err := runTest (ctx , clientOpts ); err != nil {
45- t .Fatalf ("error running test with TLS at index %d: %v" , idx , err )
46- }
47-
48- tlsConfigSkipVerify := clientOpts .TLSConfig
49- tlsConfigSkipVerify .InsecureSkipVerify = true
50-
51- // Run the connectivity test with InsecureSkipVerify to ensure SNI is done correctly even if verification is
52- // disabled.
53- clientOpts .SetTLSConfig (tlsConfigSkipVerify )
54-
55- if err := runTest (ctx , clientOpts ); err != nil {
56- t .Fatalf ("error running test with tlsInsecure at index %d: %v" , idx , err )
57- }
30+ cases := []struct {
31+ name string
32+ envVar string
33+ certKeyFile string
34+ wantErrCode string
35+ }{
36+ {
37+ name : "Atlas with TLS" ,
38+ envVar : "ATLAS_REPL" ,
39+ certKeyFile : "" ,
40+ wantErrCode : "" ,
41+ },
42+ {
43+ name : "Atlas with TLS and shared cluster" ,
44+ envVar : "ATLAS_SHRD" ,
45+ certKeyFile : "" ,
46+ wantErrCode : "" ,
47+ },
48+ {
49+ name : "Atlas with free tier" ,
50+ envVar : "ATLAS_FREE" ,
51+ certKeyFile : "" ,
52+ wantErrCode : "" ,
53+ },
54+ {
55+ name : "Atlas with TLS 1.1" ,
56+ envVar : "ATLAS_TLS11" ,
57+ certKeyFile : "" ,
58+ wantErrCode : "" ,
59+ },
60+ {
61+ name : "Atlas with TLS 1.2" ,
62+ envVar : "ATLAS_TLS12" ,
63+ certKeyFile : "" ,
64+ wantErrCode : "" ,
65+ },
66+ {
67+ name : "Atlas with serverless" ,
68+ envVar : "ATLAS_SERVERLESS" ,
69+ certKeyFile : "" ,
70+ wantErrCode : "" ,
71+ },
72+ {
73+ name : "Atlas with srv file on replica set" ,
74+ envVar : "ATLAS_SRV_REPL" ,
75+ certKeyFile : "" ,
76+ wantErrCode : "" ,
77+ },
78+ {
79+ name : "Atlas with srv file on shared cluster" ,
80+ envVar : "ATLAS_SRV_SHRD" ,
81+ certKeyFile : "" ,
82+ wantErrCode : "" ,
83+ },
84+ {
85+ name : "Atlas with srv file on free tier" ,
86+ envVar : "ATLAS_SRV_FREE" ,
87+ certKeyFile : "" ,
88+ wantErrCode : "" ,
89+ },
90+ {
91+ name : "Atlas with srv file on TLS 1.1" ,
92+ envVar : "ATLAS_SRV_TLS11" ,
93+ certKeyFile : "" ,
94+ wantErrCode : "" ,
95+ },
96+ {
97+ name : "Atlas with srv file on TLS 1.2" ,
98+ envVar : "ATLAS_SRV_TLS12" ,
99+ certKeyFile : "" ,
100+ wantErrCode : "" ,
101+ },
102+ {
103+ name : "Atlas with srv file on serverless" ,
104+ envVar : "ATLAS_SRV_SERVERLESS" ,
105+ certKeyFile : "" ,
106+ wantErrCode : "" ,
107+ },
108+ {
109+ name : "Atlas with X509 Dev" ,
110+ envVar : "ATLAS_X509_DEV" ,
111+ certKeyFile : createAtlasX509DevCertKeyFile (t ),
112+ wantErrCode : "" ,
113+ },
114+ {
115+ name : "Atlas with X509 Dev no user" ,
116+ envVar : "ATLAS_X509_DEV" ,
117+ certKeyFile : createAtlasX509DevCertKeyFileNoUser (t ),
118+ wantErrCode : "11" , // UserNotFound
119+ },
58120 }
59121
60- t .Logf ("Finished!" )
122+ for _ , tc := range cases {
123+ t .Run (fmt .Sprintf ("%s (%s)" , tc .name , tc .envVar ), func (t * testing.T ) {
124+ uri := os .Getenv (tc .envVar )
125+ assert .NotEmpty (t , uri , fmt .Sprintf ("Environment variable %s is not set" , tc .envVar ))
126+
127+ if tc .certKeyFile != "" {
128+ uri = addTLSCertKeyFile (t , tc .certKeyFile , uri )
129+ }
130+
131+ // Set a low server selection timeout so we fail fast if there are errors.
132+ clientOpts := options .Client ().
133+ ApplyURI (uri ).
134+ SetServerSelectionTimeout (1 * time .Second )
135+
136+ // Run basic connectivity test.
137+ err := runTest (context .Background (), clientOpts )
138+ if tc .wantErrCode != "" {
139+ var cmdErr * mongo.CommandError
140+ if errors .As (err , & cmdErr ) {
141+ assert .Equal (t , cmdErr .Code , tc .wantErrCode )
142+ } else {
143+ t .Fatalf ("expected error to be a command error, got: %v" , err )
144+ }
145+
146+ return
147+ }
148+ require .NoError (t , err , "error running test with TLS" )
149+
150+ orig := clientOpts .TLSConfig
151+ if orig == nil {
152+ orig = & tls.Config {}
153+ }
154+
155+ insecure := orig .Clone ()
156+ insecure .InsecureSkipVerify = true
157+
158+ // Run the connectivity test with InsecureSkipVerify to ensure SNI is done
159+ // correctly even if verification is disabled.
160+ insecureClientOpts := options .Client ().
161+ ApplyURI (uri ).
162+ SetServerSelectionTimeout (1 * time .Second ).
163+ SetTLSConfig (insecure )
164+
165+ err = runTest (context .Background (), insecureClientOpts )
166+ require .NoError (t , err , "error running test with tlsInsecure" )
167+ })
168+ }
61169}
62170
63171func runTest (ctx context.Context , clientOpts * options.ClientOptions ) error {
@@ -83,3 +191,51 @@ func runTest(ctx context.Context, clientOpts *options.ClientOptions) error {
83191 }
84192 return nil
85193}
194+
195+ func createAtlasX509DevCertKeyFile (t * testing.T ) string {
196+ t .Helper ()
197+
198+ b64 := os .Getenv ("ATLAS_X509_DEV_CERT_BASE64" )
199+ assert .NotEmpty (t , b64 , "Environment variable ATLAS_X509_DEV_CERT_BASE64 is not set" )
200+
201+ certBytes , err := base64 .StdEncoding .DecodeString (b64 )
202+ require .NoError (t , err , "failed to decode ATLAS_X509_DEV_CERT_BASE64" )
203+
204+ certFilePath := t .TempDir () + "/atlas_x509_dev_cert.pem"
205+
206+ err = os .WriteFile (certFilePath , certBytes , 0600 )
207+ require .NoError (t , err , "failed to write ATLAS_X509_DEV_CERT_BASE64 to file" )
208+
209+ return certFilePath
210+ }
211+
212+ func createAtlasX509DevCertKeyFileNoUser (t * testing.T ) string {
213+ t .Helper ()
214+
215+ b64 := os .Getenv ("ATLAS_X509_DEV_CERT_NOUSER_BASE64" )
216+ assert .NotEmpty (t , b64 , "Environment variable ATLAS_X509_DEV_CERT_NOUSER_BASE64 is not set" )
217+
218+ keyBytes , err := base64 .StdEncoding .DecodeString (b64 )
219+ require .NoError (t , err , "failed to decode ATLAS_X509_DEV_CERT_NOUSER_BASE64" )
220+
221+ keyFilePath := t .TempDir () + "/atlas_x509_dev_cert_no_user.pem"
222+
223+ err = os .WriteFile (keyFilePath , keyBytes , 0600 )
224+ require .NoError (t , err , "failed to write ATLAS_X509_DEV_CERT_NOUSER_BASE64 to file" )
225+
226+ return keyFilePath
227+ }
228+
229+ func addTLSCertKeyFile (t * testing.T , certKeyFile , uri string ) string {
230+ t .Helper ()
231+
232+ u , err := url .Parse (uri )
233+ require .NoError (t , err , "failed to parse uri" )
234+
235+ q := u .Query ()
236+ q .Set ("tlsCertificateKeyFile" , filepath .ToSlash (certKeyFile ))
237+
238+ u .RawQuery = q .Encode ()
239+
240+ return u .String ()
241+ }
0 commit comments