Skip to content

Commit de1513c

Browse files
Clean up code
1 parent 417a0ad commit de1513c

File tree

1 file changed

+46
-25
lines changed

1 file changed

+46
-25
lines changed

internal/cmd/testatlas/atlas_test.go

Lines changed: 46 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,20 @@ package main
88

99
import (
1010
"context"
11+
"crypto/tls"
1112
"encoding/base64"
1213
"errors"
1314
"fmt"
1415
"net/url"
1516
"os"
17+
"path/filepath"
1618
"testing"
1719
"time"
1820

1921
"github.com/stretchr/testify/assert"
22+
"github.com/stretchr/testify/require"
2023
"go.mongodb.org/mongo-driver/v2/bson"
2124
"go.mongodb.org/mongo-driver/v2/internal/handshake"
22-
"go.mongodb.org/mongo-driver/v2/internal/require"
2325
"go.mongodb.org/mongo-driver/v2/mongo"
2426
"go.mongodb.org/mongo-driver/v2/mongo/options"
2527
)
@@ -29,91 +31,91 @@ func TestAtlas(t *testing.T) {
2931
name string
3032
envVar string
3133
certKeyFile string
32-
wantErr string
34+
wantErrCode string
3335
}{
3436
{
3537
name: "Atlas with TLS",
3638
envVar: "ATLAS_REPL",
3739
certKeyFile: "",
38-
wantErr: "",
40+
wantErrCode: "",
3941
},
4042
{
4143
name: "Atlas with TLS and shared cluster",
4244
envVar: "ATLAS_SHRD",
4345
certKeyFile: "",
44-
wantErr: "",
46+
wantErrCode: "",
4547
},
4648
{
4749
name: "Atlas with free tier",
4850
envVar: "ATLAS_FREE",
4951
certKeyFile: "",
50-
wantErr: "",
52+
wantErrCode: "",
5153
},
5254
{
5355
name: "Atlas with TLS 1.1",
5456
envVar: "ATLAS_TLS11",
5557
certKeyFile: "",
56-
wantErr: "",
58+
wantErrCode: "",
5759
},
5860
{
5961
name: "Atlas with TLS 1.2",
6062
envVar: "ATLAS_TLS12",
6163
certKeyFile: "",
62-
wantErr: "",
64+
wantErrCode: "",
6365
},
6466
{
6567
name: "Atlas with serverless",
6668
envVar: "ATLAS_SERVERLESS",
6769
certKeyFile: "",
68-
wantErr: "",
70+
wantErrCode: "",
6971
},
7072
{
7173
name: "Atlas with srv file on replica set",
7274
envVar: "ATLAS_SRV_REPL",
7375
certKeyFile: "",
74-
wantErr: "",
76+
wantErrCode: "",
7577
},
7678
{
7779
name: "Atlas with srv file on shared cluster",
7880
envVar: "ATLAS_SRV_SHRD",
7981
certKeyFile: "",
80-
wantErr: "",
82+
wantErrCode: "",
8183
},
8284
{
8385
name: "Atlas with srv file on free tier",
8486
envVar: "ATLAS_SRV_FREE",
8587
certKeyFile: "",
86-
wantErr: "",
88+
wantErrCode: "",
8789
},
8890
{
8991
name: "Atlas with srv file on TLS 1.1",
9092
envVar: "ATLAS_SRV_TLS11",
9193
certKeyFile: "",
92-
wantErr: "",
94+
wantErrCode: "",
9395
},
9496
{
9597
name: "Atlas with srv file on TLS 1.2",
9698
envVar: "ATLAS_SRV_TLS12",
9799
certKeyFile: "",
98-
wantErr: "",
100+
wantErrCode: "",
99101
},
100102
{
101103
name: "Atlas with srv file on serverless",
102104
envVar: "ATLAS_SRV_SERVERLESS",
103105
certKeyFile: "",
104-
wantErr: "",
106+
wantErrCode: "",
105107
},
106108
{
107109
name: "Atlas with X509 Dev",
108110
envVar: "ATLAS_X509_DEV",
109111
certKeyFile: createAtlasX509DevCertKeyFile(t),
110-
wantErr: "",
112+
wantErrCode: "",
111113
},
112114
{
113115
name: "Atlas with X509 Dev no user",
114116
envVar: "ATLAS_X509_DEV",
115117
certKeyFile: createAtlasX509DevCertKeyFileNoUser(t),
116-
wantErr: "(UserNotFound) Could not find user",
118+
wantErrCode: "11", // UserNotFound
117119
},
118120
}
119121

