diff --git a/internal/options/options.go b/internal/optionsutil/options.go similarity index 98% rename from internal/options/options.go rename to internal/optionsutil/options.go index 8d5f47f422..5e7527c99b 100644 --- a/internal/options/options.go +++ b/internal/optionsutil/options.go @@ -4,7 +4,7 @@ // not use this file except in compliance with the License. You may obtain // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 -package options +package optionsutil // Options stores internal options. type Options struct { diff --git a/mongo/options/clientoptions.go b/mongo/options/clientoptions.go index 649d6a8e3d..6cfb3dc2f1 100644 --- a/mongo/options/clientoptions.go +++ b/mongo/options/clientoptions.go @@ -26,7 +26,7 @@ import ( "go.mongodb.org/mongo-driver/v2/bson" "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/httputil" - "go.mongodb.org/mongo-driver/v2/internal/options" + "go.mongodb.org/mongo-driver/v2/internal/optionsutil" "go.mongodb.org/mongo-driver/v2/mongo/readconcern" "go.mongodb.org/mongo-driver/v2/mongo/readpref" "go.mongodb.org/mongo-driver/v2/mongo/writeconcern" @@ -302,7 +302,7 @@ type ClientOptions struct { // // Deprecated: This option is for internal use only and should not be set. It may be changed or removed in any // release. - Custom options.Options + Custom optionsutil.Options connString *connstring.ConnString err error diff --git a/mongo/options/clientoptions_test.go b/mongo/options/clientoptions_test.go index a8a00122e0..2a2471adb5 100644 --- a/mongo/options/clientoptions_test.go +++ b/mongo/options/clientoptions_test.go @@ -27,7 +27,7 @@ import ( "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/assert" "go.mongodb.org/mongo-driver/v2/internal/httputil" - "go.mongodb.org/mongo-driver/v2/internal/options" + "go.mongodb.org/mongo-driver/v2/internal/optionsutil" "go.mongodb.org/mongo-driver/v2/internal/ptrutil" "go.mongodb.org/mongo-driver/v2/mongo/readconcern" "go.mongodb.org/mongo-driver/v2/mongo/readpref" @@ -157,7 +157,7 @@ func TestClientOptions(t *testing.T) { cmp.Comparer(func(r1, r2 *bson.Registry) bool { return r1 == r2 }), cmp.Comparer(func(cfg1, cfg2 *tls.Config) bool { return cfg1 == cfg2 }), cmp.Comparer(func(fp1, fp2 *event.PoolMonitor) bool { return fp1 == fp2 }), - cmp.Comparer(options.Equal), + cmp.Comparer(optionsutil.Equal), cmp.AllowUnexported(ClientOptions{}), cmpopts.IgnoreFields(http.Client{}, "Transport"), ); diff != "" { @@ -1255,7 +1255,7 @@ func TestApplyURI(t *testing.T) { cmp.Comparer(func(r1, r2 *bson.Registry) bool { return r1 == r2 }), cmp.Comparer(compareTLSConfig), cmp.Comparer(compareErrors), - cmp.Comparer(options.Equal), + cmp.Comparer(optionsutil.Equal), cmpopts.SortSlices(stringLess), cmpopts.IgnoreFields(connstring.ConnString{}, "SSLClientCertificateKeyPassword"), cmpopts.IgnoreFields(http.Client{}, "Transport"), diff --git a/x/mongo/driver/auth/auth_test.go b/x/mongo/driver/auth/auth_test.go index 4736144e59..2edf872a4d 100644 --- a/x/mongo/driver/auth/auth_test.go +++ b/x/mongo/driver/auth/auth_test.go @@ -7,13 +7,19 @@ package auth_test import ( + "context" + "fmt" "net/http" "testing" "github.com/google/go-cmp/cmp" "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/auth" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/drivertest" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage" ) @@ -101,3 +107,56 @@ func compareResponses(t *testing.T, wm []byte, expectedPayload bsoncore.Document t.Errorf("Payloads don't match. got %v; want %v", actualPayload, expectedPayload) } } + +type testAuthenticator struct{} + +func (a *testAuthenticator) Auth(context.Context, *driver.AuthConfig) error { + return fmt.Errorf("test error") +} + +func (a *testAuthenticator) Reauth(context.Context, *driver.AuthConfig) error { + return nil +} + +func TestPerformAuthentication(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + authenticateToAnything bool + require func(*testing.T, error) + }{ + { + name: "positive", + authenticateToAnything: true, + require: func(t *testing.T, err error) { + require.EqualError(t, err, "auth error: test error") + }, + }, + { + name: "negative", + authenticateToAnything: false, + require: func(t *testing.T, err error) { + require.NoError(t, err) + }, + }, + } + mnetconn := mnet.NewConnection(&drivertest.ChannelConn{}) + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + handshaker := auth.Handshaker(nil, &auth.HandshakeOptions{ + Authenticator: &testAuthenticator{}, + PerformAuthentication: func(description.Server) bool { + return tc.authenticateToAnything + }, + }) + + err := handshaker.FinishHandshake(context.Background(), mnetconn) + tc.require(t, err) + }) + } +} diff --git a/x/mongo/driver/topology/topology_options.go b/x/mongo/driver/topology/topology_options.go index aefa74c56a..2ddc7434bd 100644 --- a/x/mongo/driver/topology/topology_options.go +++ b/x/mongo/driver/topology/topology_options.go @@ -15,9 +15,11 @@ import ( "go.mongodb.org/mongo-driver/v2/event" "go.mongodb.org/mongo-driver/v2/internal/logger" + "go.mongodb.org/mongo-driver/v2/internal/optionsutil" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/auth" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/ocsp" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/operation" "go.mongodb.org/mongo-driver/v2/x/mongo/driver/session" @@ -270,6 +272,14 @@ func NewConfigFromOptionsWithAuthenticator(opts *options.ClientOptions, clock *s // Required for SASL mechanism negotiation during handshake handshakeOpts.DBUser = opts.Auth.AuthSource + "." + opts.Auth.Username } + if a := optionsutil.Value(opts.Custom, "authenticateToAnything"); a != nil { + if v, ok := a.(bool); ok && v { + // Authenticate arbiters + handshakeOpts.PerformAuthentication = func(_ description.Server) bool { + return true + } + } + } handshaker = func(driver.Handshaker) driver.Handshaker { return auth.Handshaker(nil, handshakeOpts) diff --git a/x/mongo/driver/topology/topology_options_test.go b/x/mongo/driver/topology/topology_options_test.go index 759ab9aa4a..680aa638a7 100644 --- a/x/mongo/driver/topology/topology_options_test.go +++ b/x/mongo/driver/topology/topology_options_test.go @@ -7,6 +7,7 @@ package topology import ( + "context" "fmt" "net/url" "reflect" @@ -17,6 +18,10 @@ import ( "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/description" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/drivertest" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet" + "go.mongodb.org/mongo-driver/v2/x/mongo/driver/xoptions" ) func TestDirectConnectionFromConnString(t *testing.T) { @@ -85,6 +90,76 @@ func TestLoadBalancedFromConnString(t *testing.T) { } } +type testAuthenticator struct{} + +var _ driver.Authenticator = &testAuthenticator{} + +func (a *testAuthenticator) Auth(context.Context, *driver.AuthConfig) error { + return fmt.Errorf("test error") +} + +func (a *testAuthenticator) Reauth(context.Context, *driver.AuthConfig) error { + return nil +} + +func TestAuthenticateToAnything(t *testing.T) { + t.Parallel() + + cases := []struct { + name string + set func(*options.ClientOptions) error + require func(*testing.T, error) + }{ + { + name: "default", + set: func(*options.ClientOptions) error { return nil }, + require: func(t *testing.T, err error) { + require.NoError(t, err) + }, + }, + { + name: "positive", + set: func(opt *options.ClientOptions) error { + return xoptions.SetInternalClientOptions(opt, "authenticateToAnything", true) + }, + require: func(t *testing.T, err error) { + require.EqualError(t, err, "auth error: test error") + }, + }, + { + name: "negative", + set: func(opt *options.ClientOptions) error { + return xoptions.SetInternalClientOptions(opt, "authenticateToAnything", false) + }, + require: func(t *testing.T, err error) { + require.NoError(t, err) + }, + }, + } + + describer := &drivertest.ChannelConn{ + Desc: description.Server{Kind: description.ServerKindRSArbiter}, + } + for _, tc := range cases { + tc := tc + + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + opt := options.Client().SetAuth(options.Credential{Username: "foo", Password: "bar"}) + err := tc.set(opt) + require.NoError(t, err, "error setting authenticateToAnything: %v", err) + cfg, err := NewConfigFromOptionsWithAuthenticator(opt, nil, &testAuthenticator{}) + require.NoError(t, err, "error constructing topology config: %v", err) + + srvrCfg := newServerConfig(defaultConnectionTimeout, cfg.ServerOpts...) + connCfg := newConnectionConfig(srvrCfg.connectionOpts...) + err = connCfg.handshaker.FinishHandshake(context.TODO(), &mnet.Connection{Describer: describer}) + tc.require(t, err) + }) + } +} + func TestTopologyNewConfig(t *testing.T) { t.Run("default ServerSelectionTimeout", func(t *testing.T) { cfg, err := NewConfig(options.Client(), nil) diff --git a/x/mongo/driver/xoptions/options.go b/x/mongo/driver/xoptions/options.go index 6eb1bd0dc2..68e28e7cd8 100644 --- a/x/mongo/driver/xoptions/options.go +++ b/x/mongo/driver/xoptions/options.go @@ -9,6 +9,7 @@ package xoptions import ( "fmt" + "go.mongodb.org/mongo-driver/v2/internal/optionsutil" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" ) @@ -17,20 +18,28 @@ import ( // // Deprecated: This function is for internal use only. It may be changed or removed in any release. func SetInternalClientOptions(opts *options.ClientOptions, key string, option any) error { - const typeErr = "unexpected type for %s" + typeErrFunc := func(t string) error { + return fmt.Errorf("unexpected type for %s: %T is not %s", key, option, t) + } switch key { case "crypt": c, ok := option.(driver.Crypt) if !ok { - return fmt.Errorf(typeErr, key) + return typeErrFunc("driver.Crypt") } opts.Crypt = c case "deployment": d, ok := option.(driver.Deployment) if !ok { - return fmt.Errorf(typeErr, key) + return typeErrFunc("driver.Deployment") } opts.Deployment = d + case "authenticateToAnything": + b, ok := option.(bool) + if !ok { + return typeErrFunc("bool") + } + opts.Custom = optionsutil.WithValue(opts.Custom, key, b) default: return fmt.Errorf("unsupported option: %s", key) } diff --git a/x/mongo/driver/xoptions/options_test.go b/x/mongo/driver/xoptions/options_test.go index b459ec8ada..284fe914a1 100644 --- a/x/mongo/driver/xoptions/options_test.go +++ b/x/mongo/driver/xoptions/options_test.go @@ -7,8 +7,10 @@ package xoptions import ( + "fmt" "testing" + "go.mongodb.org/mongo-driver/v2/internal/optionsutil" "go.mongodb.org/mongo-driver/v2/internal/require" "go.mongodb.org/mongo-driver/v2/mongo/options" "go.mongodb.org/mongo-driver/v2/x/mongo/driver" @@ -18,6 +20,29 @@ import ( func TestSetInternalClientOptions(t *testing.T) { t.Parallel() + cases := []struct { + key string + value any + }{ + { + key: "authenticateToAnything", + value: true, + }, + } + for _, tc := range cases { + tc := tc + + t.Run(fmt.Sprintf("set %s", tc.key), func(t *testing.T) { + t.Parallel() + + opts := options.Client() + err := SetInternalClientOptions(opts, tc.key, tc.value) + require.NoError(t, err, "error setting %s: %v", tc.key, err) + v := optionsutil.Value(opts.Custom, tc.key) + require.Equal(t, tc.value, v, "expected %v, got %v", tc.value, v) + }) + } + t.Run("set crypt", func(t *testing.T) { t.Parallel() @@ -33,7 +58,7 @@ func TestSetInternalClientOptions(t *testing.T) { opts := options.Client() err := SetInternalClientOptions(opts, "crypt", &drivertest.MockDeployment{}) - require.EqualError(t, err, "unexpected type for crypt") + require.EqualError(t, err, "unexpected type for crypt: *drivertest.MockDeployment is not driver.Crypt") }) t.Run("set deployment", func(t *testing.T) { @@ -51,7 +76,7 @@ func TestSetInternalClientOptions(t *testing.T) { opts := options.Client() err := SetInternalClientOptions(opts, "deployment", driver.NewCrypt(&driver.CryptOptions{})) - require.EqualError(t, err, "unexpected type for deployment") + require.EqualError(t, err, "unexpected type for deployment: *driver.crypt is not driver.Deployment") }) t.Run("set unsupported option", func(t *testing.T) {