@@ -8,56 +8,161 @@ package main
88
99import (
1010 "context"
11+ "crypto/tls"
12+ "encoding/base64"
1113 "errors"
12- "flag"
1314 "fmt"
15+ "net/url"
1416 "os"
17+ "path/filepath"
1518 "testing"
1619 "time"
1720
1821 "go.mongodb.org/mongo-driver/v2/bson"
22+ "go.mongodb.org/mongo-driver/v2/internal/assert"
1923 "go.mongodb.org/mongo-driver/v2/internal/handshake"
24+ "go.mongodb.org/mongo-driver/v2/internal/require"
2025 "go.mongodb.org/mongo-driver/v2/mongo"
2126 "go.mongodb.org/mongo-driver/v2/mongo/options"
2227)
2328
24- func TestMain (m * testing.M ) {
25- flag .Parse ()
26- os .Exit (m .Run ())
27- }
28-
2929func TestAtlas (t * testing.T ) {
30- uris := flag .Args ()
31- ctx := context .Background ()
32-
33- t .Logf ("Running atlas tests for %d uris\n " , len (uris ))
34-
35- for idx , uri := range uris {
36- t .Logf ("Running test %d\n " , idx )
37-
38- // Set a low server selection timeout so we fail fast if there are errors.
39- clientOpts := options .Client ().
40- ApplyURI (uri ).
41- SetServerSelectionTimeout (1 * time .Second )
42-
43- // Run basic connectivity test.
44- if err := runTest (ctx , clientOpts ); err != nil {
45- t .Fatalf ("error running test with TLS at index %d: %v" , idx , err )
46- }
47-
48- tlsConfigSkipVerify := clientOpts .TLSConfig
49- tlsConfigSkipVerify .InsecureSkipVerify = true
50-
51- // Run the connectivity test with InsecureSkipVerify to ensure SNI is done correctly even if verification is
52- // disabled.
53- clientOpts .SetTLSConfig (tlsConfigSkipVerify )
54-
55- if err := runTest (ctx , clientOpts ); err != nil {
56- t .Fatalf ("error running test with tlsInsecure at index %d: %v" , idx , err )
57- }
30+ cases := []struct {
31+ name string
32+ envVar string
33+ certKeyFile string
34+ wantErr string
35+ }{
36+ {
37+ name : "Atlas with TLS" ,
38+ envVar : "ATLAS_REPL" ,
39+ certKeyFile : "" ,
40+ wantErr : "" ,
41+ },
42+ {
43+ name : "Atlas with TLS and shared cluster" ,
44+ envVar : "ATLAS_SHRD" ,
45+ certKeyFile : "" ,
46+ wantErr : "" ,
47+ },
48+ {
49+ name : "Atlas with free tier" ,
50+ envVar : "ATLAS_FREE" ,
51+ certKeyFile : "" ,
52+ wantErr : "" ,
53+ },
54+ {
55+ name : "Atlas with TLS 1.1" ,
56+ envVar : "ATLAS_TLS11" ,
57+ certKeyFile : "" ,
58+ wantErr : "" ,
59+ },
60+ {
61+ name : "Atlas with TLS 1.2" ,
62+ envVar : "ATLAS_TLS12" ,
63+ certKeyFile : "" ,
64+ wantErr : "" ,
65+ },
66+ {
67+ name : "Atlas with serverless" ,
68+ envVar : "ATLAS_SERVERLESS" ,
69+ certKeyFile : "" ,
70+ wantErr : "" ,
71+ },
72+ {
73+ name : "Atlas with srv file on replica set" ,
74+ envVar : "ATLAS_SRV_REPL" ,
75+ certKeyFile : "" ,
76+ wantErr : "" ,
77+ },
78+ {
79+ name : "Atlas with srv file on shared cluster" ,
80+ envVar : "ATLAS_SRV_SHRD" ,
81+ certKeyFile : "" ,
82+ wantErr : "" ,
83+ },
84+ {
85+ name : "Atlas with srv file on free tier" ,
86+ envVar : "ATLAS_SRV_FREE" ,
87+ certKeyFile : "" ,
88+ wantErr : "" ,
89+ },
90+ {
91+ name : "Atlas with srv file on TLS 1.1" ,
92+ envVar : "ATLAS_SRV_TLS11" ,
93+ certKeyFile : "" ,
94+ wantErr : "" ,
95+ },
96+ {
97+ name : "Atlas with srv file on TLS 1.2" ,
98+ envVar : "ATLAS_SRV_TLS12" ,
99+ certKeyFile : "" ,
100+ wantErr : "" ,
101+ },
102+ {
103+ name : "Atlas with srv file on serverless" ,
104+ envVar : "ATLAS_SRV_SERVERLESS" ,
105+ certKeyFile : "" ,
106+ wantErr : "" ,
107+ },
108+ {
109+ name : "Atlas with X509 Dev" ,
110+ envVar : "ATLAS_X509_DEV" ,
111+ certKeyFile : createAtlasX509DevCertKeyFile (t ),
112+ wantErr : "" ,
113+ },
114+ {
115+ name : "Atlas with X509 Dev no user" ,
116+ envVar : "ATLAS_X509_DEV" ,
117+ certKeyFile : createAtlasX509DevCertKeyFileNoUser (t ),
118+ wantErr : "UserNotFound" ,
119+ },
58120 }
59121
60- t .Logf ("Finished!" )
122+ for _ , tc := range cases {
123+ t .Run (fmt .Sprintf ("%s (%s)" , tc .name , tc .envVar ), func (t * testing.T ) {
124+ uri := os .Getenv (tc .envVar )
125+ if uri == "" {
126+ t .Skipf ("Environment variable %q is not set" , tc .envVar )
127+ }
128+
129+ if tc .certKeyFile != "" {
130+ uri = addTLSCertKeyFile (t , tc .certKeyFile , uri )
131+ }
132+
133+ // Set a low server selection timeout so we fail fast if there are errors.
134+ clientOpts := options .Client ().
135+ ApplyURI (uri ).
136+ SetServerSelectionTimeout (1 * time .Second )
137+
138+ // Run basic connectivity test.
139+ err := runTest (context .Background (), clientOpts )
140+ if tc .wantErr != "" {
141+ assert .ErrorContains (t , err , tc .wantErr , "expected error to contain %q" , tc .wantErr )
142+
143+ return
144+ }
145+ require .NoError (t , err , "error running test with TLS" )
146+
147+ orig := clientOpts .TLSConfig
148+ if orig == nil {
149+ orig = & tls.Config {}
150+ }
151+
152+ insecure := orig .Clone ()
153+ insecure .InsecureSkipVerify = true
154+
155+ // Run the connectivity test with InsecureSkipVerify to ensure SNI is done
156+ // correctly even if verification is disabled.
157+ insecureClientOpts := options .Client ().
158+ ApplyURI (uri ).
159+ SetServerSelectionTimeout (1 * time .Second ).
160+ SetTLSConfig (insecure )
161+
162+ err = runTest (context .Background (), insecureClientOpts )
163+ require .NoError (t , err , "error running test with tlsInsecure" )
164+ })
165+ }
61166}
62167
63168func runTest (ctx context.Context , clientOpts * options.ClientOptions ) error {
@@ -83,3 +188,51 @@ func runTest(ctx context.Context, clientOpts *options.ClientOptions) error {
83188 }
84189 return nil
85190}
191+
192+ func createAtlasX509DevCertKeyFile (t * testing.T ) string {
193+ t .Helper ()
194+
195+ b64 := os .Getenv ("ATLAS_X509_DEV_CERT_BASE64" )
196+ assert .NotEmpty (t , b64 , "Environment variable ATLAS_X509_DEV_CERT_BASE64 is not set" )
197+
198+ certBytes , err := base64 .StdEncoding .DecodeString (b64 )
199+ require .NoError (t , err , "failed to decode ATLAS_X509_DEV_CERT_BASE64" )
200+
201+ certFilePath := t .TempDir () + "/atlas_x509_dev_cert.pem"
202+
203+ err = os .WriteFile (certFilePath , certBytes , 0600 )
204+ require .NoError (t , err , "failed to write ATLAS_X509_DEV_CERT_BASE64 to file" )
205+
206+ return certFilePath
207+ }
208+
209+ func createAtlasX509DevCertKeyFileNoUser (t * testing.T ) string {
210+ t .Helper ()
211+
212+ b64 := os .Getenv ("ATLAS_X509_DEV_CERT_NOUSER_BASE64" )
213+ assert .NotEmpty (t , b64 , "Environment variable ATLAS_X509_DEV_CERT_NOUSER_BASE64 is not set" )
214+
215+ keyBytes , err := base64 .StdEncoding .DecodeString (b64 )
216+ require .NoError (t , err , "failed to decode ATLAS_X509_DEV_CERT_NOUSER_BASE64" )
217+
218+ keyFilePath := t .TempDir () + "/atlas_x509_dev_cert_no_user.pem"
219+
220+ err = os .WriteFile (keyFilePath , keyBytes , 0600 )
221+ require .NoError (t , err , "failed to write ATLAS_X509_DEV_CERT_NOUSER_BASE64 to file" )
222+
223+ return keyFilePath
224+ }
225+
226+ func addTLSCertKeyFile (t * testing.T , certKeyFile , uri string ) string {
227+ t .Helper ()
228+
229+ u , err := url .Parse (uri )
230+ require .NoError (t , err , "failed to parse uri" )
231+
232+ q := u .Query ()
233+ q .Set ("tlsCertificateKeyFile" , filepath .ToSlash (certKeyFile ))
234+
235+ u .RawQuery = q .Encode ()
236+
237+ return u .String ()
238+ }
0 commit comments