@@ -8,56 +8,159 @@ package main
88
99import (
1010 "context"
11+ "crypto/tls"
12+ "encoding/base64"
1113 "errors"
12- "flag"
1314 "fmt"
15+ "net/url"
1416 "os"
17+ "path/filepath"
1518 "testing"
1619 "time"
1720
21+ "github.com/stretchr/testify/assert"
22+ "github.com/stretchr/testify/require"
1823 "go.mongodb.org/mongo-driver/v2/bson"
1924 "go.mongodb.org/mongo-driver/v2/internal/handshake"
2025 "go.mongodb.org/mongo-driver/v2/mongo"
2126 "go.mongodb.org/mongo-driver/v2/mongo/options"
2227)
2328
24- func TestMain (m * testing.M ) {
25- flag .Parse ()
26- os .Exit (m .Run ())
27- }
28-
2929func TestAtlas (t * testing.T ) {
30- uris := flag .Args ()
31- ctx := context .Background ()
32-
33- t .Logf ("Running atlas tests for %d uris\n " , len (uris ))
34-
35- for idx , uri := range uris {
36- t .Logf ("Running test %d\n " , idx )
37-
38- // Set a low server selection timeout so we fail fast if there are errors.
39- clientOpts := options .Client ().
40- ApplyURI (uri ).
41- SetServerSelectionTimeout (1 * time .Second )
42-
43- // Run basic connectivity test.
44- if err := runTest (ctx , clientOpts ); err != nil {
45- t .Fatalf ("error running test with TLS at index %d: %v" , idx , err )
46- }
47-
48- tlsConfigSkipVerify := clientOpts .TLSConfig
49- tlsConfigSkipVerify .InsecureSkipVerify = true
50-
51- // Run the connectivity test with InsecureSkipVerify to ensure SNI is done correctly even if verification is
52- // disabled.
53- clientOpts .SetTLSConfig (tlsConfigSkipVerify )
54-
55- if err := runTest (ctx , clientOpts ); err != nil {
56- t .Fatalf ("error running test with tlsInsecure at index %d: %v" , idx , err )
57- }
30+ cases := []struct {
31+ name string
32+ envVar string
33+ certKeyFile string
34+ wantErr string
35+ }{
36+ {
37+ name : "Atlas with TLS" ,
38+ envVar : "ATLAS_REPL" ,
39+ certKeyFile : "" ,
40+ wantErr : "" ,
41+ },
42+ {
43+ name : "Atlas with TLS and shared cluster" ,
44+ envVar : "ATLAS_SHRD" ,
45+ certKeyFile : "" ,
46+ wantErr : "" ,
47+ },
48+ {
49+ name : "Atlas with free tier" ,
50+ envVar : "ATLAS_FREE" ,
51+ certKeyFile : "" ,
52+ wantErr : "" ,
53+ },
54+ {
55+ name : "Atlas with TLS 1.1" ,
56+ envVar : "ATLAS_TLS11" ,
57+ certKeyFile : "" ,
58+ wantErr : "" ,
59+ },
60+ {
61+ name : "Atlas with TLS 1.2" ,
62+ envVar : "ATLAS_TLS12" ,
63+ certKeyFile : "" ,
64+ wantErr : "" ,
65+ },
66+ {
67+ name : "Atlas with serverless" ,
68+ envVar : "ATLAS_SERVERLESS" ,
69+ certKeyFile : "" ,
70+ wantErr : "" ,
71+ },
72+ {
73+ name : "Atlas with srv file on replica set" ,
74+ envVar : "ATLAS_SRV_REPL" ,
75+ certKeyFile : "" ,
76+ wantErr : "" ,
77+ },
78+ {
79+ name : "Atlas with srv file on shared cluster" ,
80+ envVar : "ATLAS_SRV_SHRD" ,
81+ certKeyFile : "" ,
82+ wantErr : "" ,
83+ },
84+ {
85+ name : "Atlas with srv file on free tier" ,
86+ envVar : "ATLAS_SRV_FREE" ,
87+ certKeyFile : "" ,
88+ wantErr : "" ,
89+ },
90+ {
91+ name : "Atlas with srv file on TLS 1.1" ,
92+ envVar : "ATLAS_SRV_TLS11" ,
93+ certKeyFile : "" ,
94+ wantErr : "" ,
95+ },
96+ {
97+ name : "Atlas with srv file on TLS 1.2" ,
98+ envVar : "ATLAS_SRV_TLS12" ,
99+ certKeyFile : "" ,
100+ wantErr : "" ,
101+ },
102+ {
103+ name : "Atlas with srv file on serverless" ,
104+ envVar : "ATLAS_SRV_SERVERLESS" ,
105+ certKeyFile : "" ,
106+ wantErr : "" ,
107+ },
108+ {
109+ name : "Atlas with X509 Dev" ,
110+ envVar : "ATLAS_X509_DEV" ,
111+ certKeyFile : createAtlasX509DevCertKeyFile (t ),
112+ wantErr : "" ,
113+ },
114+ {
115+ name : "Atlas with X509 Dev no user" ,
116+ envVar : "ATLAS_X509_DEV" ,
117+ certKeyFile : createAtlasX509DevCertKeyFileNoUser (t ),
118+ wantErr : "UserNotFound" ,
119+ },
58120 }
59121
60- t .Logf ("Finished!" )
122+ for _ , tc := range cases {
123+ t .Run (fmt .Sprintf ("%s (%s)" , tc .name , tc .envVar ), func (t * testing.T ) {
124+ uri := os .Getenv (tc .envVar )
125+ assert .NotEmpty (t , uri , fmt .Sprintf ("Environment variable %s is not set" , tc .envVar ))
126+
127+ if tc .certKeyFile != "" {
128+ uri = addTLSCertKeyFile (t , tc .certKeyFile , uri )
129+ }
130+
131+ // Set a low server selection timeout so we fail fast if there are errors.
132+ clientOpts := options .Client ().
133+ ApplyURI (uri ).
134+ SetServerSelectionTimeout (1 * time .Second )
135+
136+ // Run basic connectivity test.
137+ err := runTest (context .Background (), clientOpts )
138+ if tc .wantErr != "" {
139+ assert .ErrorContains (t , err , tc .wantErr , "expected error to contain %q" , tc .wantErr )
140+
141+ return
142+ }
143+ require .NoError (t , err , "error running test with TLS" )
144+
145+ orig := clientOpts .TLSConfig
146+ if orig == nil {
147+ orig = & tls.Config {}
148+ }
149+
150+ insecure := orig .Clone ()
151+ insecure .InsecureSkipVerify = true
152+
153+ // Run the connectivity test with InsecureSkipVerify to ensure SNI is done
154+ // correctly even if verification is disabled.
155+ insecureClientOpts := options .Client ().
156+ ApplyURI (uri ).
157+ SetServerSelectionTimeout (1 * time .Second ).
158+ SetTLSConfig (insecure )
159+
160+ err = runTest (context .Background (), insecureClientOpts )
161+ require .NoError (t , err , "error running test with tlsInsecure" )
162+ })
163+ }
61164}
62165
63166func runTest (ctx context.Context , clientOpts * options.ClientOptions ) error {
@@ -83,3 +186,51 @@ func runTest(ctx context.Context, clientOpts *options.ClientOptions) error {
83186 }
84187 return nil
85188}
189+
190+ func createAtlasX509DevCertKeyFile (t * testing.T ) string {
191+ t .Helper ()
192+
193+ b64 := os .Getenv ("ATLAS_X509_DEV_CERT_BASE64" )
194+ assert .NotEmpty (t , b64 , "Environment variable ATLAS_X509_DEV_CERT_BASE64 is not set" )
195+
196+ certBytes , err := base64 .StdEncoding .DecodeString (b64 )
197+ require .NoError (t , err , "failed to decode ATLAS_X509_DEV_CERT_BASE64" )
198+
199+ certFilePath := t .TempDir () + "/atlas_x509_dev_cert.pem"
200+
201+ err = os .WriteFile (certFilePath , certBytes , 0600 )
202+ require .NoError (t , err , "failed to write ATLAS_X509_DEV_CERT_BASE64 to file" )
203+
204+ return certFilePath
205+ }
206+
207+ func createAtlasX509DevCertKeyFileNoUser (t * testing.T ) string {
208+ t .Helper ()
209+
210+ b64 := os .Getenv ("ATLAS_X509_DEV_CERT_NOUSER_BASE64" )
211+ assert .NotEmpty (t , b64 , "Environment variable ATLAS_X509_DEV_CERT_NOUSER_BASE64 is not set" )
212+
213+ keyBytes , err := base64 .StdEncoding .DecodeString (b64 )
214+ require .NoError (t , err , "failed to decode ATLAS_X509_DEV_CERT_NOUSER_BASE64" )
215+
216+ keyFilePath := t .TempDir () + "/atlas_x509_dev_cert_no_user.pem"
217+
218+ err = os .WriteFile (keyFilePath , keyBytes , 0600 )
219+ require .NoError (t , err , "failed to write ATLAS_X509_DEV_CERT_NOUSER_BASE64 to file" )
220+
221+ return keyFilePath
222+ }
223+
224+ func addTLSCertKeyFile (t * testing.T , certKeyFile , uri string ) string {
225+ t .Helper ()
226+
227+ u , err := url .Parse (uri )
228+ require .NoError (t , err , "failed to parse uri" )
229+
230+ q := u .Query ()
231+ q .Set ("tlsCertificateKeyFile" , filepath .ToSlash (certKeyFile ))
232+
233+ u .RawQuery = q .Encode ()
234+
235+ return u .String ()
236+ }
0 commit comments