Skip to content

Commit 95b9288

Browse files
author
Divjot Arora
committed
Drain the connection pool if an auth error occurs
GODRIVER-442 Change-Id: I9f304471a89d4a770e43db7488311933b0b38960
1 parent 4a66187 commit 95b9288

File tree

13 files changed

+139
-57
lines changed

13 files changed

+139
-57
lines changed

core/auth/auth.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ func CreateAuthenticator(name string, cred *Cred) (Authenticator, error) {
3737
return f(cred)
3838
}
3939

40-
return nil, fmt.Errorf("unknown authenticator: %s", name)
40+
return nil, newAuthError(fmt.Sprintf("unknown authenticator: %s", name), nil)
4141
}
4242

4343
// RegisterAuthenticatorFactory registers the authenticator factory.
@@ -96,12 +96,12 @@ func Handshaker(appName string, h connection.Handshaker, authenticator Authentic
9696
return connection.HandshakerFunc(func(ctx context.Context, addr address.Address, rw wiremessage.ReadWriter) (description.Server, error) {
9797
desc, err := (&command.Handshake{Client: command.ClientDoc(appName)}).Handshake(ctx, addr, rw)
9898
if err != nil {
99-
return description.Server{}, err
99+
return description.Server{}, newAuthError("handshake failure", err)
100100
}
101101

102102
err = authenticator.Auth(ctx, desc, rw)
103103
if err != nil {
104-
return description.Server{}, err
104+
return description.Server{}, newAuthError("auth error", err)
105105
}
106106
if h == nil {
107107
return desc, nil
@@ -116,6 +116,13 @@ type Authenticator interface {
116116
Auth(context.Context, description.Server, wiremessage.ReadWriter) error
117117
}
118118

119+
func newAuthError(msg string, inner error) error {
120+
return &Error{
121+
message: msg,
122+
inner: inner,
123+
}
124+
}
125+
119126
func newError(err error, mech string) error {
120127
return &Error{
121128
message: fmt.Sprintf("unable to authenticate using mechanism \"%s\"", mech),
@@ -130,6 +137,9 @@ type Error struct {
130137
}
131138

132139
func (e *Error) Error() string {
140+
if e.inner == nil {
141+
return e.message
142+
}
133143
return fmt.Sprintf("%s: %s", e.message, e.inner)
134144
}
135145

core/auth/default.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ func (a *DefaultAuthenticator) Auth(ctx context.Context, desc description.Server
3636
}
3737

3838
if err != nil {
39-
return err
39+
return newAuthError("error creating authenticator", err)
4040
}
4141

4242
return actual.Auth(ctx, desc, rw)

core/auth/gssapi.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ package auth
1111

1212
import (
1313
"context"
14-
"fmt"
15-
1614
"github.com/mongodb/mongo-go-driver/core/auth/internal/gssapi"
1715
"github.com/mongodb/mongo-go-driver/core/description"
1816
"github.com/mongodb/mongo-go-driver/core/wiremessage"
@@ -23,7 +21,7 @@ const GSSAPI = "GSSAPI"
2321

2422
func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) {
2523
if cred.Source != "" && cred.Source != "$external" {
26-
return nil, fmt.Errorf("GSSAPI source must be empty or $external")
24+
return nil, newAuthError("GSSAPI source must be empty or $external", nil)
2725
}
2826

2927
return &GSSAPIAuthenticator{
@@ -47,7 +45,7 @@ func (a *GSSAPIAuthenticator) Auth(ctx context.Context, desc description.Server,
4745
client, err := gssapi.New(desc.Addr.String(), a.Username, a.Password, a.PasswordSet, a.Props)
4846

4947
if err != nil {
50-
return err
48+
return newAuthError("error creating gssapi", err)
5149
}
5250
return ConductSaslConversation(ctx, desc, rw, "$external", client)
5351
}

core/auth/gssapi_not_enabled.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,9 @@
88

99
package auth
1010

11-
import "fmt"
12-
1311
// GSSAPI is the mechanism name for GSSAPI.
1412
const GSSAPI = "GSSAPI"
1513

1614
func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) {
17-
return nil, fmt.Errorf("GSSAPI support not enabled during build (-tags gssapi)")
15+
return nil, newAuthError("GSSAPI support not enabled during build (-tags gssapi)", nil)
1816
}

core/auth/gssapi_not_supported.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@ import (
1717
const GSSAPI = "GSSAPI"
1818

1919
func newGSSAPIAuthenticator(cred *Cred) (Authenticator, error) {
20-
return nil, fmt.Errorf("GSSAPI is not supported on %s", runtime.GOOS)
20+
return nil, newAuthError(fmt.Sprintf("GSSAPI is not supported on %s", runtime.GOOS), nil)
2121
}

core/auth/mongodbcr.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ func (a *MongoDBCRAuthenticator) Auth(ctx context.Context, desc description.Serv
6969

7070
err = bson.Unmarshal(rdr, &getNonceResult)
7171
if err != nil {
72-
return err
72+
return newAuthError("unmarshal error", err)
7373
}
7474

7575
cmd = command.Command{

core/auth/plain.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,6 @@ package auth
88

99
import (
1010
"context"
11-
"fmt"
12-
1311
"github.com/mongodb/mongo-go-driver/core/description"
1412
"github.com/mongodb/mongo-go-driver/core/wiremessage"
1513
)
@@ -19,7 +17,7 @@ const PLAIN = "PLAIN"
1917

2018
func newPlainAuthenticator(cred *Cred) (Authenticator, error) {
2119
if cred.Source != "" && cred.Source != "$external" {
22-
return nil, fmt.Errorf("PLAIN source must be empty or $external")
20+
return nil, newAuthError("PLAIN source must be empty or $external", nil)
2321
}
2422

2523
return &PlainAuthenticator{
@@ -53,7 +51,7 @@ func (c *plainSaslClient) Start() (string, []byte, error) {
5351
}
5452

5553
func (c *plainSaslClient) Next(challenge []byte) ([]byte, error) {
56-
return nil, fmt.Errorf("unexpected server challenge")
54+
return nil, newAuthError("unexpected server challenge", nil)
5755
}
5856

5957
func (c *plainSaslClient) Completed() bool {

core/auth/sasl.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func ConductSaslConversation(ctx context.Context, desc description.Server, rw wi
7575

7676
err = bson.Unmarshal(rdr, &saslResp)
7777
if err != nil {
78-
return err
78+
return newAuthError("unmarshall error", err)
7979
}
8080

8181
cid := saslResp.ConversationID
@@ -114,7 +114,7 @@ func ConductSaslConversation(ctx context.Context, desc description.Server, rw wi
114114

115115
err = bson.Unmarshal(rdr, &saslResp)
116116
if err != nil {
117-
return err
117+
return newAuthError("unmarshal error", err)
118118
}
119119
}
120120
}

core/auth/scramsha1.go

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ import (
1111
"context"
1212
"crypto/hmac"
1313
"crypto/sha1"
14-
"fmt"
1514
"io"
1615
"math/rand"
1716
"strconv"
@@ -60,7 +59,7 @@ func (a *ScramSHA1Authenticator) Auth(ctx context.Context, desc description.Serv
6059

6160
err := ConductSaslConversation(ctx, desc, rw, a.DB, client)
6261
if err != nil {
63-
return err
62+
return newAuthError("sasl conversation error", err)
6463
}
6564

6665
a.clientKey = client.clientKey
@@ -82,7 +81,7 @@ type scramSaslClient struct {
8281

8382
func (c *scramSaslClient) Start() (string, []byte, error) {
8483
if err := c.generateClientNonce(scramSHA1NonceLen); err != nil {
85-
return SCRAMSHA1, nil, err
84+
return SCRAMSHA1, nil, newAuthError("generate nonce error", err)
8685
}
8786

8887
c.clientFirstMessageBare = "n=" + usernameSanitizer.Replace(c.username) + ",r=" + string(c.clientNonce)
@@ -98,7 +97,7 @@ func (c *scramSaslClient) Next(challenge []byte) ([]byte, error) {
9897
case 2:
9998
return c.step2(challenge)
10099
default:
101-
return nil, fmt.Errorf("unexpected server challenge")
100+
return nil, newAuthError("unexpected server challenge", nil)
102101
}
103102
}
104103

@@ -109,33 +108,33 @@ func (c *scramSaslClient) Completed() bool {
109108
func (c *scramSaslClient) step1(challenge []byte) ([]byte, error) {
110109
fields := bytes.Split(challenge, []byte{','})
111110
if len(fields) != 3 {
112-
return nil, fmt.Errorf("invalid server response")
111+
return nil, newAuthError("invalid server response", nil)
113112
}
114113

115114
if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 {
116-
return nil, fmt.Errorf("invalid nonce")
115+
return nil, newAuthError("invalid nonce", nil)
117116
}
118117
r := fields[0][2:]
119118
if !bytes.HasPrefix(r, c.clientNonce) {
120-
return nil, fmt.Errorf("invalid nonce")
119+
return nil, newAuthError("invalid nonce", nil)
121120
}
122121

123122
if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 {
124-
return nil, fmt.Errorf("invalid salt")
123+
return nil, newAuthError("invalid salt", nil)
125124
}
126125
s := make([]byte, base64.StdEncoding.DecodedLen(len(fields[1][2:])))
127126
n, err := base64.StdEncoding.Decode(s, fields[1][2:])
128127
if err != nil {
129-
return nil, fmt.Errorf("invalid salt")
128+
return nil, newAuthError("invalid salt", nil)
130129
}
131130
s = s[:n]
132131

133132
if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 3 {
134-
return nil, fmt.Errorf("invalid iteration count")
133+
return nil, newAuthError("invalid iteration count", nil)
135134
}
136135
i, err := strconv.Atoi(string(fields[2][2:]))
137136
if err != nil {
138-
return nil, fmt.Errorf("invalid iteration count")
137+
return nil, newAuthError("invalid iteration count", nil)
139138
}
140139

141140
clientFinalMessageWithoutProof := "c=biws,r=" + string(r)
@@ -167,21 +166,21 @@ func (c *scramSaslClient) step2(challenge []byte) ([]byte, error) {
167166
hasE = bytes.HasPrefix(fields[0], []byte("e="))
168167
}
169168
if hasE {
170-
return nil, fmt.Errorf(string(fields[0][2:]))
169+
return nil, newAuthError(string(fields[0][2:]), nil)
171170
}
172171
if !hasV {
173-
return nil, fmt.Errorf("invalid final message")
172+
return nil, newAuthError("invalid final message", nil)
174173
}
175174

176175
v := make([]byte, base64.StdEncoding.DecodedLen(len(fields[0][2:])))
177176
n, err := base64.StdEncoding.Decode(v, fields[0][2:])
178177
if err != nil {
179-
return nil, fmt.Errorf("invalid server verification")
178+
return nil, newAuthError("invalid server verification", nil)
180179
}
181180
v = v[:n]
182181

183182
if !bytes.Equal(c.serverSignature, v) {
184-
return nil, fmt.Errorf("invalid server signature")
183+
return nil, newAuthError("invalid server signature", nil)
185184
}
186185

187186
return nil, nil

core/auth/scramsha1_test.go

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ func TestScramSHA1Authenticator_Fails(t *testing.T) {
4747
err := authenticator.Auth(context.Background(), description.Server{}, c)
4848
require.Error(t, err)
4949

50-
errPrefix := "unable to authenticate using mechanism \"SCRAM-SHA-1\""
51-
require.True(t, strings.HasPrefix(err.Error(), errPrefix))
50+
errSubstring := "unable to authenticate using mechanism \"SCRAM-SHA-1\""
51+
require.True(t, strings.Contains(err.Error(), errSubstring))
5252
require.True(t, authenticator.IsClientKeyNil())
5353
}
5454

@@ -81,8 +81,8 @@ func TestScramSHA1Authenticator_Missing_challenge_fields(t *testing.T) {
8181
err := authenticator.Auth(context.Background(), description.Server{}, c)
8282
require.Error(t, err)
8383

84-
errPrefix := "unable to authenticate using mechanism \"SCRAM-SHA-1\": invalid server response"
85-
require.True(t, strings.HasPrefix(err.Error(), errPrefix))
84+
errSubstring := "unable to authenticate using mechanism \"SCRAM-SHA-1\": invalid server response"
85+
require.True(t, strings.Contains(err.Error(), errSubstring))
8686

8787
require.True(t, authenticator.IsClientKeyNil())
8888
}
@@ -115,8 +115,8 @@ func TestScramSHA1Authenticator_Invalid_server_nonce1(t *testing.T) {
115115

116116
err := authenticator.Auth(context.Background(), description.Server{}, c)
117117

118-
errPrefix := "unable to authenticate using mechanism \"SCRAM-SHA-1\": invalid nonce"
119-
require.True(t, strings.HasPrefix(err.Error(), errPrefix))
118+
errSubstring := "unable to authenticate using mechanism \"SCRAM-SHA-1\": invalid nonce"
119+
require.True(t, strings.Contains(err.Error(), errSubstring))
120120

121121
require.True(t, authenticator.IsClientKeyNil())
122122
}
@@ -150,8 +150,8 @@ func TestScramSHA1Authenticator_Invalid_server_nonce2(t *testing.T) {
150150
err := authenticator.Auth(context.Background(), description.Server{}, c)
151151
require.Error(t, err)
152152

153-
errPrefix := "unable to authenticate using mechanism \"SCRAM-SHA-1\": invalid nonce"
154-
require.True(t, strings.HasPrefix(err.Error(), errPrefix))
153+
errSubstring := "unable to authenticate using mechanism \"SCRAM-SHA-1\": invalid nonce"
154+
require.True(t, strings.Contains(err.Error(), errSubstring))
155155

156156
require.True(t, authenticator.IsClientKeyNil())
157157
}
@@ -185,8 +185,8 @@ func TestScramSHA1Authenticator_No_salt(t *testing.T) {
185185
err := authenticator.Auth(context.Background(), description.Server{}, c)
186186
require.Error(t, err)
187187

188-
errPrefix := "unable to authenticate using mechanism \"SCRAM-SHA-1\": invalid salt"
189-
require.True(t, strings.HasPrefix(err.Error(), errPrefix))
188+
errSubstring := "unable to authenticate using mechanism \"SCRAM-SHA-1\": invalid salt"
189+
require.True(t, strings.Contains(err.Error(), errSubstring))
190190

191191
require.True(t, authenticator.IsClientKeyNil())
192192
}
@@ -220,8 +220,8 @@ func TestScramSHA1Authenticator_No_iteration_count(t *testing.T) {
220220
err := authenticator.Auth(context.Background(), description.Server{}, c)
221221
require.Error(t, err)
222222

223-
errPrefix := "unable to authenticate using mechanism \"SCRAM-SHA-1\": invalid iteration count"
224-
require.True(t, strings.HasPrefix(err.Error(), errPrefix))
223+
errSubstring := "unable to authenticate using mechanism \"SCRAM-SHA-1\": invalid iteration count"
224+
require.True(t, strings.Contains(err.Error(), errSubstring))
225225

226226
require.True(t, authenticator.IsClientKeyNil())
227227
}
@@ -255,8 +255,8 @@ func TestScramSHA1Authenticator_Invalid_iteration_count(t *testing.T) {
255255
err := authenticator.Auth(context.Background(), description.Server{}, c)
256256
require.Error(t, err)
257257

258-
errPrefix := "unable to authenticate using mechanism \"SCRAM-SHA-1\": invalid iteration count"
259-
require.True(t, strings.HasPrefix(err.Error(), errPrefix))
258+
errSubstring := "unable to authenticate using mechanism \"SCRAM-SHA-1\": invalid iteration count"
259+
require.True(t, strings.Contains(err.Error(), errSubstring))
260260

261261
require.True(t, authenticator.IsClientKeyNil())
262262
}
@@ -297,8 +297,8 @@ func TestScramSHA1Authenticator_Invalid_server_signature(t *testing.T) {
297297
err := authenticator.Auth(context.Background(), description.Server{}, c)
298298
require.Error(t, err)
299299

300-
errPrefix := "unable to authenticate using mechanism \"SCRAM-SHA-1\": invalid server signature"
301-
require.True(t, strings.HasPrefix(err.Error(), errPrefix))
300+
errSubstring := "unable to authenticate using mechanism \"SCRAM-SHA-1\": invalid server signature"
301+
require.True(t, strings.Contains(err.Error(), errSubstring))
302302

303303
require.True(t, authenticator.IsClientKeyNil())
304304
}
@@ -339,8 +339,8 @@ func TestScramSHA1Authenticator_Server_provided_error(t *testing.T) {
339339
err := authenticator.Auth(context.Background(), description.Server{}, c)
340340
require.Error(t, err)
341341

342-
errPrefix := "unable to authenticate using mechanism \"SCRAM-SHA-1\": server passed error"
343-
require.True(t, strings.HasPrefix(err.Error(), errPrefix))
342+
errSubstring := "unable to authenticate using mechanism \"SCRAM-SHA-1\": server passed error"
343+
require.True(t, strings.Contains(err.Error(), errSubstring))
344344

345345
require.True(t, authenticator.IsClientKeyNil())
346346
}
@@ -381,8 +381,8 @@ func TestScramSHA1Authenticator_Invalid_final_message(t *testing.T) {
381381
err := authenticator.Auth(context.Background(), description.Server{}, c)
382382
require.Error(t, err)
383383

384-
errPrefix := "unable to authenticate using mechanism \"SCRAM-SHA-1\": invalid final message"
385-
require.True(t, strings.HasPrefix(err.Error(), errPrefix))
384+
errSubstring := "unable to authenticate using mechanism \"SCRAM-SHA-1\": invalid final message"
385+
require.True(t, strings.Contains(err.Error(), errSubstring))
386386

387387
require.True(t, authenticator.IsClientKeyNil())
388388
}
@@ -429,8 +429,8 @@ func TestScramSHA1Authenticator_Extra_message(t *testing.T) {
429429
err := authenticator.Auth(context.Background(), description.Server{}, c)
430430
require.Error(t, err)
431431

432-
errPrefix := "unable to authenticate using mechanism \"SCRAM-SHA-1\": unexpected server challenge"
433-
require.True(t, strings.HasPrefix(err.Error(), errPrefix))
432+
errSubstring := "unable to authenticate using mechanism \"SCRAM-SHA-1\": unexpected server challenge"
433+
require.True(t, strings.Contains(err.Error(), errSubstring))
434434

435435
require.True(t, authenticator.IsClientKeyNil())
436436
}

0 commit comments

Comments
 (0)