diff --git a/mongo/client.go b/mongo/client.go index 04ebcb4eb2..c9859eff23 100644 --- a/mongo/client.go +++ b/mongo/client.go @@ -215,34 +215,13 @@ func newClient(opts ...options.Lister[options.ClientOptions]) (*Client, error) { } if args.Auth != nil { - var oidcMachineCallback auth.OIDCCallback - if args.Auth.OIDCMachineCallback != nil { - oidcMachineCallback = func(ctx context.Context, oargs *driver.OIDCArgs) (*driver.OIDCCredential, error) { - cred, err := args.Auth.OIDCMachineCallback(ctx, convertOIDCArgs(oargs)) - return (*driver.OIDCCredential)(cred), err - } - } - - var oidcHumanCallback auth.OIDCCallback - if args.Auth.OIDCHumanCallback != nil { - oidcHumanCallback = func(ctx context.Context, oargs *driver.OIDCArgs) (*driver.OIDCCredential, error) { - cred, err := args.Auth.OIDCHumanCallback(ctx, convertOIDCArgs(oargs)) - return (*driver.OIDCCredential)(cred), err - } - } - - // Create an authenticator for the client - client.authenticator, err = auth.CreateAuthenticator(args.Auth.AuthMechanism, &auth.Cred{ - Source: args.Auth.AuthSource, - Username: args.Auth.Username, - Password: args.Auth.Password, - PasswordSet: args.Auth.PasswordSet, - Props: args.Auth.AuthMechanismProperties, - OIDCMachineCallback: oidcMachineCallback, - OIDCHumanCallback: oidcHumanCallback, - }, args.HTTPClient) + client.authenticator, err = auth.CreateAuthenticator( + args.Auth.AuthMechanism, + topology.ConvertCreds(args.Auth), + args.HTTPClient, + ) if err != nil { - return nil, err + return nil, fmt.Errorf("error creating authenticator: %w", err) } } @@ -274,20 +253,7 @@ func newClient(opts ...options.Lister[options.ClientOptions]) (*Client, error) { return client, nil } -// convertOIDCArgs converts the internal *driver.OIDCArgs into the equivalent -// public type *options.OIDCArgs. -func convertOIDCArgs(args *driver.OIDCArgs) *options.OIDCArgs { - if args == nil { - return nil - } - return &options.OIDCArgs{ - Version: args.Version, - IDPInfo: (*options.IDPInfo)(args.IDPInfo), - RefreshToken: args.RefreshToken, - } -} - -// connect initializes the Client by starting background monitoring goroutines. +// Connect initializes the Client by starting background monitoring goroutines. // If the Client was created using the NewClient function, this method must be called before a Client can be used. // // Connect starts background goroutines to monitor the state of the deployment and does not do any I/O in the main diff --git a/mongo/client_test.go b/mongo/client_test.go index ee56449ce6..72e3ee0962 100644 --- a/mongo/client_test.go +++ b/mongo/client_test.go @@ -11,7 +11,6 @@ import ( "errors" "math" "os" - "reflect" "testing" "time" @@ -20,13 +19,11 @@ import ( "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/integtest" "go.mongodb.org/mongo-driver/v2/internal/mongoutil" - "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/mongo/readconcern" "go.mongodb.org/mongo-driver/v2/mongo/readpref" "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" "go.mongodb.org/mongo-driver/v2/tag" - "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/mongocrypt" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/session" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/topology" @@ -519,76 +516,3 @@ func TestClient(t *testing.T) { assert.Equal(t, errmsg, err.Error(), "expected error %v, got %v", errmsg, err.Error()) }) } - -// Test that convertOIDCArgs exhaustively copies all fields of a driver.OIDCArgs -// into an options.OIDCArgs. -func TestConvertOIDCArgs(t *testing.T) { - refreshToken := "test refresh token" - - testCases := []struct { - desc string - args *driver.OIDCArgs - }{ - { - desc: "populated args", - args: &driver.OIDCArgs{ - Version: 9, - IDPInfo: &driver.IDPInfo{ - Issuer: "test issuer", - ClientID: "test client ID", - RequestScopes: []string{"test scope 1", "test scope 2"}, - }, - RefreshToken: &refreshToken, - }, - }, - { - desc: "nil", - args: nil, - }, - { - desc: "nil IDPInfo and RefreshToken", - args: &driver.OIDCArgs{ - Version: 9, - IDPInfo: nil, - RefreshToken: nil, - }, - }, - } - - for _, tc := range testCases { - tc := tc // Capture range variable. - - t.Run(tc.desc, func(t *testing.T) { - t.Parallel() - - got := convertOIDCArgs(tc.args) - - if tc.args == nil { - assert.Nil(t, got, "expected nil when input is nil") - return - } - - require.Equal(t, - 3, - reflect.ValueOf(*tc.args).NumField(), - "expected the driver.OIDCArgs struct to have exactly 3 fields") - require.Equal(t, - 3, - reflect.ValueOf(*got).NumField(), - "expected the options.OIDCArgs struct to have exactly 3 fields") - - assert.Equal(t, - tc.args.Version, - got.Version, - "expected Version field to be equal") - assert.EqualValues(t, - tc.args.IDPInfo, - got.IDPInfo, - "expected IDPInfo field to be convertible to equal values") - assert.Equal(t, - tc.args.RefreshToken, - got.RefreshToken, - "expected RefreshToken field to be equal") - }) - } -} diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index 9fe4e36c1f..40db452be3 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -90,9 +90,9 @@ type ContextDialer interface { // The SERVICE_HOST and CANONICALIZE_HOST_NAME properties must not be used at the same time on Linux and Darwin // systems. // -// AuthSource: the name of the database to use for authentication. This defaults to "$external" for MONGODB-X509, -// GSSAPI, and PLAIN and "admin" for all other mechanisms. This can also be set through the "authSource" URI option -// (e.g. "authSource=otherDb"). +// AuthSource: the name of the database to use for authentication. This defaults to "$external" for MONGODB-AWS, +// MONGODB-OIDC, MONGODB-X509, GSSAPI, and PLAIN. It defaults to "admin" for all other auth mechanisms. This can +// also be set through the "authSource" URI option (e.g. "authSource=otherDb"). // // Username: the username for authentication. This can also be set through the URI as a username:password pair before // the first @ character. For example, a URI for user "user", password "pwd", and host "localhost:27017" would be diff --git a/x/mongo/driver/auth/auth.go b/x/mongo/driver/auth/auth.go index 240fc22e3d..f5e4ee87f9 100644 --- a/x/mongo/driver/auth/auth.go +++ b/x/mongo/driver/auth/auth.go @@ -20,6 +20,8 @@ import ( "go.mongodb.org/mongo-driver/v2/x/mongo/driver/session" ) +const sourceExternal = "$external" + // AuthenticatorFactory constructs an authenticator. type AuthenticatorFactory func(*Cred, *http.Client) (Authenticator, error) diff --git a/x/mongo/driver/auth/gssapi.go b/x/mongo/driver/auth/gssapi.go index 9907eb4db4..0ae7571d23 100644 --- a/x/mongo/driver/auth/gssapi.go +++ b/x/mongo/driver/auth/gssapi.go @@ -24,7 +24,7 @@ import ( const GSSAPI = "GSSAPI" func newGSSAPIAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { - if cred.Source != "" && cred.Source != "$external" { + if cred.Source != "" && cred.Source != sourceExternal { return nil, newAuthError("GSSAPI source must be empty or $external", nil) } @@ -57,7 +57,7 @@ func (a *GSSAPIAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) if err != nil { return newAuthError("error creating gssapi", err) } - return ConductSaslConversation(ctx, cfg, "$external", client) + return ConductSaslConversation(ctx, cfg, sourceExternal, client) } // Reauth reauthenticates the connection. diff --git a/x/mongo/driver/auth/mongodbaws.go b/x/mongo/driver/auth/mongodbaws.go index 548fb9c92b..dd9661e1a9 100644 --- a/x/mongo/driver/auth/mongodbaws.go +++ b/x/mongo/driver/auth/mongodbaws.go @@ -21,17 +21,15 @@ import ( const MongoDBAWS = "MONGODB-AWS" func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) { - if cred.Source != "" && cred.Source != "$external" { + if cred.Source != "" && cred.Source != sourceExternal { return nil, newAuthError("MONGODB-AWS source must be empty or $external", nil) } if httpClient == nil { return nil, errors.New("httpClient must not be nil") } return &MongoDBAWSAuthenticator{ - source: cred.Source, credentials: &credproviders.StaticProvider{ Value: credentials.Value{ - ProviderName: cred.Source, AccessKeyID: cred.Username, SecretAccessKey: cred.Password, SessionToken: cred.Props["AWS_SESSION_TOKEN"], @@ -43,7 +41,6 @@ func newMongoDBAWSAuthenticator(cred *Cred, httpClient *http.Client) (Authentica // MongoDBAWSAuthenticator uses AWS-IAM credentials over SASL to authenticate a connection. type MongoDBAWSAuthenticator struct { - source string credentials *credproviders.StaticProvider httpClient *http.Client } @@ -56,7 +53,7 @@ func (a *MongoDBAWSAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConf credentials: providers.Cred, }, } - err := ConductSaslConversation(ctx, cfg, a.source, adapter) + err := ConductSaslConversation(ctx, cfg, sourceExternal, adapter) if err != nil { return newAuthError("sasl conversation error", err) } diff --git a/x/mongo/driver/auth/mongodbcr.go b/x/mongo/driver/auth/mongodbcr.go index 55ec36fa7d..f8c1466df7 100644 --- a/x/mongo/driver/auth/mongodbcr.go +++ b/x/mongo/driver/auth/mongodbcr.go @@ -30,8 +30,12 @@ import ( const MONGODBCR = "MONGODB-CR" func newMongoDBCRAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { + source := cred.Source + if source == "" { + source = "admin" + } return &MongoDBCRAuthenticator{ - DB: cred.Source, + DB: source, Username: cred.Username, Password: cred.Password, }, nil diff --git a/x/mongo/driver/auth/oidc.go b/x/mongo/driver/auth/oidc.go index 953f9f6ef7..c476ac86c4 100644 --- a/x/mongo/driver/auth/oidc.go +++ b/x/mongo/driver/auth/oidc.go @@ -110,6 +110,9 @@ func (oa *OIDCAuthenticator) SetAccessToken(accessToken string) { } func newOIDCAuthenticator(cred *Cred, httpClient *http.Client) (Authenticator, error) { + if cred.Source != "" && cred.Source != sourceExternal { + return nil, newAuthError("MONGODB-OIDC source must be empty or $external", nil) + } if cred.Password != "" { return nil, fmt.Errorf("password cannot be specified for %q", MongoDBOIDC) } @@ -446,7 +449,7 @@ func (oa *OIDCAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) e oa.mu.Unlock() if cachedAccessToken != "" { - err = ConductSaslConversation(ctx, cfg, "$external", &oidcOneStep{ + err = ConductSaslConversation(ctx, cfg, sourceExternal, &oidcOneStep{ userName: oa.userName, accessToken: cachedAccessToken, }) @@ -506,7 +509,7 @@ func (oa *OIDCAuthenticator) doAuthHuman(ctx context.Context, cfg *driver.AuthCo return ConductSaslConversation( subCtx, cfg, - "$external", + sourceExternal, &oidcOneStep{accessToken: accessToken}, ) } @@ -515,7 +518,7 @@ func (oa *OIDCAuthenticator) doAuthHuman(ctx context.Context, cfg *driver.AuthCo conn: cfg.Connection, oa: oa, } - return ConductSaslConversation(subCtx, cfg, "$external", ots) + return ConductSaslConversation(subCtx, cfg, sourceExternal, ots) } func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *driver.AuthConfig, machineCallback OIDCCallback) error { @@ -536,7 +539,7 @@ func (oa *OIDCAuthenticator) doAuthMachine(ctx context.Context, cfg *driver.Auth return ConductSaslConversation( ctx, cfg, - "$external", + sourceExternal, &oidcOneStep{accessToken: accessToken}, ) } @@ -550,5 +553,5 @@ func (oa *OIDCAuthenticator) CreateSpeculativeConversation() (SpeculativeConvers return nil, nil // Skip speculative auth. } - return newSaslConversation(&oidcOneStep{accessToken: accessToken}, "$external", true), nil + return newSaslConversation(&oidcOneStep{accessToken: accessToken}, sourceExternal, true), nil } diff --git a/x/mongo/driver/auth/plain.go b/x/mongo/driver/auth/plain.go index 43a992c339..69342bb1e5 100644 --- a/x/mongo/driver/auth/plain.go +++ b/x/mongo/driver/auth/plain.go @@ -17,6 +17,21 @@ import ( const PLAIN = "PLAIN" func newPlainAuthenticator(cred *Cred, _ *http.Client) (Authenticator, error) { + // TODO(GODRIVER-3317): The PLAIN specification says about auth source: + // + // "MUST be specified. Defaults to the database name if supplied on the + // connection string or $external." + // + // We should actually pass through the auth source, not always pass + // $external. If it's empty, we should default to $external. + // + // For example: + // + // source := cred.Source + // if source == "" { + // source = "$external" + // } + // return &PlainAuthenticator{ Username: cred.Username, Password: cred.Password, @@ -31,7 +46,7 @@ type PlainAuthenticator struct { // Auth authenticates the connection. func (a *PlainAuthenticator) Auth(ctx context.Context, cfg *driver.AuthConfig) error { - return ConductSaslConversation(ctx, cfg, "$external", &plainSaslClient{ + return ConductSaslConversation(ctx, cfg, sourceExternal, &plainSaslClient{ username: a.Username, password: a.Password, }) diff --git a/x/mongo/driver/auth/plain_test.go b/x/mongo/driver/auth/plain_test.go index c83f0d7de3..251769caf7 100644 --- a/x/mongo/driver/auth/plain_test.go +++ b/x/mongo/driver/auth/plain_test.go @@ -8,11 +8,10 @@ package auth_test import ( "context" + "encoding/base64" "strings" "testing" - "encoding/base64" - "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" diff --git a/x/mongo/driver/auth/scram.go b/x/mongo/driver/auth/scram.go index f0a4c4af16..2896a8facc 100644 --- a/x/mongo/driver/auth/scram.go +++ b/x/mongo/driver/auth/scram.go @@ -38,6 +38,10 @@ var ( ) func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { + source := cred.Source + if source == "" { + source = "admin" + } passdigest := mongoPasswordDigest(cred.Username, cred.Password) client, err := scram.SHA1.NewClientUnprepped(cred.Username, passdigest, "") if err != nil { @@ -46,12 +50,16 @@ func newScramSHA1Authenticator(cred *Cred, _ *http.Client) (Authenticator, error client.WithMinIterations(4096) return &ScramAuthenticator{ mechanism: SCRAMSHA1, - source: cred.Source, + source: source, client: client, }, nil } func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { + source := cred.Source + if source == "" { + source = "admin" + } passprep, err := stringprep.SASLprep.Prepare(cred.Password) if err != nil { return nil, newAuthError("error SASLprepping password", err) @@ -63,7 +71,7 @@ func newScramSHA256Authenticator(cred *Cred, _ *http.Client) (Authenticator, err client.WithMinIterations(4096) return &ScramAuthenticator{ mechanism: SCRAMSHA256, - source: cred.Source, + source: source, client: client, }, nil } diff --git a/x/mongo/driver/auth/x509.go b/x/mongo/driver/auth/x509.go index 608b13dda8..f839023435 100644 --- a/x/mongo/driver/auth/x509.go +++ b/x/mongo/driver/auth/x509.go @@ -19,6 +19,9 @@ import ( const MongoDBX509 = "MONGODB-X509" func newMongoDBX509Authenticator(cred *Cred, _ *http.Client) (Authenticator, error) { + // TODO(GODRIVER-3309): Validate that cred.Source is either empty or + // "$external" to make validation uniform with other auth mechanisms that + // require Source to be "$external" (e.g. MONGODB-AWS, MONGODB-OIDC, etc). return &MongoDBX509Authenticator{User: cred.Username}, nil } @@ -66,7 +69,7 @@ func (a *MongoDBX509Authenticator) Auth(ctx context.Context, cfg *driver.AuthCon requestDoc := createFirstX509Message() authCmd := operation. NewCommand(requestDoc). - Database("$external"). + Database(sourceExternal). Deployment(driver.SingleConnectionDeployment{cfg.Connection}). ClusterClock(cfg.ClusterClock). ServerAPI(cfg.ServerAPI) diff --git a/x/mongo/driver/connstring/connstring.go b/x/mongo/driver/connstring/connstring.go index d550a68127..ece143a1e4 100644 --- a/x/mongo/driver/connstring/connstring.go +++ b/x/mongo/driver/connstring/connstring.go @@ -291,7 +291,7 @@ func (u *ConnString) setDefaultAuthParams(dbName string) error { u.AuthMechanismProperties["SERVICE_NAME"] = "mongodb" } fallthrough - case "mongodb-aws", "mongodb-x509": + case "mongodb-aws", "mongodb-x509", "mongodb-oidc": if u.AuthSource == "" { u.AuthSource = "$external" } else if u.AuthSource != "$external" { @@ -308,13 +308,6 @@ func (u *ConnString) setDefaultAuthParams(dbName string) error { u.AuthSource = "admin" } } - case "mongodb-oidc": - if u.AuthSource == "" { - u.AuthSource = dbName - if u.AuthSource == "" { - u.AuthSource = "$external" - } - } case "": // Only set auth source if there is a request for authentication via non-empty credentials. if u.AuthSource == "" && (u.AuthMechanismProperties != nil || u.Username != "" || u.PasswordSet) { diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index 2dbfb55673..d98b47d5ef 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -7,10 +7,10 @@ package topology import ( + "context" "crypto/tls" "fmt" "net/http" - "strings" "time" "go.mongodb.org/mongo-driver/v2/event" @@ -81,6 +81,53 @@ func newLogger(opts options.Lister[options.LoggerOptions]) (*logger.Logger, erro return log, nil } +// convertOIDCArgs converts the internal *driver.OIDCArgs into the equivalent +// public type *options.OIDCArgs. +func convertOIDCArgs(args *driver.OIDCArgs) *options.OIDCArgs { + if args == nil { + return nil + } + return &options.OIDCArgs{ + Version: args.Version, + IDPInfo: (*options.IDPInfo)(args.IDPInfo), + RefreshToken: args.RefreshToken, + } +} + +// ConvertCreds takes an [options.Credential] and returns the equivalent +// [driver.Cred]. +func ConvertCreds(cred *options.Credential) *driver.Cred { + if cred == nil { + return nil + } + + var oidcMachineCallback auth.OIDCCallback + if cred.OIDCMachineCallback != nil { + oidcMachineCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + cred, err := cred.OIDCMachineCallback(ctx, convertOIDCArgs(args)) + return (*driver.OIDCCredential)(cred), err + } + } + + var oidcHumanCallback auth.OIDCCallback + if cred.OIDCHumanCallback != nil { + oidcHumanCallback = func(ctx context.Context, args *driver.OIDCArgs) (*driver.OIDCCredential, error) { + cred, err := cred.OIDCHumanCallback(ctx, convertOIDCArgs(args)) + return (*driver.OIDCCredential)(cred), err + } + } + + return &auth.Cred{ + Source: cred.AuthSource, + Username: cred.Username, + Password: cred.Password, + PasswordSet: cred.PasswordSet, + Props: cred.AuthMechanismProperties, + OIDCMachineCallback: oidcMachineCallback, + OIDCHumanCallback: oidcHumanCallback, + } +} + // NewConfig behaves like NewConfigFromOptions by extracting arguments from a // list of ClientOptions setters. func NewConfig(opts *options.ClientOptionsBuilder, clock *session.ClusterClock) (*Config, error) { @@ -96,27 +143,24 @@ func NewConfig(opts *options.ClientOptionsBuilder, clock *session.ClusterClock) // config for building non-default deployments. Server and topology options are // not honored if a custom deployment is used. func NewConfigFromOptions(opts *options.ClientOptions, clock *session.ClusterClock) (*Config, error) { - // Auth & Database & Password & Username + var authenticator driver.Authenticator + var err error if opts.Auth != nil { - cred := &auth.Cred{ - Username: opts.Auth.Username, - Password: opts.Auth.Password, - PasswordSet: opts.Auth.PasswordSet, - Props: opts.Auth.AuthMechanismProperties, - Source: opts.Auth.AuthSource, - } - mechanism := opts.Auth.AuthMechanism - authenticator, err := auth.CreateAuthenticator(mechanism, cred, opts.HTTPClient) + authenticator, err = auth.CreateAuthenticator( + opts.Auth.AuthMechanism, + ConvertCreds(opts.Auth), + opts.HTTPClient, + ) if err != nil { - return nil, err + return nil, fmt.Errorf("error creating authenticator: %w", err) } - return NewConfigFromOptionsWithAuthenticator(opts, clock, authenticator) } - return NewConfigFromOptionsWithAuthenticator(opts, clock, nil) + return NewConfigFromOptionsWithAuthenticator(opts, clock, authenticator) } -// NewConfigFromOptionsWithAuthenticator will translate data from client options into a topology config for building non-default deployments. -// Server and topology options are not honored if a custom deployment is used. It uses a passed in +// NewConfigFromOptionsWithAuthenticator will translate data from client options into a +// topology config for building non-default deployments. Server and topology +// options are not honored if a custom deployment is used. It uses a passed in // authenticator to authenticate the connection. func NewConfigFromOptionsWithAuthenticator(opts *options.ClientOptions, clock *session.ClusterClock, authenticator driver.Authenticator) (*Config, error) { @@ -217,30 +261,8 @@ func NewConfigFromOptionsWithAuthenticator(opts *options.ClientOptions, clock *s } // Handshaker - var handshaker = func(driver.Handshaker) driver.Handshaker { - return operation.NewHello().AppName(appName).Compressors(comps).ClusterClock(clock). - ServerAPI(serverAPI).LoadBalanced(loadBalanced) - } - // Auth & Database & Password & Username - if opts.Auth != nil { - cred := &auth.Cred{ - Username: opts.Auth.Username, - Password: opts.Auth.Password, - PasswordSet: opts.Auth.PasswordSet, - Props: opts.Auth.AuthMechanismProperties, - Source: opts.Auth.AuthSource, - } - mechanism := opts.Auth.AuthMechanism - - if len(cred.Source) == 0 { - switch strings.ToUpper(mechanism) { - case auth.MongoDBX509, auth.GSSAPI, auth.PLAIN: - cred.Source = "$external" - default: - cred.Source = "admin" - } - } - + var handshaker func(driver.Handshaker) driver.Handshaker + if authenticator != nil { handshakeOpts := &auth.HandshakeOptions{ AppName: appName, Authenticator: authenticator, @@ -250,15 +272,26 @@ func NewConfigFromOptionsWithAuthenticator(opts *options.ClientOptions, clock *s ClusterClock: clock, } - if mechanism == "" { + if opts.Auth.AuthMechanism == "" { // Required for SASL mechanism negotiation during handshake - handshakeOpts.DBUser = cred.Source + "." + cred.Username + handshakeOpts.DBUser = opts.Auth.AuthSource + "." + opts.Auth.Username } handshaker = func(driver.Handshaker) driver.Handshaker { return auth.Handshaker(nil, handshakeOpts) } + + } else { + handshaker = func(driver.Handshaker) driver.Handshaker { + return operation.NewHello(). + AppName(appName). + Compressors(comps). + ClusterClock(clock). + ServerAPI(serverAPI). + LoadBalanced(loadBalanced) + } } + connOpts = append(connOpts, WithHandshaker(handshaker)) // Dialer diff --git a/x/mongo/driver/topology/topology_options_test.go b/x/mongo/driver/topology/topology_options_test.go index e31adff87c..759ab9aa4a 100644 --- a/x/mongo/driver/topology/topology_options_test.go +++ b/x/mongo/driver/topology/topology_options_test.go @@ -9,11 +9,14 @@ package topology import ( "fmt" "net/url" + "reflect" "testing" "time" "go.mongodb.org/mongo-driver/v2/internal/assert" + "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo/options" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" ) func TestDirectConnectionFromConnString(t *testing.T) { @@ -104,3 +107,77 @@ func TestTopologyNewConfig(t *testing.T) { assert.Equal(t, []string{"localhost:27018"}, cfg.SeedList) }) } + +// Test that convertOIDCArgs exhaustively copies all fields of a driver.OIDCArgs +// into an options.OIDCArgs. +func TestConvertOIDCArgs(t *testing.T) { + t.Parallel() + refreshToken := "test refresh token" + + testCases := []struct { + desc string + args *driver.OIDCArgs + }{ + { + desc: "populated args", + args: &driver.OIDCArgs{ + Version: 9, + IDPInfo: &driver.IDPInfo{ + Issuer: "test issuer", + ClientID: "test client ID", + RequestScopes: []string{"test scope 1", "test scope 2"}, + }, + RefreshToken: &refreshToken, + }, + }, + { + desc: "nil", + args: nil, + }, + { + desc: "nil IDPInfo and RefreshToken", + args: &driver.OIDCArgs{ + Version: 9, + IDPInfo: nil, + RefreshToken: nil, + }, + }, + } + + for _, tc := range testCases { + tc := tc // Capture range variable. + + t.Run(tc.desc, func(t *testing.T) { + t.Parallel() + + got := convertOIDCArgs(tc.args) + + if tc.args == nil { + assert.Nil(t, got, "expected nil when input is nil") + return + } + + require.Equal(t, + 3, + reflect.ValueOf(*tc.args).NumField(), + "expected the driver.OIDCArgs struct to have exactly 3 fields") + require.Equal(t, + 3, + reflect.ValueOf(*got).NumField(), + "expected the options.OIDCArgs struct to have exactly 3 fields") + + assert.Equal(t, + tc.args.Version, + got.Version, + "expected Version field to be equal") + assert.EqualValues(t, + tc.args.IDPInfo, + got.IDPInfo, + "expected IDPInfo field to be convertible to equal values") + assert.Equal(t, + tc.args.RefreshToken, + got.RefreshToken, + "expected RefreshToken field to be equal") + }) + } +}