Skip to content

Commit e2d5932

Browse files
author
Isabella Siu
committed
GODRIVER-378 Unconnected Client returns mongo.ErrClientDisconnected
Change-Id: I7391f8f56521b09cec4203082ce57c44d49b87fb
1 parent 4991574 commit e2d5932

File tree

7 files changed

+186
-26
lines changed

7 files changed

+186
-26
lines changed

mongo/client.go

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ func NewClientFromConnString(cs connstring.ConnString) (*Client, error) {
9090
func (c *Client) Connect(ctx context.Context) error {
9191
err := c.topology.Connect(ctx)
9292
if err != nil {
93-
return err
93+
return replaceTopologyErr(err)
9494
}
9595

9696
return nil
@@ -107,7 +107,7 @@ func (c *Client) Connect(ctx context.Context) error {
107107
// associated with this Client have been closed.
108108
func (c *Client) Disconnect(ctx context.Context) error {
109109
c.endSessions(ctx)
110-
return c.topology.Disconnect(ctx)
110+
return replaceTopologyErr(c.topology.Disconnect(ctx))
111111
}
112112

113113
// Ping verifies that the client can connect to the topology.
@@ -123,13 +123,13 @@ func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error {
123123
}
124124

125125
_, err := c.topology.SelectServer(ctx, description.ReadPrefSelector(rp))
126-
return err
126+
return replaceTopologyErr(err)
127127
}
128128

129129
// StartSession starts a new session.
130130
func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error) {
131131
if c.topology.SessionPool == nil {
132-
return nil, topology.ErrTopologyClosed
132+
return nil, ErrClientDisconnected
133133
}
134134

135135
sopts := options.MergeSessionOptions(opts...)
@@ -153,7 +153,7 @@ func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error)
153153

154154
sess, err := session.NewClientSession(c.topology.SessionPool, c.id, session.Explicit, coreOpts)
155155
if err != nil {
156-
return nil, err
156+
return nil, replaceTopologyErr(err)
157157
}
158158

159159
sess.RetryWrite = c.retryWrites
@@ -165,6 +165,9 @@ func (c *Client) StartSession(opts ...*options.SessionOptions) (Session, error)
165165
}
166166

167167
func (c *Client) endSessions(ctx context.Context) {
168+
if c.topology.SessionPool == nil {
169+
return
170+
}
168171
cmd := command.EndSessions{
169172
Clock: c.clock,
170173
SessionIDs: c.topology.SessionPool.IDSlice(),
@@ -200,7 +203,7 @@ func newClient(cs connstring.ConnString, opts ...*options.ClientOptions) (*Clien
200203
)
201204
topo, err := topology.New(topts...)
202205
if err != nil {
203-
return nil, err
206+
return nil, replaceTopologyErr(err)
204207
}
205208
client.topology = topo
206209
client.clock = &session.ClusterClock{}
@@ -360,7 +363,7 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ...
360363
opts...,
361364
)
362365
if err != nil {
363-
return ListDatabasesResult{}, err
366+
return ListDatabasesResult{}, replaceTopologyErr(err)
364367
}
365368

366369
return (ListDatabasesResult{}).fromResult(res), nil

mongo/client_internal_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,32 @@ func TestClient_X509Auth(t *testing.T) {
186186
t.Error("unable to find authenticated user")
187187
}
188188

189+
func TestClient_ReplaceTopologyError(t *testing.T) {
190+
t.Parallel()
191+
192+
if testing.Short() {
193+
t.Skip()
194+
}
195+
196+
cs := testutil.ConnString(t)
197+
c, err := NewClient(cs.String())
198+
require.NoError(t, err)
199+
require.NotNil(t, c)
200+
201+
_, err = c.StartSession()
202+
require.Equal(t, err, ErrClientDisconnected)
203+
204+
_, err = c.ListDatabases(ctx, nil)
205+
require.Equal(t, err, ErrClientDisconnected)
206+
207+
err = c.Ping(ctx, nil)
208+
require.Equal(t, err, ErrClientDisconnected)
209+
210+
err = c.Disconnect(ctx)
211+
require.Equal(t, err, ErrClientDisconnected)
212+
213+
}
214+
189215
func TestClient_ListDatabases_noFilter(t *testing.T) {
190216
t.Parallel()
191217

mongo/collection.go

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ func (coll *Collection) BulkWrite(ctx context.Context, models []WriteModel,
191191
}
192192
}
193193

194-
return &BulkWriteResult{}, err
194+
return &BulkWriteResult{}, replaceTopologyErr(err)
195195
}
196196

