diff --git a/internal/assert/assertbsoncore/assertions_bsoncore.go b/internal/assert/assertbsoncore/assertions_bsoncore.go new file mode 100644 index 0000000000..872192d922 --- /dev/null +++ b/internal/assert/assertbsoncore/assertions_bsoncore.go @@ -0,0 +1,47 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package assertbsoncore + +import ( + "errors" + "testing" + + "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/internal/handshake" + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" +) + +// HandshakeClientMetadata compares the client metadata in two wire messages. It +// extracts the client metadata document from each wire message and compares +// them. If the document is not found, it assumes the wire message is just the +// value of the client metadata document itself. +func HandshakeClientMetadata(t testing.TB, expectedWM, actualWM []byte) bool { + gotCommand, err := handshake.ParseClientMetadata(actualWM) + if err != nil { + if errors.Is(err, bsoncore.ErrElementNotFound) { + // If the element is not found, the actual wire message may just be the + // client metadata document itself. + gotCommand = bsoncore.Document(actualWM) + } else { + return assert.Fail(t, "error parsing actual wire message: %v", err) + } + } + + wantCommand, err := handshake.ParseClientMetadata(expectedWM) + if err != nil { + // If the element is not found, the expected wire message may just be the + // client metadata document itself. + if errors.Is(err, bsoncore.ErrElementNotFound) { + wantCommand = bsoncore.Document(expectedWM) + } else { + return assert.Fail(t, "error parsing expected wire message: %v", err) + } + } + + return assert.Equal(t, wantCommand, gotCommand, + "expected: %v, got: %v", bsoncore.Document(wantCommand), bsoncore.Document(gotCommand)) +} diff --git a/internal/handshake/handshake.go b/internal/handshake/handshake.go index c9537d3ef8..f66fd8d34f 100644 --- a/internal/handshake/handshake.go +++ b/internal/handshake/handshake.go @@ -6,8 +6,29 @@ package handshake +import ( + "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" +) + // LegacyHello is the legacy version of the hello command. var LegacyHello = "isMaster" // LegacyHelloLowercase is the lowercase, legacy version of the hello command. var LegacyHelloLowercase = "ismaster" + +func ParseClientMetadata(msg []byte) ([]byte, error) { + command := bsoncore.Document(msg) + + // Lookup the "client" field in the command document. + clientMetadataRaw, err := command.LookupErr("client") + if err != nil { + return nil, err + } + + clientMetadata, ok := clientMetadataRaw.DocumentOK() + if !ok { + return nil, err + } + + return clientMetadata, nil +} diff --git a/internal/integration/client_test.go b/internal/integration/client_test.go index 8b37f12b47..fca8e059c3 100644 --- a/internal/integration/client_test.go +++ b/internal/integration/client_test.go @@ -19,12 +19,13 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/internal/assert/assertbsoncore" "go.mongodb.org/mongo-driver/v2/internal/eventtest" "go.mongodb.org/mongo-driver/v2/internal/failpoint" - "go.mongodb.org/mongo-driver/v2/internal/handshake" "go.mongodb.org/mongo-driver/v2/internal/integration/mtest" "go.mongodb.org/mongo-driver/v2/internal/integtest" "go.mongodb.org/mongo-driver/v2/internal/require" + "go.mongodb.org/mongo-driver/v2/internal/test" "go.mongodb.org/mongo-driver/v2/mongo" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/readpref" @@ -456,26 +457,12 @@ func TestClient(t *testing.T) { err := mt.Client.Ping(context.Background(), mtest.PrimaryRp) assert.Nil(mt, err, "Ping error: %v", err) - msgPairs := mt.GetProxiedMessages() - assert.True(mt, len(msgPairs) >= 2, "expected at least 2 events sent, got %v", len(msgPairs)) + want := test.EncodeClientMetadata(mt, test.WithClientMetadataAppName("foo")) + for i := 0; i < 2; i++ { + message := mt.GetProxyCapture().TryNext() + require.NotNil(mt, message, "expected handshake message, got nil") - // First two messages should be connection handshakes: one for the heartbeat connection and the other for the - // application connection. - for idx, pair := range msgPairs[:2] { - helloCommand := handshake.LegacyHello - // Expect "hello" command name with API version. - if os.Getenv("REQUIRE_API_VERSION") == "true" { - helloCommand = "hello" - } - assert.Equal(mt, pair.CommandName, helloCommand, "expected command name %s at index %d, got %s", helloCommand, idx, - pair.CommandName) - - sent := pair.Sent - appNameVal, err := sent.Command.LookupErr("client", "application", "name") - assert.Nil(mt, err, "expected command %s at index %d to contain app name", sent.Command, idx) - appName := appNameVal.StringValue() - assert.Equal(mt, testAppName, appName, "expected app name %v at index %d, got %v", testAppName, idx, - appName) + assertbsoncore.HandshakeClientMetadata(mt, want, message.Sent.Command) } }) @@ -604,24 +591,32 @@ func TestClient(t *testing.T) { err := mt.Client.Ping(context.Background(), mtest.PrimaryRp) assert.Nil(mt, err, "Ping error: %v", err) - msgPairs := mt.GetProxiedMessages() - assert.True(mt, len(msgPairs) >= 3, "expected at least 3 events, got %v", len(msgPairs)) + proxyCapture := mt.GetProxyCapture() // The first message should be a connection handshake. - pair := msgPairs[0] - assert.Equal(mt, handshake.LegacyHello, pair.CommandName, "expected command name %s at index 0, got %s", - handshake.LegacyHello, pair.CommandName) - assert.Equal(mt, wiremessage.OpQuery, pair.Sent.OpCode, - "expected 'OP_QUERY' OpCode in wire message, got %q", pair.Sent.OpCode.String()) - - // Look for a saslContinue in the remaining proxied messages and assert that it uses the OP_MSG OpCode, as wire - // version is now known to be >= 6. + firstMessage := proxyCapture.TryNext() + require.NotNil(mt, firstMessage, "expected handshake message, got nil") + + assert.True(t, firstMessage.IsHandshake()) + + opCode := firstMessage.Sent.OpCode + assert.Equal(mt, wiremessage.OpQuery, opCode, + "expected 'OP_MSG' OpCode in wire message, got %q", opCode.String()) + + // Look for a saslContinue in the remaining proxied messages and assert that + // it uses the OP_MSG OpCode, as wire version is now known to be >= 6. var saslContinueFound bool - for _, pair := range msgPairs[1:] { - if pair.CommandName == "saslContinue" { + for { + message := proxyCapture.TryNext() + if message == nil { + break + } + + if message.CommandName == "saslContinue" { saslContinueFound = true - assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode, - "expected 'OP_MSG' OpCode in wire message, got %s", pair.Sent.OpCode.String()) + opCode := message.Sent.OpCode + assert.Equal(mt, wiremessage.OpMsg, opCode, + "expected 'OP_MSG' OpCode in wire message, got %q", opCode.String()) break } } @@ -634,18 +629,18 @@ func TestClient(t *testing.T) { err := mt.Client.Ping(context.Background(), mtest.PrimaryRp) assert.Nil(mt, err, "Ping error: %v", err) - msgPairs := mt.GetProxiedMessages() - assert.True(mt, len(msgPairs) >= 3, "expected at least 3 events, got %v", len(msgPairs)) - // First three messages should be connection handshakes: one for the heartbeat connection, another for the // application connection, and a final one for the RTT monitor connection. - for idx, pair := range msgPairs[:3] { - assert.Equal(mt, "hello", pair.CommandName, "expected command name 'hello' at index %d, got %s", idx, - pair.CommandName) + for idx := 0; idx < 3; idx++ { + message := mt.GetProxyCapture().TryNext() + require.NotNil(mt, message, "expected handshake message, got nil") + + assert.True(t, message.IsHandshake()) // Assert that appended OpCode is OP_MSG when API version is set. - assert.Equal(mt, wiremessage.OpMsg, pair.Sent.OpCode, - "expected 'OP_MSG' OpCode in wire message, got %q", pair.Sent.OpCode.String()) + opCode := message.Sent.OpCode + assert.Equal(mt, wiremessage.OpMsg, opCode, + "expected 'OP_MSG' OpCode in wire message, got %q", opCode.String()) } }) diff --git a/internal/integration/handshake_test.go b/internal/integration/handshake_test.go index f4c449e30e..65211ed9a8 100644 --- a/internal/integration/handshake_test.go +++ b/internal/integration/handshake_test.go @@ -9,18 +9,16 @@ package integration import ( "context" "os" - "reflect" - "runtime" "testing" + "time" - "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/internal/assert" - "go.mongodb.org/mongo-driver/v2/internal/handshake" + "go.mongodb.org/mongo-driver/v2/internal/assert/assertbsoncore" "go.mongodb.org/mongo-driver/v2/internal/integration/mtest" + "go.mongodb.org/mongo-driver/v2/internal/ptrutil" "go.mongodb.org/mongo-driver/v2/internal/require" + "go.mongodb.org/mongo-driver/v2/internal/test" "go.mongodb.org/mongo-driver/v2/mongo/options" - "go.mongodb.org/mongo-driver/v2/version" - "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage" ) @@ -35,40 +33,6 @@ func TestHandshakeProse(t *testing.T) { CreateCollection(false). ClientType(mtest.Proxy) - clientMetadata := func(env bson.D, info *options.DriverInfo) bson.D { - var ( - driverName = "mongo-go-driver" - driverVersion = version.Driver - platform = runtime.Version() - ) - - if info != nil { - driverName = driverName + "|" + info.Name - driverVersion = driverVersion + "|" + info.Version - platform = platform + "|" + info.Platform - } - - elems := bson.D{ - {Key: "driver", Value: bson.D{ - {Key: "name", Value: driverName}, - {Key: "version", Value: driverVersion}, - }}, - {Key: "os", Value: bson.D{ - {Key: "type", Value: runtime.GOOS}, - {Key: "architecture", Value: runtime.GOARCH}, - }}, - } - - elems = append(elems, bson.E{Key: "platform", Value: platform}) - - // If env is empty, don't include it in the metadata. - if env != nil && !reflect.DeepEqual(env, bson.D{}) { - elems = append(elems, bson.E{Key: "env", Value: env}) - } - - return elems - } - driverInfo := &options.DriverInfo{ Name: "outer-library-name", Version: "outer-library-version", @@ -88,11 +52,11 @@ func TestHandshakeProse(t *testing.T) { t.Setenv("FUNCTION_REGION", "") t.Setenv("VERCEL_REGION", "") - for _, test := range []struct { + testCases := []struct { name string env map[string]string opts *options.ClientOptions - want bson.D + want []byte }{ { name: "1. valid AWS", @@ -102,11 +66,11 @@ func TestHandshakeProse(t *testing.T) { "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "1024", }, opts: nil, - want: clientMetadata(bson.D{ - {Key: "name", Value: "aws.lambda"}, - {Key: "memory_mb", Value: 1024}, - {Key: "region", Value: "us-east-2"}, - }, nil), + want: test.EncodeClientMetadata(mt, + test.WithClientMetadataEnvName("aws.lambda"), + test.WithClientMetadataEnvMemoryMB(ptrutil.Ptr(1024)), + test.WithClientMetadataEnvRegion("us-east-2"), + ), }, { name: "2. valid Azure", @@ -114,9 +78,9 @@ func TestHandshakeProse(t *testing.T) { "FUNCTIONS_WORKER_RUNTIME": "node", }, opts: nil, - want: clientMetadata(bson.D{ - {Key: "name", Value: "azure.func"}, - }, nil), + want: test.EncodeClientMetadata(mt, + test.WithClientMetadataEnvName("azure.func"), + ), }, { name: "3. valid GCP", @@ -127,12 +91,12 @@ func TestHandshakeProse(t *testing.T) { "FUNCTION_REGION": "us-central1", }, opts: nil, - want: clientMetadata(bson.D{ - {Key: "name", Value: "gcp.func"}, - {Key: "memory_mb", Value: 1024}, - {Key: "region", Value: "us-central1"}, - {Key: "timeout_sec", Value: 60}, - }, nil), + want: test.EncodeClientMetadata(mt, + test.WithClientMetadataEnvName("gcp.func"), + test.WithClientMetadataEnvMemoryMB(ptrutil.Ptr(1024)), + test.WithClientMetadataEnvRegion("us-central1"), + test.WithClientMetadataEnvTimeoutSec(ptrutil.Ptr(60)), + ), }, { name: "4. valid Vercel", @@ -141,10 +105,10 @@ func TestHandshakeProse(t *testing.T) { "VERCEL_REGION": "cdg1", }, opts: nil, - want: clientMetadata(bson.D{ - {Key: "name", Value: "vercel"}, - {Key: "region", Value: "cdg1"}, - }, nil), + want: test.EncodeClientMetadata(mt, + test.WithClientMetadataEnvName("vercel"), + test.WithClientMetadataEnvRegion("cdg1"), + ), }, { name: "5. invalid multiple providers", @@ -153,7 +117,7 @@ func TestHandshakeProse(t *testing.T) { "FUNCTIONS_WORKER_RUNTIME": "node", }, opts: nil, - want: clientMetadata(nil, nil), + want: test.EncodeClientMetadata(mt), }, { name: "6. invalid long string", @@ -168,9 +132,9 @@ func TestHandshakeProse(t *testing.T) { }(), }, opts: nil, - want: clientMetadata(bson.D{ - {Key: "name", Value: "aws.lambda"}, - }, nil), + want: test.EncodeClientMetadata(mt, + test.WithClientMetadataEnvName("aws.lambda"), + ), }, { name: "7. invalid wrong types", @@ -179,9 +143,9 @@ func TestHandshakeProse(t *testing.T) { "AWS_LAMBDA_FUNCTION_MEMORY_SIZE": "big", }, opts: nil, - want: clientMetadata(bson.D{ - {Key: "name", Value: "aws.lambda"}, - }, nil), + want: test.EncodeClientMetadata(mt, + test.WithClientMetadataEnvName("aws.lambda"), + ), }, { name: "8. Invalid - AWS_EXECUTION_ENV does not start with \"AWS_Lambda_\"", @@ -189,51 +153,38 @@ func TestHandshakeProse(t *testing.T) { "AWS_EXECUTION_ENV": "EC2", }, opts: nil, - want: clientMetadata(nil, nil), + want: test.EncodeClientMetadata(mt), }, { name: "driver info included", opts: options.Client().SetDriverInfo(driverInfo), - want: clientMetadata(nil, driverInfo), + want: test.EncodeClientMetadata(mt, + test.WithClientMetadataDriverName("outer-library-name"), + test.WithClientMetadataDriverVersion("outer-library-version"), + test.WithClientMetadataDriverPlatform("outer-library-platform"), + ), }, - } { - test := test + } - mt.RunOpts(test.name, opts, func(mt *mtest.T) { - for k, v := range test.env { + for _, tc := range testCases { + mt.RunOpts(tc.name, opts, func(mt *mtest.T) { + for k, v := range tc.env { mt.Setenv(k, v) } - if test.opts != nil { - mt.ResetClient(test.opts) + if tc.opts != nil { + mt.ResetClient(tc.opts) } // Ping the server to ensure the handshake has completed. err := mt.Client.Ping(context.Background(), nil) require.NoError(mt, err, "Ping error: %v", err) - messages := mt.GetProxiedMessages() - handshakeMessage := messages[:1][0] - - hello := handshake.LegacyHello - if os.Getenv("REQUIRE_API_VERSION") == "true" { - hello = "hello" - } - - assert.Equal(mt, hello, handshakeMessage.CommandName) - - // Lookup the "client" field in the command document. - clientVal, err := handshakeMessage.Sent.Command.LookupErr("client") - require.NoError(mt, err, "expected command %s to contain client field", handshakeMessage.Sent.Command) - - got, ok := clientVal.DocumentOK() - require.True(mt, ok, "expected client field to be a document, got %s", clientVal.Type) + firstMessage := mt.GetProxyCapture().TryNext() + require.NotNil(mt, firstMessage, "expected to capture a proxied message") - wantBytes, err := bson.Marshal(test.want) - require.NoError(mt, err, "error marshaling want document: %v", err) - - want := bsoncore.Document(wantBytes) - assert.Equal(mt, want, got, "want: %v, got: %v", want, got) + assert.True(mt, firstMessage.IsHandshake(), "expected first message to be a handshake") + assertbsoncore.HandshakeClientMetadata(mt, tc.want, firstMessage.Sent.Command) }) } } @@ -249,13 +200,13 @@ func TestLoadBalancedConnectionHandshake(t *testing.T) { err := mt.Client.Ping(context.Background(), nil) require.NoError(mt, err, "Ping error: %v", err) - messages := mt.GetProxiedMessages() - handshakeMessage := messages[:1][0] + firstMessage := mt.GetProxyCapture().TryNext() + require.NotNil(mt, firstMessage, "expected to capture a proxied message") // Per the specifications, if loadBalanced=true, drivers MUST use the hello // command for the initial handshake and use the OP_MSG protocol. - assert.Equal(mt, "hello", handshakeMessage.CommandName) - assert.Equal(mt, wiremessage.OpMsg, handshakeMessage.Sent.OpCode) + assert.True(mt, firstMessage.IsHandshake(), "expected first message to be a handshake") + assert.Equal(mt, wiremessage.OpMsg, firstMessage.Sent.OpCode) }) opts := mtest.NewOptions().ClientType(mtest.Proxy).Topologies( @@ -269,21 +220,726 @@ func TestLoadBalancedConnectionHandshake(t *testing.T) { err := mt.Client.Ping(context.Background(), nil) require.NoError(mt, err, "Ping error: %v", err) - messages := mt.GetProxiedMessages() - handshakeMessage := messages[:1][0] + firstMessage := mt.GetProxyCapture().TryNext() + require.NotNil(mt, firstMessage, "expected to capture a proxied message") want := wiremessage.OpQuery - - hello := handshake.LegacyHello if os.Getenv("REQUIRE_API_VERSION") == "true" { - hello = "hello" - // If the server API version is requested, then we should use OP_MSG // regardless of the topology want = wiremessage.OpMsg } - assert.Equal(mt, hello, handshakeMessage.CommandName) - assert.Equal(mt, want, handshakeMessage.Sent.OpCode) + assert.True(mt, firstMessage.IsHandshake(), "expected first message to be a handshake") + assert.Equal(mt, want, firstMessage.Sent.OpCode) + }) +} + +// Test 1: Test that the driver updates metadata +// Test 2: Multiple Successive Metadata Updates +// Test 3: Multiple Successive Metadata Updates with Duplicate Data +func TestHandshakeProse_AppendMetadata_Test1_Test2_Test3(t *testing.T) { + mt := mtest.New(t) + + initialDriverInfo := options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Library Platform", + } + + testCases := []struct { + name string + driverInfo options.DriverInfo + want options.DriverInfo + + // append initialDriverInfo using client.AppendDriverInfo instead of as a + // client-level constructor. + append bool + }{ + { + name: "test1.1: append new driver info", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "2.0", + Platform: "Framework Platform", + }, + want: options.DriverInfo{ + Name: "library|framework", + Version: "1.2|2.0", + Platform: "Library Platform|Framework Platform", + }, + append: false, + }, + { + name: "test1.2: append with no platform", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "2.0", + Platform: "", + }, + want: options.DriverInfo{ + Name: "library|framework", + Version: "1.2|2.0", + Platform: "Library Platform", + }, + append: false, + }, + { + name: "test1.3: append with no version", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "", + Platform: "Framework Platform", + }, + want: options.DriverInfo{ + Name: "library|framework", + Version: "1.2", + Platform: "Library Platform|Framework Platform", + }, + append: false, + }, + { + name: "test1.4: append with name only", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "", + Platform: "", + }, + want: options.DriverInfo{ + Name: "library|framework", + Version: "1.2", + Platform: "Library Platform", + }, + append: false, + }, + { + name: "test2.1: append new driver info after appending", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "2.0", + Platform: "Framework Platform", + }, + want: options.DriverInfo{ + Name: "library|framework", + Version: "1.2|2.0", + Platform: "Library Platform|Framework Platform", + }, + append: true, + }, + { + name: "test2.2: append with no platform after appending", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "2.0", + Platform: "", + }, + want: options.DriverInfo{ + Name: "library|framework", + Version: "1.2|2.0", + Platform: "Library Platform", + }, + append: true, + }, + { + name: "test2.3: append with no version after appending", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "", + Platform: "Framework Platform", + }, + want: options.DriverInfo{ + Name: "library|framework", + Version: "1.2", + Platform: "Library Platform|Framework Platform", + }, + append: true, + }, + { + name: "test2.4: append with name only after appending", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "", + Platform: "", + }, + want: options.DriverInfo{ + Name: "library|framework", + Version: "1.2", + Platform: "Library Platform", + }, + append: true, + }, + { + name: "test3.1: same driver info after appending", + driverInfo: options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Library Platform", + }, + want: options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Library Platform", + }, + append: true, + }, + { + name: "test3.2: same version and platform after appending", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "1.2", + Platform: "Library Platform", + }, + want: options.DriverInfo{ + Name: "library|framework", + Version: "1.2", + Platform: "Library Platform", + }, + append: true, + }, + { + name: "test3.3: same name and platform after appending", + driverInfo: options.DriverInfo{ + Name: "library", + Version: "2.0", + Platform: "Library Platform", + }, + want: options.DriverInfo{ + Name: "library", + Version: "1.2|2.0", + Platform: "Library Platform", + }, + append: true, + }, + { + name: "test3.4: same name and version after appending", + driverInfo: options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Framework Platform", + }, + want: options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Library Platform|Framework Platform", + }, + append: true, + }, + { + name: "test3.5: same platform after appending", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "2.0", + Platform: "Library Platform", + }, + want: options.DriverInfo{ + Name: "library|framework", + Version: "1.2|2.0", + Platform: "Library Platform", + }, + append: true, + }, + { + name: "test3.6: same version after appending", + driverInfo: options.DriverInfo{ + Name: "framework", + Version: "1.2", + Platform: "Framework Platform", + }, + want: options.DriverInfo{ + Name: "library|framework", + Version: "1.2", + Platform: "Library Platform|Framework Platform", + }, + append: true, + }, + { + name: "test3.7: same name after appending", + driverInfo: options.DriverInfo{ + Name: "library", + Version: "2.0", + Platform: "Framework Platform", + }, + want: options.DriverInfo{ + Name: "library", + Version: "1.2|2.0", + Platform: "Library Platform|Framework Platform", + }, + append: true, + }, + } + + for _, tc := range testCases { + // Create a top-level client that can be shared among sub-tests. This is + // necessary to test appending driver info to an existing client. + opts := mtest.NewOptions().CreateClient(false).ClientType(mtest.Proxy) + + mt.RunOpts(tc.name, opts, func(mt *mtest.T) { + clientOpts := options.Client(). + // Set idle timeout to 1ms to force new connections to be created + // throughout the lifetime of the test. + SetMaxConnIdleTime(1 * time.Millisecond) + + if !tc.append { + clientOpts = clientOpts.SetDriverInfo(&initialDriverInfo) + } + + mt.ResetClient(clientOpts) + + if tc.append { + mt.Client.AppendDriverInfo(initialDriverInfo) + } + + // Send a ping command to the server and verify that the command succeeded. + err := mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // Save intercepted `client` document as `initialClientMetadata`. + initialClientMetadata := mt.GetProxyCapture().TryNext() + + require.NotNil(mt, initialClientMetadata, "expected to capture a proxied message") + assert.True(mt, initialClientMetadata.IsHandshake(), "expected first message to be a handshake") + + // Wait 5ms for the connection to become idle. + time.Sleep(20 * time.Millisecond) + + mt.Client.AppendDriverInfo(tc.driverInfo) + + // Drain the proxy + mt.GetProxyCapture().Drain() + + // Send a ping command to the server and verify that the command succeeded. + err = mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // Capture the first message sent after appending driver info. + gotMessage := mt.GetProxyCapture().TryNext() + require.NotNil(mt, gotMessage, "expected to capture a proxied message") + assert.True(mt, gotMessage.IsHandshake(), "expected first message to be a handshake") + + want := test.EncodeClientMetadata(mt, + test.WithClientMetadataDriverName(tc.want.Name), + test.WithClientMetadataDriverVersion(tc.want.Version), + test.WithClientMetadataDriverPlatform(tc.want.Platform), + ) + + assertbsoncore.HandshakeClientMetadata(mt, want, gotMessage.Sent.Command) + }) + } +} + +// Test 4: Multiple Metadata Updates with Duplicate Data. +func TestHandshakeProse_AppendMetadata_MultipleUpdatesWithDuplicateFields(t *testing.T) { + opts := mtest.NewOptions().ClientType(mtest.Proxy) + mt := mtest.New(t, opts) + + clientOpts := options.Client(). + // Set idle timeout to 1ms to force new connections to be created + // throughout the lifetime of the test. + SetMaxConnIdleTime(1 * time.Millisecond) + + // 1. Create a top-level client that can be shared among sub-tests. This is + // necessary to test appending driver info to an existing client. + mt.ResetClient(clientOpts) + + originalDriverInfo := options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Library Platform", + } + + // 2. Append initial driver info using client.AppendDriverInfo. + mt.Client.AppendDriverInfo(originalDriverInfo) + + // 3. Send a ping command to the server and verify that the command succeeded. + err := mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 4. Wait 5ms for the connection to become idle. + time.Sleep(5 * time.Millisecond) + + // 5. Append new driver info. + mt.Client.AppendDriverInfo(options.DriverInfo{ + Name: "framework", + Version: "2.0", + Platform: "Framework Platform", }) + + // Drain the proxy to ensure we only capture messages after appending. + mt.GetProxyCapture().Drain() + + // 6. Send a ping command to the server and verify that the command succeeded. + err = mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 7. Save intercepted `client` document as `clientMetadata`. + clientMetadata := mt.GetProxyCapture().TryNext() + + require.NotNil(mt, clientMetadata, "expected to capture a proxied message") + assert.True(mt, clientMetadata.IsHandshake(), "expected first message to be a handshake") + + // 8. Wait 5ms for the connection to become idle. + time.Sleep(5 * time.Millisecond) + + // Drain the proxy to ensure we only capture messages after appending. + mt.GetProxyCapture().Drain() + + // 9. Append the original driver info again. + mt.Client.AppendDriverInfo(originalDriverInfo) + + // 10. Send a ping command to the server and verify that the command succeeded. + err = mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 11. Save intercepted `client` document as `clientMetadata`. + updatedClientMetadata := mt.GetProxyCapture().TryNext() + + require.NotNil(mt, updatedClientMetadata, "expected to capture a proxied message") + assert.True(mt, updatedClientMetadata.IsHandshake(), "expected first message to be a handshake") + + assertbsoncore.HandshakeClientMetadata(mt, clientMetadata.Sent.Command, updatedClientMetadata.Sent.Command) +} + +// Test 5: Metadata is not appended if identical to initial metadata +func TestHandshakeProse_AppendMetadata_NotAppendedIfIdentical(t *testing.T) { + opts := mtest.NewOptions().ClientType(mtest.Proxy) + mt := mtest.New(t, opts) + + originalDriverInfo := options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Library Platform", + } + + clientOpts := options.Client(). + // Set idle timeout to 1ms to force new connections to be created + // throughout the lifetime of the test. + SetMaxConnIdleTime(1 * time.Millisecond). + SetDriverInfo(&originalDriverInfo) + + // 1. Create a top-level client that can be shared among sub-tests. This is + // necessary to test appending driver info to an existing client. + mt.ResetClient(clientOpts) + + // Drain the proxy to ensure we only capture messages after appending. + mt.GetProxyCapture().Drain() + + // 2. Send a ping command to the server and verify that the command succeeded. + err := mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + clientMetadata := mt.GetProxyCapture().TryNext() + require.NotNil(mt, clientMetadata, "expected to capture a proxied message") + assert.True(mt, clientMetadata.IsHandshake(), "expected first message to be a handshake") + + // 3. Wait 5ms for the connection to become idle. + time.Sleep(5 * time.Millisecond) + + // 5. Append new driver info. + mt.Client.AppendDriverInfo(options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Library Platform", + }) + + // Drain the proxy to ensure we only capture messages after appending. + mt.GetProxyCapture().Drain() + + // 6. Send a ping command to the server and verify that the command succeeded. + err = mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 7. Save intercepted `client` document as `updatedClientMetadata`. + updatedClientMetadata := mt.GetProxyCapture().TryNext() + require.NotNil(mt, updatedClientMetadata, "expected to capture a proxied message") + assert.True(mt, updatedClientMetadata.IsHandshake(), "expected first message to be a handshake") + + assertbsoncore.HandshakeClientMetadata(mt, clientMetadata.Sent.Command, updatedClientMetadata.Sent.Command) +} + +// Test 6: Metadata is not appended if identical to initial metadata (separated +// by non-identical metadata) +func TestHandshakeProse_AppendMetadata_NotAppendedIfIdentical_NonSequential(t *testing.T) { + opts := mtest.NewOptions().ClientType(mtest.Proxy) + mt := mtest.New(t, opts) + + originalDriverInfo := options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Library Platform", + } + + clientOpts := options.Client(). + // Set idle timeout to 1ms to force new connections to be created + // throughout the lifetime of the test. + SetMaxConnIdleTime(1 * time.Millisecond). + SetDriverInfo(&originalDriverInfo) + + // 1. Create a top-level client that can be shared among sub-tests. This is + // necessary to test appending driver info to an existing client. + mt.ResetClient(clientOpts) + + // Drain the proxy to ensure we only capture messages after appending. + mt.GetProxyCapture().Drain() + + // 2. Send a ping command to the server and verify that the command succeeded. + err := mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 3. Wait 5ms for the connection to become idle. + time.Sleep(5 * time.Millisecond) + + // 4. Append new driver info. + mt.Client.AppendDriverInfo(options.DriverInfo{ + Name: "framework", + Version: "1.2", + Platform: "Framework Platform", + }) + + // Drain the proxy to ensure we only capture messages after appending. + mt.GetProxyCapture().Drain() + + // 5. Send a ping command to the server and verify that the command succeeded. + err = mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 6. Save intercepted `client` document as `clientMetadata`. + clientMetadata := mt.GetProxyCapture().TryNext() + require.NotNil(mt, clientMetadata, "expected to capture a proxied message") + assert.True(mt, clientMetadata.IsHandshake(), "expected first message to be a handshake") + + // 7. Wait 5ms for the connection to become idle. + time.Sleep(5 * time.Millisecond) + + // 8. Append new driver info. + mt.Client.AppendDriverInfo(options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "Library Platform", + }) + + // Drain the proxy to ensure we only capture messages after appending. + mt.GetProxyCapture().Drain() + + // 9. Send a `ping` command to the server and verify that the command + // succeeds. + err = mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 10. Save intercepted `client` document as `updatedClientMetadata`. + updatedClientMetadata := mt.GetProxyCapture().TryNext() + require.NotNil(mt, updatedClientMetadata, "expected to capture a proxied message") + assert.True(mt, updatedClientMetadata.IsHandshake(), "expected first message to be a handshake") + + assertbsoncore.HandshakeClientMetadata(mt, clientMetadata.Sent.Command, updatedClientMetadata.Sent.Command) +} + +// Test 7: Empty strings are considered unset when appending duplicate metadata. +func TestHandshakeProse_AppendMetadata_EmptyStrings(t *testing.T) { + mt := mtest.New(t) + + testCases := []struct { + name string + initialDriverInfo options.DriverInfo + toAppendDriverInfo options.DriverInfo + }{ + { + name: "name empty", + initialDriverInfo: options.DriverInfo{ + Name: "", + Version: "1.2", + Platform: "Library Platform", + }, + toAppendDriverInfo: options.DriverInfo{ + Name: "", + Version: "1.2", + Platform: "Library Platform", + }, + }, + { + name: "version empty", + initialDriverInfo: options.DriverInfo{ + Name: "library", + Version: "", + Platform: "Library Platform", + }, + toAppendDriverInfo: options.DriverInfo{ + Name: "library", + Version: "", + Platform: "Library Platform", + }, + }, + { + name: "platform empty", + initialDriverInfo: options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "", + }, + toAppendDriverInfo: options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "", + }, + }, + } + + for _, tc := range testCases { + // Create a top-level client that can be shared among sub-tests. This is + // necessary to test appending driver info to an existing client. + opts := mtest.NewOptions().CreateClient(false).ClientType(mtest.Proxy) + mt.RunOpts(tc.name, opts, func(mt *mtest.T) { + // 1. Create a `MongoClient` instance. + clientOpts := options.Client(). + // Set idle timeout to 1ms to force new connections to be created + // throughout the lifetime of the test. + SetMaxConnIdleTime(1 * time.Millisecond) + + mt.ResetClient(clientOpts) + + // 2. Append the `DriverInfoOptions` from the selected test case from + // the initial metadata section. + mt.Client.AppendDriverInfo(tc.initialDriverInfo) + + mt.GetProxyCapture().Drain() + + // 3. Send a `ping` command to the server and verify that the command + // succeeds. + err := mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 4. Save intercepted `client` document as `initialClientMetadata`. + initialClientMetadata := mt.GetProxyCapture().TryNext() + + require.NotNil(mt, initialClientMetadata, "expected to capture a proxied message") + assert.True(mt, initialClientMetadata.IsHandshake(), "expected first message to be a handshake") + + // 5. Wait 5ms for the connection to become idle. + time.Sleep(20 * time.Millisecond) + + // 6. Append the `DriverInfoOptions` from the selected test case from + // the appended metadata section. + mt.Client.AppendDriverInfo(tc.toAppendDriverInfo) + + // Drain the proxy + mt.GetProxyCapture().Drain() + + // 7. Send a `ping` command to the server and verify the command + // succeeds. + err = mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // Capture the first message sent after appending driver info. + updatedClientMetadata := mt.GetProxyCapture().TryNext() + require.NotNil(mt, updatedClientMetadata, "expected to capture a proxied message") + assert.True(mt, updatedClientMetadata.IsHandshake(), "expected first message to be a handshake") + + assertbsoncore.HandshakeClientMetadata(mt, initialClientMetadata.Sent.Command, + updatedClientMetadata.Sent.Command) + }) + } +} + +// Test 8: Empty strings are considered unset when appending metadata identical +// to initial metadata +func TestHandshakeProse_AppendMetadata_EmptyStrings_InitializedClient(t *testing.T) { + mt := mtest.New(t) + + testCases := []struct { + name string + initialDriverInfo options.DriverInfo + toAppendDriverInfo options.DriverInfo + }{ + { + name: "name empty", + initialDriverInfo: options.DriverInfo{ + Name: "", + Version: "1.2", + Platform: "Library Platform", + }, + toAppendDriverInfo: options.DriverInfo{ + Name: "", + Version: "1.2", + Platform: "Library Platform", + }, + }, + { + name: "version empty", + initialDriverInfo: options.DriverInfo{ + Name: "library", + Version: "", + Platform: "Library Platform", + }, + toAppendDriverInfo: options.DriverInfo{ + Name: "library", + Version: "", + Platform: "Library Platform", + }, + }, + { + name: "platform empty", + initialDriverInfo: options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "", + }, + toAppendDriverInfo: options.DriverInfo{ + Name: "library", + Version: "1.2", + Platform: "", + }, + }, + } + + for _, tc := range testCases { + tc := tc // Avoid implicit memory aliasing in for loop. + + // Create a top-level client that can be shared among sub-tests. This is + // necessary to test appending driver info to an existing client. + opts := mtest.NewOptions().CreateClient(false).ClientType(mtest.Proxy) + mt.RunOpts(tc.name, opts, func(mt *mtest.T) { + // 1. Create a `MongoClient` instance. + clientOpts := options.Client(). + // Set idle timeout to 1ms to force new connections to be created + // throughout the lifetime of the test. + SetMaxConnIdleTime(1 * time.Millisecond). + SetDriverInfo(&tc.initialDriverInfo) + + mt.ResetClient(clientOpts) + + // 2. Send a `ping` command to the server and verify that the command + // succeeds. + err := mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 3. Save intercepted `client` document as `initialClientMetadata`. + initialClientMetadata := mt.GetProxyCapture().TryNext() + + require.NotNil(mt, initialClientMetadata, "expected to capture a proxied message") + assert.True(mt, initialClientMetadata.IsHandshake(), "expected first message to be a handshake") + + // 4. Wait 5ms for the connection to become idle. + time.Sleep(20 * time.Millisecond) + + // 5. Append the `DriverInfoOptions` from the selected test case from + // the appended metadata section. + mt.Client.AppendDriverInfo(tc.toAppendDriverInfo) + + // Drain the proxy + mt.GetProxyCapture().Drain() + + // 6. Send a `ping` command to the server and verify the command + // succeeds. + err = mt.Client.Ping(context.Background(), nil) + require.NoError(mt, err, "Ping error: %v", err) + + // 7. Store the response as `updatedClientMetadata`. + updatedClientMetadata := mt.GetProxyCapture().TryNext() + require.NotNil(mt, updatedClientMetadata, "expected to capture a proxied message") + assert.True(mt, updatedClientMetadata.IsHandshake(), "expected first message to be a handshake") + + // 8. Assert that `initialClientMetadata` is identical to `updatedClientMetadata`. + assertbsoncore.HandshakeClientMetadata(mt, initialClientMetadata.Sent.Command, + updatedClientMetadata.Sent.Command) + }) + } } diff --git a/internal/integration/mtest/mongotest.go b/internal/integration/mtest/mongotest.go index ce9823b89a..db8c664a41 100644 --- a/internal/integration/mtest/mongotest.go +++ b/internal/integration/mtest/mongotest.go @@ -338,13 +338,13 @@ func (t *T) FilterFailedEvents(filter func(*event.CommandFailedEvent) bool) { t.failed = newEvents } -// GetProxiedMessages returns the messages proxied to the server by the test. If the client type is not Proxy, this -// returns nil. -func (t *T) GetProxiedMessages() []*ProxyMessage { +// GetProxyCapture returns the ProxyCapture used by the test. If the client +// type is not Proxy, this returns nil. +func (t *T) GetProxyCapture() *ProxyCapture { if t.proxyDialer == nil { return nil } - return t.proxyDialer.Messages() + return t.proxyDialer.proxyCapture } // NumberConnectionsCheckedOut returns the number of connections checked out from the test Client. diff --git a/internal/integration/mtest/proxy_capture.go b/internal/integration/mtest/proxy_capture.go new file mode 100644 index 0000000000..5cd43cc984 --- /dev/null +++ b/internal/integration/mtest/proxy_capture.go @@ -0,0 +1,53 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package mtest + +import ( + "sync" +) + +// ProxyCapture provides a FIFO channel for handshake messages passed +// through the mtest proxyDialer. +type ProxyCapture struct { + messages chan *ProxyMessage + mu sync.Mutex +} + +func newProxyCapture(bufferSize int) *ProxyCapture { + return &ProxyCapture{ + messages: make(chan *ProxyMessage, bufferSize), + } +} + +func (hc *ProxyCapture) Capture(msg *ProxyMessage) { + hc.mu.Lock() + defer hc.mu.Unlock() + + hc.messages <- msg +} + +func (hc *ProxyCapture) TryNext() *ProxyMessage { + select { + case msg := <-hc.messages: + return msg + default: + return nil + } +} + +// Drain removes all messages from the channel and returns them as a slice. +func (hc *ProxyCapture) Drain() []*ProxyMessage { + messages := []*ProxyMessage{} + for { + select { + case msg := <-hc.messages: + messages = append(messages, msg) + default: + return messages + } + } +} diff --git a/internal/integration/mtest/proxy_dialer.go b/internal/integration/mtest/proxy_dialer.go index 7f17dbbdb1..0d980c406c 100644 --- a/internal/integration/mtest/proxy_dialer.go +++ b/internal/integration/mtest/proxy_dialer.go @@ -11,9 +11,11 @@ import ( "errors" "fmt" "net" + "os" "sync" "time" + "go.mongodb.org/mongo-driver/v2/internal/handshake" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" ) @@ -32,7 +34,6 @@ type proxyDialer struct { *net.Dialer sync.Mutex - messages []*ProxyMessage // sentMap temporarily stores the message sent to the server using the requestID so it can map requests to their // responses. sentMap sync.Map @@ -40,13 +41,16 @@ type proxyDialer struct { // differ. This can happen if a connection is dialed to a host name, in which case the reported remote address will // be the resolved IP address. addressTranslations sync.Map + + proxyCapture *ProxyCapture } var _ options.ContextDialer = (*proxyDialer)(nil) func newProxyDialer() *proxyDialer { return &proxyDialer{ - Dialer: &net.Dialer{Timeout: 30 * time.Second}, + Dialer: &net.Dialer{Timeout: 30 * time.Second}, + proxyCapture: newProxyCapture(100), } } @@ -121,21 +125,10 @@ func (p *proxyDialer) storeReceivedMessage(wm []byte, addr string) error { Sent: sent, Received: parsed, } - p.messages = append(p.messages, msgPair) + p.proxyCapture.Capture(msgPair) return nil } -// Messages returns a slice of proxied messages. This slice is a copy of the messages proxied so far and will not be -// updated for messages proxied after this call. -func (p *proxyDialer) Messages() []*ProxyMessage { - p.Lock() - defer p.Unlock() - - copiedMessages := make([]*ProxyMessage, len(p.messages)) - copy(copiedMessages, p.messages) - return copiedMessages -} - // proxyConn is a net.Conn that wraps a network connection. All messages sent/received through a proxyConn are stored // in the associated proxyDialer and are forwarded over the wrapped connection. Errors encountered when parsing and // storing wire messages are wrapped to add context, while errors returned from the underlying network connection are @@ -184,3 +177,12 @@ func (pc *proxyConn) Read(buffer []byte) (int, error) { return n, nil } + +func (msg *ProxyMessage) IsHandshake() bool { + hello := handshake.LegacyHello + if os.Getenv("REQUIRE_API_VERSION") == "true" { + hello = "hello" + } + + return hello == msg.CommandName +} diff --git a/internal/integration/sdam_prose_test.go b/internal/integration/sdam_prose_test.go index 274d6c0abb..ac9572ac02 100644 --- a/internal/integration/sdam_prose_test.go +++ b/internal/integration/sdam_prose_test.go @@ -69,7 +69,7 @@ func TestSDAMProse(t *testing.T) { } start := time.Now() time.Sleep(2 * time.Second) - messages := mt.GetProxiedMessages() + messages := mt.GetProxyCapture().Drain() duration := time.Since(start) hosts, err := mongoutil.HostsFromURI(mtest.ClusterURI()) diff --git a/internal/integration/unified/client_operation_execution.go b/internal/integration/unified/client_operation_execution.go index 86f161761d..1631d88ba7 100644 --- a/internal/integration/unified/client_operation_execution.go +++ b/internal/integration/unified/client_operation_execution.go @@ -307,6 +307,34 @@ func executeClientBulkWrite(ctx context.Context, operation *operation) (*operati return newDocumentResult(rawBuilder.Build(), err), nil } +func executeAppendMetadata(ctx context.Context, op *operation) (*operationResult, error) { + client, err := entities(ctx).client(op.Object) + if err != nil { + return nil, fmt.Errorf("error getting client entity: %w", err) + } + + elems, err := op.Arguments.Elements() + if err != nil { + return nil, fmt.Errorf("error getting appendMetadata arguments: %w", err) + } + + driverInfo := options.DriverInfo{} + for _, elem := range elems { + key := elem.Key() + val := elem.Value() + + if key == "driverInfoOptions" { + if err = bson.Unmarshal(val.Value, &driverInfo); err != nil { + return nil, fmt.Errorf("error unmarshaling driverInfoOptions: %w", err) + } + } + } + + client.AppendDriverInfo(driverInfo) + + return newEmptyResult(), nil +} + func createClientInsertOneModel(value bson.Raw) (*mongo.ClientBulkWrite, error) { var v struct { Namespace string diff --git a/internal/integration/unified/operation.go b/internal/integration/unified/operation.go index 9baf785dcb..1b591d66af 100644 --- a/internal/integration/unified/operation.go +++ b/internal/integration/unified/operation.go @@ -128,7 +128,9 @@ func (op *operation) run(ctx context.Context, loopDone <-chan struct{}) (*operat // executeWithTransaction internally verifies results/errors for each operation, so it doesn't return a result. return newEmptyResult(), executeWithTransaction(ctx, op, loopDone) - // Client operations + // Client operations + case "appendMetadata": + return executeAppendMetadata(ctx, op) case "createChangeStream": return executeCreateChangeStream(ctx, op) case "listDatabases": diff --git a/internal/integration/unified/unified_spec_test.go b/internal/integration/unified/unified_spec_test.go index 9021d03e75..03b4139f90 100644 --- a/internal/integration/unified/unified_spec_test.go +++ b/internal/integration/unified/unified_spec_test.go @@ -37,6 +37,7 @@ var ( "run-command/tests/unified", "index-management/tests", "atlas-data-lake-testing/tests/unified", + "mongodb-handshake/tests/unified", } failDirectories = []string{ "unified-test-format/tests/valid-fail", diff --git a/internal/test/client_metadata.go b/internal/test/client_metadata.go new file mode 100644 index 0000000000..fb70977a65 --- /dev/null +++ b/internal/test/client_metadata.go @@ -0,0 +1,199 @@ +// Copyright (C) MongoDB, Inc. 2025-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package test + +import ( + "runtime" + "testing" + + "go.mongodb.org/mongo-driver/v2/bson" + "go.mongodb.org/mongo-driver/v2/internal/require" + "go.mongodb.org/mongo-driver/v2/version" +) + +type clientMetadataOptions struct { + appName string + driverName string + driverVersion string + driverPlatform string + envName string + envTimeoutSec *int + envMemoryMB *int + envRegion string +} + +// ClientMetadataOption represents a configuration option for building client +// metadata. +type ClientMetadataOption func(*clientMetadataOptions) + +// WithClientMetadataAppName sets the application name included in client metadata. +func WithClientMetadataAppName(name string) ClientMetadataOption { + return func(o *clientMetadataOptions) { + o.appName = name + } +} + +// WithClientMetadataDriverName sets the driver name (e.g., "mongo-go-driver"). +func WithClientMetadataDriverName(name string) ClientMetadataOption { + return func(o *clientMetadataOptions) { + o.driverName = name + } +} + +// WithClientMetadataDriverVersion sets the driver version (e.g., "1.16.0"). +func WithClientMetadataDriverVersion(version string) ClientMetadataOption { + return func(o *clientMetadataOptions) { + o.driverVersion = version + } +} + +// WithClientMetadataDriverPlatform sets the driver platform string +// (e.g., "go1.22.5 gc linux/amd64"). +func WithClientMetadataDriverPlatform(platform string) ClientMetadataOption { + return func(o *clientMetadataOptions) { + o.driverPlatform = platform + } +} + +// WithClientMetadataEnvName sets the execution environment name +// (e.g., "AWS Lambda", "GCP Cloud Functions", "Kubernetes"). +func WithClientMetadataEnvName(name string) ClientMetadataOption { + return func(o *clientMetadataOptions) { + o.envName = name + } +} + +// WithClientMetadataEnvTimeoutSec sets the execution timeout in seconds. +// Pass nil to indicate "unspecified" or "not applicable". +func WithClientMetadataEnvTimeoutSec(timeoutSec *int) ClientMetadataOption { + return func(o *clientMetadataOptions) { + o.envTimeoutSec = timeoutSec + } +} + +// WithClientMetadataEnvMemoryMB sets the memory limit in megabytes. +// Pass nil to indicate "unspecified" or "not applicable". +func WithClientMetadataEnvMemoryMB(memoryMB *int) ClientMetadataOption { + return func(o *clientMetadataOptions) { + o.envMemoryMB = memoryMB + } +} + +// WithClientMetadataEnvRegion sets the deployment/region identifier +// (e.g., "us-east-1", "europe-west1"). +func WithClientMetadataEnvRegion(region string) ClientMetadataOption { + return func(o *clientMetadataOptions) { + o.envRegion = region + } +} + +// EncodeClientMetadata constructs the WM byte slice that represents the client +// metadata document for the given options with the intent of comparing to an +// actual handshake wire message: +// +// { +// application: { +// name: "" +// }, +// driver: { +// name: "", +// version: "" +// }, +// platform: "", +// os: { +// type: "", +// name: "", +// architecture: "", +// version: "" +// }, +// env: { +// name: "", +// timeout_sec: 42, +// memory_mb: 1024, +// region: "", +// container: { +// runtime: "", +// orchestrator: "" +// } +// } +// } +// +// This function was not put in mtest since it could be used in non-integration +// test conditions. +func EncodeClientMetadata(t testing.TB, opts ...ClientMetadataOption) []byte { + t.Helper() + + cfg := clientMetadataOptions{} + for _, apply := range opts { + apply(&cfg) + } + + var ( + driverName = "mongo-go-driver" // Default + driverVersion = version.Driver + platform = runtime.Version() // Default + ) + + if cfg.driverName != "" { + driverName = driverName + "|" + cfg.driverName + } + + if cfg.driverVersion != "" { + driverVersion = driverVersion + "|" + cfg.driverVersion + } + + if cfg.driverPlatform != "" { + platform = platform + "|" + cfg.driverPlatform + } + + elems := bson.D{} + + if cfg.appName != "" { + elems = append(elems, bson.E{Key: "application", Value: bson.D{ + {Key: "name", Value: cfg.appName}, + }}) + } + + elems = append(elems, bson.D{ + {Key: "driver", Value: bson.D{ + {Key: "name", Value: driverName}, + {Key: "version", Value: driverVersion}, + }}, + {Key: "os", Value: bson.D{ + {Key: "type", Value: runtime.GOOS}, + {Key: "architecture", Value: runtime.GOARCH}, + }}, + }...) + + elems = append(elems, bson.E{Key: "platform", Value: platform}) + + envElems := bson.D{} + if cfg.envName != "" { + envElems = append(envElems, bson.E{Key: "name", Value: cfg.envName}) + } + + if cfg.envMemoryMB != nil { + envElems = append(envElems, bson.E{Key: "memory_mb", Value: *cfg.envMemoryMB}) + } + + if cfg.envRegion != "" { + envElems = append(envElems, bson.E{Key: "region", Value: cfg.envRegion}) + } + + if cfg.envTimeoutSec != nil { + envElems = append(envElems, bson.E{Key: "timeout_sec", Value: *cfg.envTimeoutSec}) + } + + if len(envElems) > 0 { + elems = append(elems, bson.E{Key: "env", Value: envElems}) + } + + bytes, err := bson.Marshal(elems) + require.NoError(t, err) + + return bytes +} diff --git a/mongo/client.go b/mongo/client.go index f0480a0c72..3f5b278321 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -11,6 +11,8 @@ import ( "errors" "fmt" "net/http" + "sync" + "sync/atomic" "time" "go.mongodb.org/mongo-driver/v2/bson" @@ -56,24 +58,26 @@ var ( // The Client type opens and closes connections automatically and maintains a pool of idle connections. For // connection pool configuration options, see documentation for the ClientOptions type in the mongo/options package. type Client struct { - id uuid.UUID - deployment driver.Deployment - localThreshold time.Duration - retryWrites bool - retryReads bool - clock *session.ClusterClock - readPreference *readpref.ReadPref - readConcern *readconcern.ReadConcern - writeConcern *writeconcern.WriteConcern - bsonOpts *options.BSONOptions - registry *bson.Registry - monitor *event.CommandMonitor - serverAPI *driver.ServerAPIOptions - serverMonitor *event.ServerMonitor - sessionPool *session.Pool - timeout *time.Duration - httpClient *http.Client - logger *logger.Logger + id uuid.UUID + deployment driver.Deployment + localThreshold time.Duration + retryWrites bool + retryReads bool + clock *session.ClusterClock + readPreference *readpref.ReadPref + readConcern *readconcern.ReadConcern + writeConcern *writeconcern.WriteConcern + bsonOpts *options.BSONOptions + registry *bson.Registry + monitor *event.CommandMonitor + serverAPI *driver.ServerAPIOptions + serverMonitor *event.ServerMonitor + sessionPool *session.Pool + timeout *time.Duration + httpClient *http.Client + logger *logger.Logger + currentDriverInfo *atomic.Pointer[options.DriverInfo] + seenDriverInfo sync.Map // in-use encryption fields isAutoEncryptionSet bool @@ -132,7 +136,11 @@ func newClient(opts ...*options.ClientOptions) (*Client, error) { if err != nil { return nil, err } - client := &Client{id: id} + + client := &Client{ + id: id, + currentDriverInfo: &atomic.Pointer[options.DriverInfo]{}, + } // ClusterClock client.clock = new(session.ClusterClock) @@ -217,7 +225,16 @@ func newClient(opts ...*options.ClientOptions) (*Client, error) { } } - cfg, err := topology.NewConfigFromOptionsWithAuthenticator(clientOpts, client.clock, client.authenticator) + if clientOpts.DriverInfo != nil { + client.AppendDriverInfo(*clientOpts.DriverInfo) + } + + cfg, err := topology.NewAuthenticatorConfig(client.authenticator, + topology.WithAuthConfigClock(client.clock), + topology.WithAuthConfigClientOptions(clientOpts), + topology.WithAuthConfigDriverInfo(client.currentDriverInfo), + ) + if err != nil { return nil, err } @@ -294,6 +311,50 @@ func (c *Client) connect() error { return nil } +// AppendDriverInfo appends the provided DriverInfo to the driver information +// that will be sent to the server in handshake requests when establishing new +// connections. The provided info will overwrite any existing values. +// +// AppendsDriverInfo appends the provided [options.DriverInfo] to the metadata +// (e.g. name, version, platform) that will be sent to the server in handshake +// requests when establishing new connections. The provided info will overwrite +// any existing values. +// +// Repeated calls to appendMetadata with equivalent DriverInfo is a no-op. +// +// Metadata is limited to 512 bytes; any excess will be truncated. +func (c *Client) AppendDriverInfo(info options.DriverInfo) { + if _, loaded := c.seenDriverInfo.LoadOrStore(info, struct{}{}); loaded { + return + } + + if old := c.currentDriverInfo.Load(); old != nil { + if old.Name != "" && info.Name != "" && old.Name != info.Name { + info.Name = old.Name + "|" + info.Name + } else if old.Name != "" { + info.Name = old.Name + } + + if old.Version != "" && info.Version != "" && old.Version != info.Version { + info.Version = old.Version + "|" + info.Version + } else if old.Version != "" { + info.Version = old.Version + } + + if old.Platform != "" && info.Platform != "" && old.Platform != info.Platform { + info.Platform = old.Platform + "|" + info.Platform + } else if old.Platform != "" { + info.Platform = old.Platform + } + } + + // Copy-on-write so that the info stored in the client is immutable. + infoCopy := new(options.DriverInfo) + *infoCopy = info + + c.currentDriverInfo.Store(infoCopy) +} + // Disconnect closes sockets to the topology referenced by this Client. It will // shut down any monitoring goroutines, close the idle connection pool, and will // wait until all the in use connections have been returned to the connection diff --git a/x/mongo/driver/topology/server.go b/x/mongo/driver/topology/server.go index b665387404..a408aee561 100644 --- a/x/mongo/driver/topology/server.go +++ b/x/mongo/driver/topology/server.go @@ -807,9 +807,18 @@ func (s *Server) createConnection() *connection { opts := copyConnectionOpts(s.cfg.connectionOpts) opts = append(opts, WithHandshaker(func(Handshaker) Handshaker { - return operation.NewHello().AppName(s.cfg.appname).Compressors(s.cfg.compressionOpts). - ServerAPI(s.cfg.serverAPI).OuterLibraryName(s.cfg.outerLibraryName). - OuterLibraryVersion(s.cfg.outerLibraryVersion).OuterLibraryPlatform(s.cfg.outerLibraryPlatform) + handshaker := operation.NewHello().AppName(s.cfg.appname).Compressors(s.cfg.compressionOpts). + ServerAPI(s.cfg.serverAPI) + + if s.cfg.driverInfo != nil { + driverInfo := s.cfg.driverInfo.Load() + if driverInfo != nil { + handshaker = handshaker.OuterLibraryName(driverInfo.Name).OuterLibraryVersion(driverInfo.Version). + OuterLibraryPlatform(driverInfo.Platform) + } + } + + return handshaker }), // Override any monitors specified in options with nil to avoid monitoring heartbeats. WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { return nil }), diff --git a/x/mongo/driver/topology/server_options.go b/x/mongo/driver/topology/server_options.go index 490834cbef..297cafc701 100644 --- a/x/mongo/driver/topology/server_options.go +++ b/x/mongo/driver/topology/server_options.go @@ -7,11 +7,13 @@ package topology import ( + "sync/atomic" "time" "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/logger" + "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/connstring" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/session" @@ -32,6 +34,7 @@ type serverConfig struct { monitoringDisabled bool serverAPI *driver.ServerAPIOptions loadBalanced bool + driverInfo *atomic.Pointer[options.DriverInfo] // Connection pool options. maxConns uint64 @@ -41,11 +44,6 @@ type serverConfig struct { logger *logger.Logger poolMaxIdleTime time.Duration poolMaintainInterval time.Duration - - // Fields provided by a library that wraps the Go Driver. - outerLibraryName string - outerLibraryVersion string - outerLibraryPlatform string } func newServerConfig(connectTimeout time.Duration, opts ...ServerOption) *serverConfig { @@ -101,27 +99,12 @@ func WithServerAppName(fn func(string) string) ServerOption { } } -// WithOuterLibraryName configures the name for the outer library to include -// in the drivers section of the handshake metadata. -func WithOuterLibraryName(fn func(string) string) ServerOption { - return func(cfg *serverConfig) { - cfg.outerLibraryName = fn(cfg.outerLibraryName) - } -} - -// WithOuterLibraryVersion configures the version for the outer library to -// include in the drivers section of the handshake metadata. -func WithOuterLibraryVersion(fn func(string) string) ServerOption { - return func(cfg *serverConfig) { - cfg.outerLibraryVersion = fn(cfg.outerLibraryVersion) - } -} - -// WithOuterLibraryPlatform configures the platform for the outer library to -// include in the platform section of the handshake metadata. -func WithOuterLibraryPlatform(fn func(string) string) ServerOption { +// WithDriverInfo sets at atomic pointer to the server configuration, which will +// be used to create the "driver" section on handshake commands. An atomic +// pointer is used so that the driver info can be updated concurrently. +func WithDriverInfo(info *atomic.Pointer[options.DriverInfo]) ServerOption { return func(cfg *serverConfig) { - cfg.outerLibraryPlatform = fn(cfg.outerLibraryPlatform) + cfg.driverInfo = info } } diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index 2ddc7434bd..2e46d2c7ef 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -11,6 +11,7 @@ import ( "crypto/tls" "fmt" "net/http" + "sync/atomic" "time" "go.mongodb.org/mongo-driver/v2/event" @@ -139,14 +140,55 @@ func NewConfig(opts *options.ClientOptions, clock *session.ClusterClock) (*Confi return nil, fmt.Errorf("error creating authenticator: %w", err) } } - return NewConfigFromOptionsWithAuthenticator(opts, clock, authenticator) + return NewAuthenticatorConfig(authenticator, + WithAuthConfigClock(clock), + WithAuthConfigClientOptions(opts), + ) +} + +type authConfigOptions struct { + clock *session.ClusterClock + opts *options.ClientOptions + driverInfo *atomic.Pointer[options.DriverInfo] +} + +// AuthConfigOption is a function that configures authConfigOptions. +type AuthConfigOption func(*authConfigOptions) + +// WithAuthConfigClock sets the cluster clock in authConfigOptions. +func WithAuthConfigClock(clock *session.ClusterClock) AuthConfigOption { + return func(co *authConfigOptions) { + co.clock = clock + } +} + +// WithAuthConfigClientOptions sets the client options in authConfigOptions. +func WithAuthConfigClientOptions(opts *options.ClientOptions) AuthConfigOption { + return func(co *authConfigOptions) { + co.opts = opts + } +} + +// WithAuthConfigDriverInfo sets the driver info in authConfigOptions. +func WithAuthConfigDriverInfo(driverInfo *atomic.Pointer[options.DriverInfo]) AuthConfigOption { + return func(co *authConfigOptions) { + co.driverInfo = driverInfo + } } -// NewConfigFromOptionsWithAuthenticator will translate data from client options into a +// NewAuthenticatorConfig will translate data from client options into a // topology config for building non-default deployments. Server and topology // options are not honored if a custom deployment is used. It uses a passed in // authenticator to authenticate the connection. -func NewConfigFromOptionsWithAuthenticator(opts *options.ClientOptions, clock *session.ClusterClock, authenticator driver.Authenticator) (*Config, error) { +func NewAuthenticatorConfig(authenticator driver.Authenticator, clientOpts ...AuthConfigOption) (*Config, error) { + settings := authConfigOptions{} + for _, apply := range clientOpts { + apply(&settings) + } + + opts := settings.opts + clock := settings.clock + var serverAPI *driver.ServerAPIOptions if err := opts.Validate(); err != nil { @@ -200,23 +242,8 @@ func NewConfigFromOptionsWithAuthenticator(opts *options.ClientOptions, clock *s })) } - var outerLibraryName, outerLibraryVersion, outerLibraryPlatform string - if opts.DriverInfo != nil { - outerLibraryName = opts.DriverInfo.Name - outerLibraryVersion = opts.DriverInfo.Version - outerLibraryPlatform = opts.DriverInfo.Platform - - serverOpts = append(serverOpts, WithOuterLibraryName(func(string) string { - return outerLibraryName - })) - - serverOpts = append(serverOpts, WithOuterLibraryVersion(func(string) string { - return outerLibraryVersion - })) - - serverOpts = append(serverOpts, WithOuterLibraryPlatform(func(string) string { - return outerLibraryPlatform - })) + if settings.driverInfo != nil { + serverOpts = append(serverOpts, WithDriverInfo(settings.driverInfo)) } // Compressors & ZlibLevel @@ -257,15 +284,18 @@ func NewConfigFromOptionsWithAuthenticator(opts *options.ClientOptions, clock *s var handshaker func(driver.Handshaker) driver.Handshaker if authenticator != nil { handshakeOpts := &auth.HandshakeOptions{ - AppName: appName, - Authenticator: authenticator, - Compressors: comps, - ServerAPI: serverAPI, - LoadBalanced: loadBalanced, - ClusterClock: clock, - OuterLibraryName: outerLibraryName, - OuterLibraryVersion: outerLibraryVersion, - OuterLibraryPlatform: outerLibraryPlatform, + AppName: appName, + Authenticator: authenticator, + Compressors: comps, + ServerAPI: serverAPI, + LoadBalanced: loadBalanced, + ClusterClock: clock, + } + + if driverInfo := settings.driverInfo; driverInfo != nil && driverInfo.Load() != nil { + handshakeOpts.OuterLibraryName = driverInfo.Load().Name + handshakeOpts.OuterLibraryVersion = driverInfo.Load().Version + handshakeOpts.OuterLibraryPlatform = driverInfo.Load().Platform } if opts.Auth.AuthMechanism == "" { @@ -287,6 +317,13 @@ func NewConfigFromOptionsWithAuthenticator(opts *options.ClientOptions, clock *s } else { handshaker = func(driver.Handshaker) driver.Handshaker { + var outerLibraryName, outerLibraryVersion, outerLibraryPlatform string + if driverInfo := settings.driverInfo; driverInfo != nil && driverInfo.Load() != nil { + outerLibraryName = driverInfo.Load().Name + outerLibraryVersion = driverInfo.Load().Version + outerLibraryPlatform = driverInfo.Load().Platform + } + return operation.NewHello(). AppName(appName). Compressors(comps). diff --git a/x/mongo/driver/topology/topology_options_test.go b/x/mongo/driver/topology/topology_options_test.go index 680aa638a7..402503b300 100644 --- a/x/mongo/driver/topology/topology_options_test.go +++ b/x/mongo/driver/topology/topology_options_test.go @@ -149,7 +149,7 @@ func TestAuthenticateToAnything(t *testing.T) { opt := options.Client().SetAuth(options.Credential{Username: "foo", Password: "bar"}) err := tc.set(opt) require.NoError(t, err, "error setting authenticateToAnything: %v", err) - cfg, err := NewConfigFromOptionsWithAuthenticator(opt, nil, &testAuthenticator{}) + cfg, err := NewAuthenticatorConfig(nil, WithAuthConfigClientOptions(opt)) require.NoError(t, err, "error constructing topology config: %v", err) srvrCfg := newServerConfig(defaultConnectionTimeout, cfg.ServerOpts...)