@@ -133,21 +135,34 @@ func TestAtlas(t *testing.T) {
133135

134136
// Run basic connectivity test.
135137
err := runTest(context.Background(), clientOpts)
136-
if tc.wantErr != "" {
137-
assert.ErrorContains(t, err, tc.wantErr, "error running test with TLS")
138+
if tc.wantErrCode != "" {
139+
var cmdErr *mongo.CommandError
140+
if errors.As(err, &cmdErr) {
141+
assert.Equal(t, cmdErr.Code, tc.wantErrCode)
142+
} else {
143+
t.Fatalf("expected error to be a command error, got: %v", err)
144+
}
138145

139146
return
140147
}
141148
require.NoError(t, err, "error running test with TLS")
142149

143-
tlsConfigSkipVerify := clientOpts.TLSConfig
144-
tlsConfigSkipVerify.InsecureSkipVerify = true
150+
orig := clientOpts.TLSConfig
151+
if orig == nil {
152+
orig = &tls.Config{}
153+
}
154+
155+
insecure := orig.Clone()
156+
insecure.InsecureSkipVerify = true
145157

146158
// Run the connectivity test with InsecureSkipVerify to ensure SNI is done
147159
// correctly even if verification is disabled.
148-
clientOpts.SetTLSConfig(tlsConfigSkipVerify)
160+
insecureClientOpts := options.Client().
161+
ApplyURI(uri).
162+
SetServerSelectionTimeout(1 * time.Second).
163+
SetTLSConfig(insecure)
149164

150-
err = runTest(context.Background(), clientOpts)
165+
err = runTest(context.Background(), insecureClientOpts)
151166
require.NoError(t, err, "error running test with tlsInsecure")
152167
})
153168
}
@@ -180,7 +195,10 @@ func runTest(ctx context.Context, clientOpts *options.ClientOptions) error {
180195
func createAtlasX509DevCertKeyFile(t *testing.T) string {
181196
t.Helper()
182197

183-
certBytes, err := base64.StdEncoding.DecodeString(os.Getenv("ATLAS_X509_DEV_CERT_BASE64"))
198+
b64 := os.Getenv("ATLAS_X509_DEV_CERT_BASE64")
199+
assert.NotEmpty(t, b64, "Environment variable ATLAS_X509_DEV_CERT_BASE64 is not set")
200+
201+
certBytes, err := base64.StdEncoding.DecodeString(b64)
184202
require.NoError(t, err, "failed to decode ATLAS_X509_DEV_CERT_BASE64")
185203

186204
certFilePath := t.TempDir() + "/atlas_x509_dev_cert.pem"
@@ -194,7 +212,10 @@ func createAtlasX509DevCertKeyFile(t *testing.T) string {
194212
func createAtlasX509DevCertKeyFileNoUser(t *testing.T) string {
195213
t.Helper()
196214

197-
keyBytes, err := base64.StdEncoding.DecodeString(os.Getenv("ATLAS_X509_DEV_CERT_NOUSER_BASE64"))
215+
b64 := os.Getenv("ATLAS_X509_DEV_CERT_NOUSER_BASE64")
216+
assert.NotEmpty(t, b64, "Environment variable ATLAS_X509_DEV_CERT_NOUSER_BASE64 is not set")
217+
218+
keyBytes, err := base64.StdEncoding.DecodeString(b64)
198219
require.NoError(t, err, "failed to decode ATLAS_X509_DEV_CERT_NOUSER_BASE64")
199220

200221
keyFilePath := t.TempDir() + "/atlas_x509_dev_cert_no_user.pem"
@@ -212,7 +233,7 @@ func addTLSCertKeyFile(t *testing.T, certKeyFile, uri string) string {
212233
require.NoError(t, err, "failed to parse uri")
213234

214235
q := u.Query()
215-
q.Set("tlsCertificateKeyFile", certKeyFile)
236+
q.Set("tlsCertificateKeyFile", filepath.ToSlash(certKeyFile))
216237

217238
u.RawQuery = q.Encode()
218239

0 commit comments

Comments
 (0)