diff --git a/Taskfile.yml b/Taskfile.yml index a4c6f405bf..411a09366c 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -10,7 +10,7 @@ dotenv: ['.test.env'] tasks: ### Utility tasks. ### - default: + default: deps: [build, check-license, check-fmt, check-modules, lint, test-short] add-license: bash etc/check_license.sh -a @@ -36,7 +36,7 @@ tasks: build-aws-ecs-test: go test -c ./internal/test/aws -o aws.testbin - cross-compile: + cross-compile: - GOOS=linux GOARCH=386 go build ./... - GOOS=linux GOARCH=arm go build ./... - GOOS=linux GOARCH=arm64 go build ./... @@ -44,7 +44,7 @@ tasks: - GOOS=linux GOARCH=ppc64le go build ./... - GOOS=linux GOARCH=s390x go build ./... - check-fmt: + check-fmt: deps: [install-lll] cmds: - bash etc/check_fmt.sh @@ -57,9 +57,9 @@ tasks: api-report: bash etc/api_report.sh - install-libmongocrypt: + install-libmongocrypt: cmds: [bash etc/install-libmongocrypt.sh] - status: + status: - test -d install || test -d /cygdrive/c/libmongocrypt/bin run-docker: bash etc/run_docker.sh @@ -76,7 +76,7 @@ tasks: # specific operating systems or architectures. For example, staticcheck will only check for 64-bit # alignment of atomically accessed variables on 32-bit architectures (see # https://staticcheck.io/docs/checks#SA1027) - lint: + lint: cmds: - GOOS=linux GOARCH=386 etc/golangci-lint.sh - GOOS=linux GOARCH=arm etc/golangci-lint.sh @@ -104,8 +104,8 @@ tasks: test-oidc-remote: bash etc/run-oidc-remote-test.sh - test-atlas-connect: - - go test -v -run ^TestAtlas$ go.mongodb.org/mongo-driver/v2/internal/cmd/testatlas -args "$ATLAS_REPL" "$ATLAS_SHRD" "$ATLAS_FREE" "$ATLAS_TLS11" "$ATLAS_TLS12" "$ATLAS_SERVERLESS" "$ATLAS_SRV_REPL" "$ATLAS_SRV_SHRD" "$ATLAS_SRV_FREE" "$ATLAS_SRV_TLS11" "$ATLAS_SRV_TLS12" "$ATLAS_SRV_SERVERLESS" >> test.suite + test-atlas-connect: + - go test -v -run ^TestAtlas$ go.mongodb.org/mongo-driver/v2/internal/cmd/testatlas -tags atlastest >> test.suite test-awskms: bash etc/run-awskms-test.sh @@ -117,9 +117,9 @@ tasks: ### Local FaaS tasks. ### build-faas-awslambda: - requires: + requires: vars: [MONGODB_URI] - cmds: + cmds: - make -c internal/cmd/faas/awslambda ### Evergreen specific tasks. ### @@ -134,7 +134,7 @@ tasks: - ATLAS_DATA_LAKE_INTEGRATION_TEST=true go test -v ./internal/integration/unified -run TestUnifiedSpec/atlas-data-lake-testing >> spec_test.suite - ATLAS_DATA_LAKE_INTEGRATION_TEST=true go test -v ./internal/integration -run TestAtlasDataLake >> spec_test.suite - evg-test-enterprise-auth: + evg-test-enterprise-auth: - go run -tags gssapi ./internal/cmd/testentauth/main.go evg-test-oidc-auth: @@ -188,15 +188,15 @@ tasks: ### Benchmark specific tasks and support. ### benchmark: deps: [perf-files] - cmds: + cmds: - go test ${BUILD_TAGS} -benchmem -bench=. ./benchmark | test benchmark.suite - driver-benchmark: - cmds: + driver-benchmark: + cmds: - go test ./internal/cmd/benchmark -v --fullRun | tee perf.suite ### Internal tasks. ### - install-lll: + install-lll: internal: true cmds: - go install github.com/walle/lll/...@latest diff --git a/internal/assert/assertions.go b/internal/assert/assertions.go index c227d47c83..0754a411a1 100644 --- a/internal/assert/assertions.go +++ b/internal/assert/assertions.go @@ -1073,3 +1073,20 @@ func buildErrorChainString(err error) string { } return chain } + +// NotEmpty asserts that the specified object is NOT [Empty]. +// +// if assert.NotEmpty(t, obj) { +// assert.Equal(t, "two", obj[1]) +// } +func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) bool { + pass := !isEmpty(object) + if !pass { + if h, ok := t.(tHelper); ok { + h.Helper() + } + Fail(t, fmt.Sprintf("Should NOT be empty, but was %v", object), msgAndArgs...) + } + + return pass +} diff --git a/internal/cmd/testatlas/atlas_test.go b/internal/cmd/testatlas/atlas_test.go index cf4a84735f..74707276b0 100644 --- a/internal/cmd/testatlas/atlas_test.go +++ b/internal/cmd/testatlas/atlas_test.go @@ -4,60 +4,166 @@ // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 +//go:build atlastest +// +build atlastest + package main import ( "context" + "crypto/tls" + "encoding/base64" "errors" - "flag" "fmt" + "net/url" "os" + "path/filepath" "testing" "time" "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/handshake" + "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" ) -func TestMain(m *testing.M) { - flag.Parse() - os.Exit(m.Run()) -} - func TestAtlas(t *testing.T) { - uris := flag.Args() - ctx := context.Background() - - t.Logf("Running atlas tests for %d uris\n", len(uris)) - - for idx, uri := range uris { - t.Logf("Running test %d\n", idx) - - // Set a low server selection timeout so we fail fast if there are errors. - clientOpts := options.Client(). - ApplyURI(uri). - SetServerSelectionTimeout(1 * time.Second) - - // Run basic connectivity test. - if err := runTest(ctx, clientOpts); err != nil { - t.Fatalf("error running test with TLS at index %d: %v", idx, err) - } - - tlsConfigSkipVerify := clientOpts.TLSConfig - tlsConfigSkipVerify.InsecureSkipVerify = true - - // Run the connectivity test with InsecureSkipVerify to ensure SNI is done correctly even if verification is - // disabled. - clientOpts.SetTLSConfig(tlsConfigSkipVerify) - - if err := runTest(ctx, clientOpts); err != nil { - t.Fatalf("error running test with tlsInsecure at index %d: %v", idx, err) - } + cases := []struct { + name string + envVar string + certKeyFile string + wantErr string + }{ + { + name: "Atlas with TLS", + envVar: "ATLAS_REPL", + certKeyFile: "", + wantErr: "", + }, + { + name: "Atlas with TLS and shared cluster", + envVar: "ATLAS_SHRD", + certKeyFile: "", + wantErr: "", + }, + { + name: "Atlas with free tier", + envVar: "ATLAS_FREE", + certKeyFile: "", + wantErr: "", + }, + { + name: "Atlas with TLS 1.1", + envVar: "ATLAS_TLS11", + certKeyFile: "", + wantErr: "", + }, + { + name: "Atlas with TLS 1.2", + envVar: "ATLAS_TLS12", + certKeyFile: "", + wantErr: "", + }, + { + name: "Atlas with serverless", + envVar: "ATLAS_SERVERLESS", + certKeyFile: "", + wantErr: "", + }, + { + name: "Atlas with srv file on replica set", + envVar: "ATLAS_SRV_REPL", + certKeyFile: "", + wantErr: "", + }, + { + name: "Atlas with srv file on shared cluster", + envVar: "ATLAS_SRV_SHRD", + certKeyFile: "", + wantErr: "", + }, + { + name: "Atlas with srv file on free tier", + envVar: "ATLAS_SRV_FREE", + certKeyFile: "", + wantErr: "", + }, + { + name: "Atlas with srv file on TLS 1.1", + envVar: "ATLAS_SRV_TLS11", + certKeyFile: "", + wantErr: "", + }, + { + name: "Atlas with srv file on TLS 1.2", + envVar: "ATLAS_SRV_TLS12", + certKeyFile: "", + wantErr: "", + }, + { + name: "Atlas with srv file on serverless", + envVar: "ATLAS_SRV_SERVERLESS", + certKeyFile: "", + wantErr: "", + }, + { + name: "Atlas with X509 Dev", + envVar: "ATLAS_X509_DEV", + certKeyFile: createAtlasX509DevCertKeyFile(t), + wantErr: "", + }, + { + name: "Atlas with X509 Dev no user", + envVar: "ATLAS_X509_DEV", + certKeyFile: createAtlasX509DevCertKeyFileNoUser(t), + wantErr: "UserNotFound", + }, } - t.Logf("Finished!") + for _, tc := range cases { + t.Run(fmt.Sprintf("%s (%s)", tc.name, tc.envVar), func(t *testing.T) { + uri := os.Getenv(tc.envVar) + require.NotEmpty(t, uri, "Environment variable %s is not set", tc.envVar) + + if tc.certKeyFile != "" { + uri = addTLSCertKeyFile(t, tc.certKeyFile, uri) + } + + // Set a low server selection timeout so we fail fast if there are errors. + clientOpts := options.Client(). + ApplyURI(uri). + SetServerSelectionTimeout(1 * time.Second) + + // Run basic connectivity test. + err := runTest(context.Background(), clientOpts) + if tc.wantErr != "" { + assert.ErrorContains(t, err, tc.wantErr, "expected error to contain %q", tc.wantErr) + + return + } + require.NoError(t, err, "error running test with TLS") + + orig := clientOpts.TLSConfig + if orig == nil { + orig = &tls.Config{} + } + + insecure := orig.Clone() + insecure.InsecureSkipVerify = true + + // Run the connectivity test with InsecureSkipVerify to ensure SNI is done + // correctly even if verification is disabled. + insecureClientOpts := options.Client(). + ApplyURI(uri). + SetServerSelectionTimeout(1 * time.Second). + SetTLSConfig(insecure) + + err = runTest(context.Background(), insecureClientOpts) + require.NoError(t, err, "error running test with tlsInsecure") + }) + } } func runTest(ctx context.Context, clientOpts *options.ClientOptions) error { @@ -83,3 +189,51 @@ func runTest(ctx context.Context, clientOpts *options.ClientOptions) error { } return nil } + +func createAtlasX509DevCertKeyFile(t *testing.T) string { + t.Helper() + + b64 := os.Getenv("ATLAS_X509_DEV_CERT_BASE64") + assert.NotEmpty(t, b64, "Environment variable ATLAS_X509_DEV_CERT_BASE64 is not set") + + certBytes, err := base64.StdEncoding.DecodeString(b64) + require.NoError(t, err, "failed to decode ATLAS_X509_DEV_CERT_BASE64") + + certFilePath := filepath.Join(t.TempDir(), "atlas_x509_dev_cert.pem") + + err = os.WriteFile(certFilePath, certBytes, 0600) + require.NoError(t, err, "failed to write ATLAS_X509_DEV_CERT_BASE64 to file") + + return certFilePath +} + +func createAtlasX509DevCertKeyFileNoUser(t *testing.T) string { + t.Helper() + + b64 := os.Getenv("ATLAS_X509_DEV_CERT_NOUSER_BASE64") + assert.NotEmpty(t, b64, "Environment variable ATLAS_X509_DEV_CERT_NOUSER_BASE64 is not set") + + keyBytes, err := base64.StdEncoding.DecodeString(b64) + require.NoError(t, err, "failed to decode ATLAS_X509_DEV_CERT_NOUSER_BASE64") + + keyFilePath := filepath.Join(t.TempDir(), "atlas_x509_dev_cert_no_user.pem") + + err = os.WriteFile(keyFilePath, keyBytes, 0600) + require.NoError(t, err, "failed to write ATLAS_X509_DEV_CERT_NOUSER_BASE64 to file") + + return keyFilePath +} + +func addTLSCertKeyFile(t *testing.T, certKeyFile, uri string) string { + t.Helper() + + u, err := url.Parse(uri) + require.NoError(t, err, "failed to parse uri") + + q := u.Query() + q.Set("tlsCertificateKeyFile", filepath.ToSlash(certKeyFile)) + + u.RawQuery = q.Encode() + + return u.String() +} diff --git a/internal/require/require.go b/internal/require/require.go index 26d1885759..0b60613e3e 100644 --- a/internal/require/require.go +++ b/internal/require/require.go @@ -817,3 +817,18 @@ func WithinDurationf(t TestingT, expected time.Time, actual time.Time, delta tim } t.FailNow() } + +// NotEmpty asserts that the specified object is NOT [Empty]. +// +// if require.NotEmpty(t, obj) { +// require.Equal(t, "two", obj[1]) +// } +func NotEmpty(t TestingT, object interface{}, msgAndArgs ...interface{}) { + if h, ok := t.(tHelper); ok { + h.Helper() + } + if assert.NotEmpty(t, object, msgAndArgs...) { + return + } + t.FailNow() +}