197197
return &BulkWriteResult{
@@ -339,7 +339,7 @@ func (coll *Collection) InsertMany(ctx context.Context, documents []interface{},
339339
case command.ErrUnacknowledgedWrite:
340340
return &InsertManyResult{InsertedIDs: result}, ErrUnacknowledgedWrite
341341
default:
342-
return nil, err
342+
return nil, replaceTopologyErr(err)
343343
}
344344
if len(res.WriteErrors) > 0 || res.WriteConcernError != nil {
345345
bwErrors := make([]BulkWriteError, 0, len(res.WriteErrors))
@@ -523,7 +523,7 @@ func (coll *Collection) updateOrReplaceOne(ctx context.Context, filter,
523523
opts...,
524524
)
525525
if err != nil && err != command.ErrUnacknowledgedWrite {
526-
return nil, err
526+
return nil, replaceTopologyErr(err)
527527
}
528528

529529
res := &UpdateResult{
@@ -645,7 +645,7 @@ func (coll *Collection) UpdateMany(ctx context.Context, filter interface{}, upda
645645
opts...,
646646
)
647647
if err != nil && err != command.ErrUnacknowledgedWrite {
648-
return nil, err
648+
return nil, replaceTopologyErr(err)
649649
}
650650
res := &UpdateResult{
651651
MatchedCount: r.MatchedCount,
@@ -760,7 +760,7 @@ func (coll *Collection) Aggregate(ctx context.Context, pipeline interface{},
760760
Clock: coll.client.clock,
761761
}
762762

763-
return dispatch.Aggregate(
763+
cursor, err := dispatch.Aggregate(
764764
ctx, cmd,
765765
coll.client.topology,
766766
coll.readSelector,
@@ -770,6 +770,8 @@ func (coll *Collection) Aggregate(ctx context.Context, pipeline interface{},
770770
coll.registry,
771771
aggOpts,
772772
)
773+
774+
return cursor, replaceTopologyErr(err)
773775
}
774776

775777
// Count gets the number of documents matching the filter. A user can supply a
@@ -812,7 +814,7 @@ func (coll *Collection) Count(ctx context.Context, filter interface{},
812814
Clock: coll.client.clock,
813815
}
814816

815-
return dispatch.Count(
817+
count, err := dispatch.Count(
816818
ctx, cmd,
817819
coll.client.topology,
818820
coll.readSelector,
@@ -821,6 +823,8 @@ func (coll *Collection) Count(ctx context.Context, filter interface{},
821823
coll.registry,
822824
opts...,
823825
)
826+
827+
return count, replaceTopologyErr(err)
824828
}
825829

826830
// CountDocuments gets the number of documents matching the filter. A user can supply a
@@ -864,7 +868,7 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{},
864868
Clock: coll.client.clock,
865869
}
866870

867-
return dispatch.CountDocuments(
871+
count, err := dispatch.CountDocuments(
868872
ctx, cmd,
869873
coll.client.topology,
870874
coll.readSelector,
@@ -873,6 +877,8 @@ func (coll *Collection) CountDocuments(ctx context.Context, filter interface{},
873877
coll.registry,
874878
countOpts,
875879
)
880+
881+
return count, replaceTopologyErr(err)
876882
}
877883

878884
// EstimatedDocumentCount gets an estimate of the count of documents in a collection using collection metadata.
@@ -910,7 +916,7 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context,
910916
countOpts = countOpts.SetMaxTime(*opts[len(opts)-1].MaxTime)
911917
}
912918

913-
return dispatch.Count(
919+
count, err := dispatch.Count(
914920
ctx, cmd,
915921
coll.client.topology,
916922
coll.readSelector,
@@ -919,6 +925,8 @@ func (coll *Collection) EstimatedDocumentCount(ctx context.Context,
919925
coll.registry,
920926
countOpts,
921927
)
928+
929+
return count, replaceTopologyErr(err)
922930
}
923931

924932
// Distinct finds the distinct values for a specified field across a single
@@ -976,7 +984,7 @@ func (coll *Collection) Distinct(ctx context.Context, fieldName string, filter i
976984
opts...,
977985
)
978986
if err != nil {
979-
return nil, err
987+
return nil, replaceTopologyErr(err)
980988
}
981989

982990
return res.Values, nil
@@ -1026,7 +1034,7 @@ func (coll *Collection) Find(ctx context.Context, filter interface{},
10261034
Clock: coll.client.clock,
10271035
}
10281036

1029-
return dispatch.Find(
1037+
cursor, err := dispatch.Find(
10301038
ctx, cmd,
10311039
coll.client.topology,
10321040
coll.readSelector,
@@ -1035,6 +1043,8 @@ func (coll *Collection) Find(ctx context.Context, filter interface{},
10351043
coll.registry,
10361044
opts...,
10371045
)
1046+
1047+
return cursor, replaceTopologyErr(err)
10381048
}
10391049

10401050
// FindOne returns up to one document that matches the model. A user can
@@ -1115,7 +1125,7 @@ func (coll *Collection) FindOne(ctx context.Context, filter interface{},
11151125
findOpts...,
11161126
)
11171127
if err != nil {
1118-
return &DocumentResult{err: err}
1128+
return &DocumentResult{err: replaceTopologyErr(err)}
11191129
}
11201130

11211131
return &DocumentResult{cur: cursor, reg: coll.registry}
@@ -1178,7 +1188,7 @@ func (coll *Collection) FindOneAndDelete(ctx context.Context, filter interface{}
11781188
opts...,
11791189
)
11801190
if err != nil {
1181-
return &DocumentResult{err: err}
1191+
return &DocumentResult{err: replaceTopologyErr(err)}
11821192
}
11831193

11841194
return &DocumentResult{rdr: res.Value, reg: coll.registry}
@@ -1247,7 +1257,7 @@ func (coll *Collection) FindOneAndReplace(ctx context.Context, filter interface{
12471257
opts...,
12481258
)
12491259
if err != nil {
1250-
return &DocumentResult{err: err}
1260+
return &DocumentResult{err: replaceTopologyErr(err)}
12511261
}
12521262

12531263
return &DocumentResult{rdr: res.Value, reg: coll.registry}
@@ -1316,7 +1326,7 @@ func (coll *Collection) FindOneAndUpdate(ctx context.Context, filter interface{}
13161326
opts...,
13171327
)
13181328
if err != nil {
1319-
return &DocumentResult{err: err}
1329+
return &DocumentResult{err: replaceTopologyErr(err)}
13201330
}
13211331

13221332
return &DocumentResult{rdr: res.Value, reg: coll.registry}
@@ -1368,7 +1378,7 @@ func (coll *Collection) Drop(ctx context.Context) error {
13681378
coll.client.topology.SessionPool,
13691379
)
13701380
if err != nil && !command.IsNotFound(err) {
1371-
return err
1381+
return replaceTopologyErr(err)
13721382
}
13731383
return nil
13741384
}

mongo/collection_internal_test.go

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,89 @@ func TestCollection_InheritOptions(t *testing.T) {
152152
}
153153
}
154154

155+
func TestCollection_ReplaceTopologyError(t *testing.T) {
156+
t.Parallel()
157+
158+
if testing.Short() {
159+
t.Skip()
160+
}
161+
162+
cs := testutil.ConnString(t)
163+
c, err := NewClient(cs.String())
164+
require.NoError(t, err)
165+
require.NotNil(t, c)
166+
167+
db := c.Database("TestCollection")
168+
coll := db.Collection("ReplaceTopologyError")
169+
170+
doc1 := bson.NewDocument(bson.EC.Int32("x", 1))
171+
doc2 := bson.NewDocument(bson.EC.Int32("x", 6))
172+
docs := []interface{}{doc1, doc2}
173+
update := bson.NewDocument(
174+
bson.EC.SubDocumentFromElements("$inc", bson.EC.Int32("x", 1)))
175+
176+
_, err = coll.InsertOne(context.Background(), doc1)
177+
require.Equal(t, err, ErrClientDisconnected)
178+
179+
_, err = coll.InsertMany(context.Background(), docs)
180+
require.Equal(t, err, ErrClientDisconnected)
181+
182+
_, err = coll.DeleteOne(context.Background(), doc1)
183+
require.Equal(t, err, ErrClientDisconnected)
184+
185+
_, err = coll.DeleteMany(context.Background(), doc1)
186+
require.Equal(t, err, ErrClientDisconnected)
187+
188+
_, err = coll.UpdateOne(context.Background(), doc1, update)
189+
require.Equal(t, err, ErrClientDisconnected)
190+
191+
_, err = coll.UpdateMany(context.Background(), doc1, update)
192+
require.Equal(t, err, ErrClientDisconnected)
193+
194+
_, err = coll.ReplaceOne(context.Background(), doc1, doc2)
195+
require.Equal(t, err, ErrClientDisconnected)
196+
197+
pipeline := bson.NewArray(
198+
bson.VC.DocumentFromElements(
199+
bson.EC.SubDocumentFromElements(
200+
"$match",
201+
bson.EC.SubDocumentFromElements(
202+
"x",
203+
bson.EC.Int32("$gte", 2),
204+
),
205+
),
206+
))
207+
_, err = coll.Aggregate(context.Background(), pipeline, options.Aggregate())
208+
require.Equal(t, err, ErrClientDisconnected)
209+
210+
_, err = coll.Count(context.Background(), nil)
211+
require.Equal(t, err, ErrClientDisconnected)
212+
213+
_, err = coll.CountDocuments(context.Background(), nil)
214+
require.Equal(t, err, ErrClientDisconnected)
215+
216+
_, err = coll.EstimatedDocumentCount(context.Background())
217+
require.Equal(t, err, ErrClientDisconnected)
218+
219+
_, err = coll.Distinct(context.Background(), "x", nil)
220+
require.Equal(t, err, ErrClientDisconnected)
221+
222+
_, err = coll.Find(context.Background(), doc1)
223+
require.Equal(t, err, ErrClientDisconnected)
224+
225+
result := coll.FindOne(context.Background(), doc1)
226+
require.Equal(t, result.err, ErrClientDisconnected)
227+
228+
result = coll.FindOneAndDelete(context.Background(), doc1)
229+
require.Equal(t, result.err, ErrClientDisconnected)
230+
231+
result = coll.FindOneAndReplace(context.Background(), doc1, doc2)
232+
require.Equal(t, result.err, ErrClientDisconnected)
233+
234+
result = coll.FindOneAndUpdate(context.Background(), doc1, update)
235+
require.Equal(t, result.err, ErrClientDisconnected)
236+
}
237+
155238
func TestCollection_namespace(t *testing.T) {
156239
t.Parallel()
157240

0 commit comments

Comments
 (0)