Skip to content

Commit 7cb4923

Browse files
author
Divjot Arora
committed
Store connection description before authenticating.
GODRIVER-1311 Change-Id: I04c1dfe730ef483d675a4fa01ca8c118f3ef69f0
1 parent d2ca53b commit 7cb4923

File tree

7 files changed

+149
-91
lines changed

7 files changed

+149
-91
lines changed

x/mongo/driver/auth/auth.go

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -56,40 +56,59 @@ type HandshakeOptions struct {
5656
PerformAuthentication func(description.Server) bool
5757
}
5858

59-
// Handshaker creates a connection handshaker for the given authenticator.
60-
func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshaker {
61-
return driver.HandshakerFunc(func(ctx context.Context, addr address.Address, conn driver.Connection) (description.Server, error) {
62-
desc, err := operation.NewIsMaster().
63-
AppName(options.AppName).
64-
Compressors(options.Compressors).
65-
SASLSupportedMechs(options.DBUser).
66-
Handshake(ctx, addr, conn)
59+
type authHandshaker struct {
60+
wrapped driver.Handshaker
61+
options *HandshakeOptions
62+
}
6763

68-
if err != nil {
69-
return description.Server{}, newAuthError("handshake failure", err)
70-
}
64+
// GetDescription performs an isMaster to retrieve the initial description for conn.
65+
func (ah *authHandshaker) GetDescription(ctx context.Context, addr address.Address, conn driver.Connection) (description.Server, error) {
66+
if ah.wrapped != nil {
67+
return ah.wrapped.GetDescription(ctx, addr, conn)
68+
}
7169

72-
performAuth := options.PerformAuthentication
73-
if performAuth == nil {
74-
performAuth = func(serv description.Server) bool {
75-
return serv.Kind == description.RSPrimary ||
76-
serv.Kind == description.RSSecondary ||
77-
serv.Kind == description.Mongos ||
78-
serv.Kind == description.Standalone
79-
}
80-
}
81-
if performAuth(desc) && options.Authenticator != nil {
82-
err = options.Authenticator.Auth(ctx, desc, conn)
83-
if err != nil {
84-
return description.Server{}, newAuthError("auth error", err)
85-
}
70+
desc, err := operation.NewIsMaster().
71+
AppName(ah.options.AppName).
72+
Compressors(ah.options.Compressors).
73+
SASLSupportedMechs(ah.options.DBUser).
74+
GetDescription(ctx, addr, conn)
75+
if err != nil {
76+
return description.Server{}, newAuthError("handshake failure", err)
77+
}
78+
return desc, nil
79+
}
8680

81+
// FinishHandshake performs authentication for conn if necessary.
82+
func (ah *authHandshaker) FinishHandshake(ctx context.Context, conn driver.Connection) error {
83+
performAuth := ah.options.PerformAuthentication
84+
if performAuth == nil {
85+
performAuth = func(serv description.Server) bool {
86+
return serv.Kind == description.RSPrimary ||
87+
serv.Kind == description.RSSecondary ||
88+
serv.Kind == description.Mongos ||
89+
serv.Kind == description.Standalone
8790
}
88-
if h == nil {
89-
return desc, nil
91+
}
92+
desc := conn.Description()
93+
if performAuth(desc) && ah.options.Authenticator != nil {
94+
err := ah.options.Authenticator.Auth(ctx, desc, conn)
95+
if err != nil {
96+
return newAuthError("auth error", err)
9097
}
91-
return h.Handshake(ctx, addr, conn)
92-
})
98+
}
99+
100+
if ah.wrapped == nil {
101+
return nil
102+
}
103+
return ah.wrapped.FinishHandshake(ctx, conn)
104+
}
105+
106+
// Handshaker creates a connection handshaker for the given authenticator.
107+
func Handshaker(h driver.Handshaker, options *HandshakeOptions) driver.Handshaker {
108+
return &authHandshaker{
109+
wrapped: h,
110+
options: options,
111+
}
93112
}
94113

95114
// Authenticator handles authenticating a connection.

x/mongo/driver/driver.go

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,8 @@ type ErrorProcessor interface {
5959
// handshake over a provided driver.Connection. This is used during connection
6060
// initialization. Implementations must be goroutine safe.
6161
type Handshaker interface {
62-
Handshake(context.Context, address.Address, Connection) (description.Server, error)
63-
}
64-
65-
// HandshakerFunc is an adapter to allow the use of ordinary functions as
66-
// connection handshakers.
67-
type HandshakerFunc func(context.Context, address.Address, Connection) (description.Server, error)
68-
69-
// Handshake implements the Handshaker interface.
70-
func (hf HandshakerFunc) Handshake(ctx context.Context, addr address.Address, conn Connection) (description.Server, error) {
71-
return hf(ctx, addr, conn)
62+
GetDescription(context.Context, address.Address, Connection) (description.Server, error)
63+
FinishHandshake(context.Context, Connection) error
7264
}
7365

7466
// SingleServerDeployment is an implementation of Deployment that always returns a single server.

x/mongo/driver/operation/ismaster.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ type IsMaster struct {
2828
res bsoncore.Document
2929
}
3030

31+
var _ driver.Handshaker = (*IsMaster)(nil)
32+
3133
// NewIsMaster constructs an IsMaster.
3234
func NewIsMaster() *IsMaster { return &IsMaster{} }
3335

@@ -401,8 +403,9 @@ func (im *IsMaster) Execute(ctx context.Context) error {
401403
}.Execute(ctx, nil)
402404
}
403405

404-
// Handshake implements the Handshaker interface.
405-
func (im *IsMaster) Handshake(ctx context.Context, _ address.Address, c driver.Connection) (description.Server, error) {
406+
// GetDescription retrieves the server description for the given connection. This function implements the Handshaker
407+
// interface.
408+
func (im *IsMaster) GetDescription(ctx context.Context, _ address.Address, c driver.Connection) (description.Server, error) {
406409
err := driver.Operation{
407410
Clock: im.clock,
408411
CommandFn: im.handshakeCommand,
@@ -418,3 +421,9 @@ func (im *IsMaster) Handshake(ctx context.Context, _ address.Address, c driver.C
418421
}
419422
return im.Result(c.Address()), nil
420423
}
424+
425+
// FinishHandshake implements the Handshaker interface. This is a no-op function because a non-authenticated connection
426+
// does not do anything besides the initial isMaster for a handshake.
427+
func (im *IsMaster) FinishHandshake(context.Context, driver.Connection) error {
428+
return nil
429+
}

x/mongo/driver/topology/connection.go

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -114,39 +114,47 @@ func (c *connection) connect(ctx context.Context) {
114114
c.bumpIdleDeadline()
115115

116116
// running isMaster and authentication is handled by a handshaker on the configuration instance.
117-
if c.config.handshaker != nil {
118-
c.desc, err = c.config.handshaker.Handshake(ctx, c.addr, initConnection{c})
119-
if err != nil {
120-
if c.nc != nil {
121-
_ = c.nc.Close()
122-
}
123-
atomic.StoreInt32(&c.connected, disconnected)
124-
c.connectErr = ConnectionError{Wrapped: err, init: true}
125-
return
126-
}
127-
if c.config.descCallback != nil {
128-
c.config.descCallback(c.desc)
117+
handshaker := c.config.handshaker
118+
if handshaker == nil {
119+
return
120+
}
121+
122+
handshakeConn := initConnection{c}
123+
c.desc, err = handshaker.GetDescription(ctx, c.addr, handshakeConn)
124+
if err == nil {
125+
err = handshaker.FinishHandshake(ctx, handshakeConn)
126+
}
127+
if err != nil {
128+
if c.nc != nil {
129+
_ = c.nc.Close()
129130
}
130-
if len(c.desc.Compression) > 0 {
131-
clientMethodLoop:
132-
for _, method := range c.config.compressors {
133-
for _, serverMethod := range c.desc.Compression {
134-
if method != serverMethod {
135-
continue
136-
}
131+
atomic.StoreInt32(&c.connected, disconnected)
132+
c.connectErr = ConnectionError{Wrapped: err, init: true}
133+
return
134+
}
135+
136+
if c.config.descCallback != nil {
137+
c.config.descCallback(c.desc)
138+
}
139+
if len(c.desc.Compression) > 0 {
140+
clientMethodLoop:
141+
for _, method := range c.config.compressors {
142+
for _, serverMethod := range c.desc.Compression {
143+
if method != serverMethod {
144+
continue
145+
}
137146

138-
switch strings.ToLower(method) {
139-
case "snappy":
140-
c.compressor = wiremessage.CompressorSnappy
141-
case "zlib":
142-
c.compressor = wiremessage.CompressorZLib
143-
c.zliblevel = wiremessage.DefaultZlibLevel
144-
if c.config.zlibLevel != nil {
145-
c.zliblevel = *c.config.zlibLevel
146-
}
147+
switch strings.ToLower(method) {
148+
case "snappy":
149+
c.compressor = wiremessage.CompressorSnappy
150+
case "zlib":
151+
c.compressor = wiremessage.CompressorZLib
152+
c.zliblevel = wiremessage.DefaultZlibLevel
153+
if c.config.zlibLevel != nil {
154+
c.zliblevel = *c.config.zlibLevel
147155
}
148-
break clientMethodLoop
149156
}
157+
break clientMethodLoop
150158
}
151159
}
152160
}
@@ -300,10 +308,15 @@ type initConnection struct{ *connection }
300308

301309
var _ driver.Connection = initConnection{}
302310

303-
func (c initConnection) Description() description.Server { return description.Server{} }
304-
func (c initConnection) Close() error { return nil }
305-
func (c initConnection) ID() string { return c.id }
306-
func (c initConnection) Address() address.Address { return c.addr }
311+
func (c initConnection) Description() description.Server {
312+
if c.connection == nil {
313+
return description.Server{}
314+
}
315+
return c.connection.desc
316+
}
317+
func (c initConnection) Close() error { return nil }
318+
func (c initConnection) ID() string { return c.id }
319+
func (c initConnection) Address() address.Address { return c.addr }
307320
func (c initConnection) LocalAddress() address.Address {
308321
if c.connection == nil || c.nc == nil {
309322
return address.Address("0.0.0.0")

x/mongo/driver/topology/connection_options.go

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,6 @@ var DefaultDialer Dialer = &net.Dialer{}
3535
// initialization. Implementations must be goroutine safe.
3636
type Handshaker = driver.Handshaker
3737

38-
// HandshakerFunc is an adapter to allow the use of ordinary functions as
39-
// connection handshakers.
40-
type HandshakerFunc = driver.HandshakerFunc
41-
4238
type connectionConfig struct {
4339
appName string
4440
connectTimeout time.Duration

x/mongo/driver/topology/connection_test.go

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,11 @@ func TestConnection(t *testing.T) {
6767
var want error = ConnectionError{Wrapped: err}
6868
conn, got := newConnection(context.Background(), address.Address(""),
6969
WithHandshaker(func(Handshaker) Handshaker {
70-
return HandshakerFunc(func(context.Context, address.Address, driver.Connection) (description.Server, error) {
71-
return description.Server{}, err
72-
})
70+
return &testHandshaker{
71+
finishHandshake: func(context.Context, driver.Connection) error {
72+
return err
73+
},
74+
}
7375
}),
7476
WithDialer(func(Dialer) Dialer {
7577
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {
@@ -92,9 +94,11 @@ func TestConnection(t *testing.T) {
9294
conn, err := newConnection(context.Background(), address.Address(""),
9395
withServerDescriptionCallback(func(desc description.Server) { got = desc },
9496
WithHandshaker(func(Handshaker) Handshaker {
95-
return HandshakerFunc(func(context.Context, address.Address, driver.Connection) (description.Server, error) {
96-
return want, nil
97-
})
97+
return &testHandshaker{
98+
getDescription: func(context.Context, address.Address, driver.Connection) (description.Server, error) {
99+
return want, nil
100+
},
101+
}
98102
}),
99103
WithDialer(func(Dialer) Dialer {
100104
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {

x/mongo/driver/topology/server_test.go

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,29 @@ func (cncd *channelNetConnDialer) DialContext(_ context.Context, _, _ string) (n
4747
return cnc, nil
4848
}
4949

50+
type testHandshaker struct {
51+
getDescription func(context.Context, address.Address, driver.Connection) (description.Server, error)
52+
finishHandshake func(context.Context, driver.Connection) error
53+
}
54+
55+
// GetDescription implements the Handshaker interface.
56+
func (th *testHandshaker) GetDescription(ctx context.Context, addr address.Address, conn driver.Connection) (description.Server, error) {
57+
if th.getDescription != nil {
58+
return th.getDescription(ctx, addr, conn)
59+
}
60+
return description.Server{}, nil
61+
}
62+
63+
// FinishHandshake implements the Handshaker interface.
64+
func (th *testHandshaker) FinishHandshake(ctx context.Context, conn driver.Connection) error {
65+
if th.finishHandshake != nil {
66+
return th.finishHandshake(ctx, conn)
67+
}
68+
return nil
69+
}
70+
71+
var _ driver.Handshaker = &testHandshaker{}
72+
5073
func TestServer(t *testing.T) {
5174
var serverTestTable = []struct {
5275
name string
@@ -69,13 +92,15 @@ func TestServer(t *testing.T) {
6992
WithConnectionOptions(func(connOpts ...ConnectionOption) []ConnectionOption {
7093
return append(connOpts,
7194
WithHandshaker(func(Handshaker) Handshaker {
72-
return HandshakerFunc(func(context.Context, address.Address, driver.Connection) (description.Server, error) {
73-
var err error
74-
if tt.connectionError {
75-
err = authErr.Wrapped
76-
}
77-
return description.Server{}, err
78-
})
95+
return &testHandshaker{
96+
finishHandshake: func(context.Context, driver.Connection) error {
97+
var err error
98+
if tt.connectionError {
99+
err = authErr.Wrapped
100+
}
101+
return err
102+
},
103+
}
79104
}),
80105
WithDialer(func(Dialer) Dialer {
81106
return DialerFunc(func(context.Context, string, string) (net.Conn, error) {

0 commit comments

Comments
 (0)