Skip to content
59 changes: 59 additions & 0 deletions x/mongo/driver/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,19 @@
package auth_test

import (
"context"
"fmt"
"net/http"
"testing"

"github.com/google/go-cmp/cmp"
"go.mongodb.org/mongo-driver/v2/internal/require"
"go.mongodb.org/mongo-driver/v2/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/auth"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/drivertest"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/wiremessage"
)

Expand Down Expand Up @@ -101,3 +107,56 @@ func compareResponses(t *testing.T, wm []byte, expectedPayload bsoncore.Document
t.Errorf("Payloads don't match. got %v; want %v", actualPayload, expectedPayload)
}
}

type testAuthenticator struct{}

func (a *testAuthenticator) Auth(context.Context, *driver.AuthConfig) error {
return fmt.Errorf("test error")
}

func (a *testAuthenticator) Reauth(context.Context, *driver.AuthConfig) error {
return nil
}

func TestPerformAuthentication(t *testing.T) {
t.Parallel()

cases := []struct {
name string
needToPerform bool
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional] Suggest calling this bool authenticateToAnything.

assert func(*testing.T, error)
}{
{
name: "positive",
needToPerform: true,
assert: func(t *testing.T, err error) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[optional] Suggest renaming this to require.

require.EqualError(t, err, "auth error: test error")
},
},
{
name: "negative",
needToPerform: false,
assert: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
}
mnetconn := mnet.NewConnection(&drivertest.ChannelConn{})
for _, tc := range cases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

handshaker := auth.Handshaker(nil, &auth.HandshakeOptions{
Authenticator: &testAuthenticator{},
PerformAuthentication: func(description.Server) bool {
return tc.needToPerform
},
})

err := handshaker.FinishHandshake(context.Background(), mnetconn)
tc.assert(t, err)
})
}
}
10 changes: 10 additions & 0 deletions x/mongo/driver/topology/topology_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,11 @@ import (

"go.mongodb.org/mongo-driver/v2/event"
"go.mongodb.org/mongo-driver/v2/internal/logger"
internalOptions "go.mongodb.org/mongo-driver/v2/internal/options"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/auth"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/ocsp"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/operation"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/session"
Expand Down Expand Up @@ -270,6 +272,14 @@ func NewConfigFromOptionsWithAuthenticator(opts *options.ClientOptions, clock *s
// Required for SASL mechanism negotiation during handshake
handshakeOpts.DBUser = opts.Auth.AuthSource + "." + opts.Auth.Username
}
if a := internalOptions.Value(opts.Custom, "authenticateToAnything"); a != nil {
if v, ok := a.(bool); ok && v {
// Authenticate arbiters
handshakeOpts.PerformAuthentication = func(_ description.Server) bool {
return true
}
}
}

handshaker = func(driver.Handshaker) driver.Handshaker {
return auth.Handshaker(nil, handshakeOpts)
Expand Down
73 changes: 73 additions & 0 deletions x/mongo/driver/topology/topology_options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
package topology

import (
"context"
"fmt"
"net/url"
"reflect"
Expand All @@ -17,6 +18,10 @@ import (
"go.mongodb.org/mongo-driver/v2/internal/require"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/description"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/drivertest"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/mnet"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver/xoptions"
)

func TestDirectConnectionFromConnString(t *testing.T) {
Expand Down Expand Up @@ -85,6 +90,74 @@ func TestLoadBalancedFromConnString(t *testing.T) {
}
}

type testAuthenticator struct{}

func (a *testAuthenticator) Auth(context.Context, *driver.AuthConfig) error {
return fmt.Errorf("test error")
}

func (a *testAuthenticator) Reauth(context.Context, *driver.AuthConfig) error {
return nil
}

func TestAuthenticateToAnything(t *testing.T) {
t.Parallel()

cases := []struct {
name string
set func(*options.ClientOptions) error
assert func(*testing.T, error)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Optional] Suggest renaming this to require.

}{
{
name: "default",
set: func(*options.ClientOptions) error { return nil },
assert: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
{
name: "positive",
set: func(opt *options.ClientOptions) error {
return xoptions.SetInternalClientOptions(opt, "authenticateToAnything", true)
},
assert: func(t *testing.T, err error) {
require.EqualError(t, err, "auth error: test error")
},
},
{
name: "negative",
set: func(opt *options.ClientOptions) error {
return xoptions.SetInternalClientOptions(opt, "authenticateToAnything", false)
},
assert: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
}

describer := &drivertest.ChannelConn{
Desc: description.Server{Kind: description.ServerKindRSArbiter},
}
for _, tc := range cases {
tc := tc

t.Run(tc.name, func(t *testing.T) {
t.Parallel()

opt := options.Client().SetAuth(options.Credential{Username: "foo", Password: "bar"})
err := tc.set(opt)
require.NoError(t, err, "error setting authenticateToAnything: %v", err)
cfg, err := NewConfigFromOptionsWithAuthenticator(opt, nil, &testAuthenticator{})
require.NoError(t, err, "error constructing topology config: %v", err)

srvrCfg := newServerConfig(defaultConnectionTimeout, cfg.ServerOpts...)
connCfg := newConnectionConfig(srvrCfg.connectionOpts...)
err = connCfg.handshaker.FinishHandshake(context.TODO(), &mnet.Connection{Describer: describer})
tc.assert(t, err)
})
}
}

func TestTopologyNewConfig(t *testing.T) {
t.Run("default ServerSelectionTimeout", func(t *testing.T) {
cfg, err := NewConfig(options.Client(), nil)
Expand Down
7 changes: 7 additions & 0 deletions x/mongo/driver/xoptions/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package xoptions
import (
"fmt"

internalOptions "go.mongodb.org/mongo-driver/v2/internal/options"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we rename the internal/options package to internal/optionsutil? That seems to follow the existing pattern. Then we can avoid aliasing imports.

"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
)
Expand All @@ -31,6 +32,12 @@ func SetInternalClientOptions(opts *options.ClientOptions, key string, option an
return fmt.Errorf(typeErr, key)
}
opts.Deployment = d
case "authenticateToAnything":
b, ok := option.(bool)
if !ok {
return fmt.Errorf(typeErr, key)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we update these errors to describe the type that we want? Possible using some kind of composing function:

typeErrFuc := func() string {
	return fmt.Sprintf("unexecpted type for %q, wanted %T, got %T", key, option, option)
}

}
opts.Custom = internalOptions.WithValue(opts.Custom, key, b)
default:
return fmt.Errorf("unsupported option: %s", key)
}
Expand Down
25 changes: 25 additions & 0 deletions x/mongo/driver/xoptions/options_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
package xoptions

import (
"fmt"
"testing"

internalOptions "go.mongodb.org/mongo-driver/v2/internal/options"
"go.mongodb.org/mongo-driver/v2/internal/require"
"go.mongodb.org/mongo-driver/v2/mongo/options"
"go.mongodb.org/mongo-driver/v2/x/mongo/driver"
Expand All @@ -18,6 +20,29 @@ import (
func TestSetInternalClientOptions(t *testing.T) {
t.Parallel()

cases := []struct {
key string
value any
}{
{
key: "authenticateToAnything",
value: true,
},
}
for _, tc := range cases {
tc := tc

t.Run(fmt.Sprintf("set %s", tc.key), func(t *testing.T) {
t.Parallel()

opts := options.Client()
err := SetInternalClientOptions(opts, tc.key, tc.value)
require.NoError(t, err, "error setting %s: %v", tc.key, err)
v := internalOptions.Value(opts.Custom, tc.key)
require.Equal(t, tc.value, v, "expected %v, got %v", tc.value, v)
})
}

t.Run("set crypt", func(t *testing.T) {
t.Parallel()

Expand Down
Loading