Skip to content

Commit b5a8fcf

Browse files
Implement improved sessions API
GODRIVER-587 Change-Id: Id99fa7a48ffb30b2d822227ac1b95dd36720b2a7
1 parent b7c48f5 commit b5a8fcf

11 files changed

+555
-463
lines changed

mongo/causal_consistency_test.go

Lines changed: 88 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ package mongo
99
import (
1010
"context"
1111
"os"
12-
"reflect"
1312
"testing"
1413

1514
"github.com/mongodb/mongo-go-driver/bson"
@@ -82,11 +81,11 @@ func createReadFuncMap(t *testing.T, dbName string, collName string) (*Client, *
8281
coll.writeConcern = writeconcern.New(writeconcern.WMajority())
8382

8483
functions := []CollFunction{
85-
{"Aggregate", reflect.ValueOf(coll.Aggregate), []interface{}{ctx, emptyDoc}},
86-
{"Count", reflect.ValueOf(coll.Count), []interface{}{ctx, emptyDoc}},
87-
{"Distinct", reflect.ValueOf(coll.Distinct), []interface{}{ctx, "field", emptyDoc}},
88-
{"Find", reflect.ValueOf(coll.Find), []interface{}{ctx, emptyDoc}},
89-
{"FindOne", reflect.ValueOf(coll.FindOne), []interface{}{ctx, emptyDoc}},
84+
{"Aggregate", coll, nil, func(mctx SessionContext) error { _, err := coll.Aggregate(mctx, emptyDoc); return err }},
85+
{"Count", coll, nil, func(mctx SessionContext) error { _, err := coll.Count(mctx, emptyDoc); return err }},
86+
{"Distinct", coll, nil, func(mctx SessionContext) error { _, err := coll.Distinct(mctx, "field", emptyDoc); return err }},
87+
{"Find", coll, nil, func(mctx SessionContext) error { _, err := coll.Find(mctx, emptyDoc); return err }},
88+
{"FindOne", coll, nil, func(mctx SessionContext) error { res := coll.FindOne(mctx, emptyDoc); return res.err }},
9089
}
9190

9291
_, err = coll.InsertOne(ctx, startingDoc)
@@ -137,25 +136,31 @@ func createWriteFuncMap(t *testing.T, dbName string, collName string) (*Client,
137136
manyIndexes := []IndexModel{barIndex, bazIndex}
138137

139138
functions := []CollFunction{
140-
{"InsertOne", reflect.ValueOf(coll.InsertOne), []interface{}{ctx, doc}},
141-
{"InsertMany", reflect.ValueOf(coll.InsertMany), []interface{}{ctx, []interface{}{doc2}}},
142-
{"DeleteOne", reflect.ValueOf(coll.DeleteOne), []interface{}{ctx, emptyDoc}},
143-
{"DeleteMany", reflect.ValueOf(coll.DeleteMany), []interface{}{ctx, emptyDoc}},
144-
{"UpdateOne", reflect.ValueOf(coll.UpdateOne), []interface{}{ctx, emptyDoc, updateDoc}},
145-
{"UpdateMany", reflect.ValueOf(coll.UpdateMany), []interface{}{ctx, emptyDoc, updateDoc}},
146-
{"ReplaceOne", reflect.ValueOf(coll.ReplaceOne), []interface{}{ctx, emptyDoc, emptyDoc}},
147-
{"FindOneAndDelete", reflect.ValueOf(coll.FindOneAndDelete), []interface{}{ctx, emptyDoc}},
148-
{"FindOneAndReplace", reflect.ValueOf(coll.FindOneAndReplace), []interface{}{ctx, emptyDoc, emptyDoc}},
149-
{"FindOneAndUpdate", reflect.ValueOf(coll.FindOneAndUpdate), []interface{}{ctx, emptyDoc, updateDoc}},
150-
{"DropCollection", reflect.ValueOf(coll.Drop), []interface{}{ctx}},
151-
{"DropDatabase", reflect.ValueOf(db.Drop), []interface{}{ctx}},
152-
{"ListCollections", reflect.ValueOf(db.ListCollections), []interface{}{ctx, emptyDoc}},
153-
{"ListDatabases", reflect.ValueOf(client.ListDatabases), []interface{}{ctx, emptyDoc}},
154-
{"CreateOneIndex", reflect.ValueOf(iv.CreateOne), []interface{}{ctx, fooIndex}},
155-
{"CreateManyIndexes", reflect.ValueOf(iv.CreateMany), []interface{}{ctx, manyIndexes}},
156-
{"DropOneIndex", reflect.ValueOf(iv.DropOne), []interface{}{ctx, "barIndex"}},
157-
{"DropAllIndexes", reflect.ValueOf(iv.DropAll), []interface{}{ctx}},
158-
{"ListIndexes", reflect.ValueOf(iv.List), []interface{}{ctx}},
139+
{"InsertOne", coll, nil, func(mctx SessionContext) error { _, err := coll.InsertOne(mctx, doc); return err }},
140+
{"InsertMany", coll, nil, func(mctx SessionContext) error { _, err := coll.InsertMany(mctx, []interface{}{doc2}); return err }},
141+
{"DeleteOne", coll, nil, func(mctx SessionContext) error { _, err := coll.DeleteOne(mctx, emptyDoc); return err }},
142+
{"DeleteMany", coll, nil, func(mctx SessionContext) error { _, err := coll.DeleteMany(mctx, emptyDoc); return err }},
143+
{"UpdateOne", coll, nil, func(mctx SessionContext) error { _, err := coll.UpdateOne(mctx, emptyDoc, updateDoc); return err }},
144+
{"UpdateMany", coll, nil, func(mctx SessionContext) error { _, err := coll.UpdateMany(mctx, emptyDoc, updateDoc); return err }},
145+
{"ReplaceOne", coll, nil, func(mctx SessionContext) error { _, err := coll.ReplaceOne(mctx, emptyDoc, emptyDoc); return err }},
146+
{"FindOneAndDelete", coll, nil, func(mctx SessionContext) error { res := coll.FindOneAndDelete(mctx, emptyDoc); return res.err }},
147+
{"FindOneAndReplace", coll, nil, func(mctx SessionContext) error {
148+
res := coll.FindOneAndReplace(mctx, emptyDoc, emptyDoc)
149+
return res.err
150+
}},
151+
{"FindOneAndUpdate", coll, nil, func(mctx SessionContext) error {
152+
res := coll.FindOneAndUpdate(mctx, emptyDoc, updateDoc)
153+
return res.err
154+
}},
155+
{"DropCollection", coll, nil, func(mctx SessionContext) error { err := coll.Drop(mctx); return err }},
156+
{"DropDatabase", coll, nil, func(mctx SessionContext) error { err := db.Drop(mctx); return err }},
157+
{"ListCollections", coll, nil, func(mctx SessionContext) error { _, err := db.ListCollections(mctx, emptyDoc); return err }},
158+
{"ListDatabases", coll, nil, func(mctx SessionContext) error { _, err := client.ListDatabases(mctx, emptyDoc); return err }},
159+
{"CreateOneIndex", coll, nil, func(mctx SessionContext) error { _, err := iv.CreateOne(mctx, fooIndex); return err }},
160+
{"CreateManyIndexes", coll, nil, func(mctx SessionContext) error { _, err := iv.CreateMany(mctx, manyIndexes); return err }},
161+
{"DropOneIndex", coll, nil, func(mctx SessionContext) error { _, err := iv.DropOne(mctx, "barIndex"); return err }},
162+
{"DropAllIndexes", coll, nil, func(mctx SessionContext) error { _, err := iv.DropAll(mctx); return err }},
163+
{"ListIndexes", coll, nil, func(mctx SessionContext) error { _, err := iv.List(mctx); return err }},
159164
}
160165

161166
return client, db, coll, functions
@@ -183,7 +188,7 @@ func TestCausalConsistency(t *testing.T) {
183188
testhelpers.RequireNil(t, err, "error creating session: %s", err)
184189
defer sess.EndSession(ctx)
185190

186-
if sess.OperationTime != nil {
191+
if sess.OperationTime() != nil {
187192
t.Fatal("operation time is not nil")
188193
}
189194
})
@@ -192,16 +197,17 @@ func TestCausalConsistency(t *testing.T) {
192197
// First read in causally consistent session must not send afterClusterTime to the server
193198

194199
client := createSessionsMonitoredClient(t, ccMonitor)
195-
sess, err := client.StartSession(sessionopt.CausalConsistency(true))
196-
testhelpers.RequireNil(t, err, "error creating session: %s", err)
197-
defer sess.EndSession(ctx)
198200

199201
db := client.Database("FirstCommandDB")
200-
err = db.Drop(ctx)
202+
err := db.Drop(ctx)
201203
testhelpers.RequireNil(t, err, "error dropping db: %s", err)
202204

203205
coll := db.Collection("FirstCommandColl")
204-
_, err = coll.Find(ctx, emptyDoc, sess)
206+
err = client.UseSessionWithOptions(ctx, []sessionopt.Session{sessionopt.CausalConsistency(true)},
207+
func(mctx SessionContext) error {
208+
_, err := coll.Find(mctx, emptyDoc)
209+
return err
210+
})
205211
testhelpers.RequireNil(t, err, "error running find: %s", err)
206212

207213
testhelpers.RequireNotNil(t, ccStarted, "no started command found")
@@ -228,13 +234,16 @@ func TestCausalConsistency(t *testing.T) {
228234
testhelpers.RequireNil(t, err, "error dropping db: %s", err)
229235

230236
coll := db.Collection("OptimeUpdateColl")
231-
_, _ = coll.Find(ctx, emptyDoc, sess)
237+
_ = WithSession(ctx, sess, func(mctx SessionContext) error {
238+
_, _ = coll.Find(mctx, emptyDoc)
239+
return nil
240+
})
232241

233242
testhelpers.RequireNotNil(t, ccSucceeded, "no succeeded command")
234243
serverT, serverI := ccSucceeded.Reply.Lookup("operationTime").Timestamp()
235244

236-
testhelpers.RequireNotNil(t, sess.OperationTime, "operation time nil after first command")
237-
compareOperationTimes(t, &bson.Timestamp{serverT, serverI}, sess.OperationTime)
245+
testhelpers.RequireNotNil(t, sess.OperationTime(), "operation time nil after first command")
246+
compareOperationTimes(t, &bson.Timestamp{serverT, serverI}, sess.OperationTime())
238247
})
239248

240249
t.Run("TestOperationTimeSent", func(t *testing.T) {
@@ -252,14 +261,15 @@ func TestCausalConsistency(t *testing.T) {
252261
testhelpers.RequireNil(t, err, "error creating session for %s: %s", tc.name, err)
253262
defer sess.EndSession(ctx)
254263

255-
opts := append(tc.opts, sess)
256-
docRes := coll.FindOne(ctx, emptyDoc, sess)
257-
testhelpers.RequireNil(t, docRes.err, "find one error for %s: %s", tc.name, docRes.err)
264+
err = WithSession(ctx, sess, func(mctx SessionContext) error {
265+
docRes := coll.FindOne(mctx, emptyDoc)
266+
return docRes.err
267+
})
268+
testhelpers.RequireNil(t, err, "find one error for %s: %s", tc.name, err)
258269

259-
currOptime := sess.OperationTime
270+
currOptime := sess.OperationTime()
260271

261-
returnVals := tc.f.Call(getOptValues(opts))
262-
err = getReturnError(returnVals)
272+
err = WithSession(ctx, sess, tc.f)
263273
testhelpers.RequireNil(t, err, "error running %s: %s", tc.name, err)
264274

265275
testhelpers.RequireNotNil(t, ccStarted, "no started command")
@@ -285,13 +295,15 @@ func TestCausalConsistency(t *testing.T) {
285295
testhelpers.RequireNil(t, err, "error starting session: %s", err)
286296
defer sess.EndSession(ctx)
287297

288-
opts := append(tc.opts, sess)
289-
returnVals := tc.f.Call(getOptValues(opts))
290-
err = getReturnError(returnVals)
298+
err = WithSession(ctx, sess, tc.f)
291299
testhelpers.RequireNil(t, err, "error running %s: %s", tc.name, err)
292300

293-
currentOptime := sess.OperationTime
294-
_ = coll.FindOne(ctx, emptyDoc, sess)
301+
currentOptime := sess.OperationTime()
302+
303+
_ = WithSession(ctx, sess, func(mctx SessionContext) error {
304+
_ = coll.FindOne(mctx, emptyDoc)
305+
return nil
306+
})
295307

296308
testhelpers.RequireNotNil(t, ccStarted, "no started command")
297309
sentOptime := getOperationTime(t, ccStarted.Command)
@@ -308,16 +320,17 @@ func TestCausalConsistency(t *testing.T) {
308320
skipIfBelow36(t)
309321

310322
client := createSessionsMonitoredClient(t, ccMonitor)
311-
sess, err := client.StartSession(sessionopt.CausalConsistency(false))
312-
testhelpers.RequireNil(t, err, "error creating session: %s", err)
313-
defer sess.EndSession(ctx)
314323

315324
db := client.Database("NonConsistentReadDB")
316-
err = db.Drop(ctx)
325+
err := db.Drop(ctx)
317326
testhelpers.RequireNil(t, err, "error dropping db: %s", err)
318327

319328
coll := db.Collection("NonConsistentReadColl")
320-
_, _ = coll.Find(ctx, emptyDoc, sess)
329+
_ = client.UseSessionWithOptions(ctx, []sessionopt.Session{sessionopt.CausalConsistency(false)},
330+
func(mctx SessionContext) error {
331+
_, _ = coll.Find(mctx, emptyDoc)
332+
return nil
333+
})
321334

322335
testhelpers.RequireNotNil(t, ccStarted, "no started command")
323336
if ccStarted.CommandName != "find" {
@@ -338,12 +351,12 @@ func TestCausalConsistency(t *testing.T) {
338351

339352
skipIfSessionsSupported(t, db)
340353

341-
sess, err := client.StartSession(sessionopt.CausalConsistency(true))
342-
testhelpers.RequireNil(t, err, "error starting session: %s", err)
343-
defer sess.EndSession(ctx)
344-
345354
coll := db.Collection("InvalidTopologyColl")
346-
_, _ = coll.Find(ctx, emptyDoc, sess)
355+
_ = client.UseSessionWithOptions(ctx, []sessionopt.Session{sessionopt.CausalConsistency(true)},
356+
func(mctx SessionContext) error {
357+
_, _ = coll.Find(mctx, emptyDoc)
358+
return nil
359+
})
347360

348361
testhelpers.RequireNotNil(t, ccStarted, "no started command found")
349362
if ccStarted.CommandName != "find" {
@@ -370,10 +383,16 @@ func TestCausalConsistency(t *testing.T) {
370383

371384
coll := db.Collection("DefaultReadConcernColl")
372385
coll.readConcern = readconcern.New()
373-
_ = coll.FindOne(ctx, emptyDoc, sess)
386+
_ = WithSession(ctx, sess, func(mctx SessionContext) error {
387+
_ = coll.FindOne(mctx, emptyDoc)
388+
return nil
389+
})
374390

375-
currOptime := sess.OperationTime
376-
_ = coll.FindOne(ctx, emptyDoc, sess)
391+
currOptime := sess.OperationTime()
392+
_ = WithSession(ctx, sess, func(mctx SessionContext) error {
393+
_ = coll.FindOne(mctx, emptyDoc)
394+
return nil
395+
})
377396

378397
testhelpers.RequireNotNil(t, ccStarted, "no started command found")
379398
if ccStarted.CommandName != "find" {
@@ -402,10 +421,17 @@ func TestCausalConsistency(t *testing.T) {
402421
coll := db.Collection("CustomReadConcernColl")
403422
coll.readConcern = readconcern.Majority()
404423

405-
_ = coll.FindOne(ctx, emptyDoc, sess)
406-
currOptime := sess.OperationTime
424+
_ = WithSession(ctx, sess, func(mctx SessionContext) error {
425+
_ = coll.FindOne(mctx, emptyDoc)
426+
return nil
427+
})
428+
429+
currOptime := sess.OperationTime()
407430

408-
_ = coll.FindOne(ctx, emptyDoc, sess)
431+
_ = WithSession(ctx, sess, func(mctx SessionContext) error {
432+
_ = coll.FindOne(mctx, emptyDoc)
433+
return nil
434+
})
409435

410436
testhelpers.RequireNotNil(t, ccStarted, "no started command found")
411437
if ccStarted.CommandName != "find" {
@@ -431,7 +457,7 @@ func TestCausalConsistency(t *testing.T) {
431457
coll.writeConcern = writeconcern.New(writeconcern.W(0))
432458
_, _ = coll.InsertOne(ctx, doc)
433459

434-
if sess.OperationTime != nil {
460+
if sess.OperationTime() != nil {
435461
t.Fatal("operation time updated for unacknowledged write")
436462
}
437463
})

mongo/change_stream.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,13 @@ func newChangeStream(ctx context.Context, coll *Collection, pipeline interface{}
4646
return nil, err
4747
}
4848

49-
csOpts, sess, err := changestreamopt.BundleChangeStream(opts...).Unbundle(true)
49+
csOpts, _, err := changestreamopt.BundleChangeStream(opts...).Unbundle(true)
5050
if err != nil {
5151
return nil, err
5252
}
5353

54+
sess := sessionFromContext(ctx)
55+
5456
err = coll.client.ValidSession(sess)
5557
if err != nil {
5658
return nil, err

mongo/client.go

Lines changed: 53 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ func (c *Client) Ping(ctx context.Context, rp *readpref.ReadPref) error {
131131
}
132132

133133
// StartSession starts a new session.
134-
func (c *Client) StartSession(opts ...sessionopt.Session) (*Session, error) {
134+
func (c *Client) StartSession(opts ...sessionopt.Session) (Session, error) {
135135
if c.topology.SessionPool == nil {
136136
return nil, topology.ErrTopologyClosed
137137
}
@@ -156,7 +156,7 @@ func (c *Client) StartSession(opts ...sessionopt.Session) (*Session, error) {
156156

157157
sess.RetryWrite = c.retryWrites
158158

159-
return &Session{
159+
return &sessionImpl{
160160
Client: sess,
161161
topo: c.topology,
162162
}, nil
@@ -333,11 +333,13 @@ func (c *Client) ListDatabases(ctx context.Context, filter interface{}, opts ...
333333
if ctx == nil {
334334
ctx = context.Background()
335335
}
336-
listDbOpts, sess, err := listdbopt.BundleListDatabases(opts...).Unbundle(true)
336+
listDbOpts, _, err := listdbopt.BundleListDatabases(opts...).Unbundle(true)
337337
if err != nil {
338338
return ListDatabasesResult{}, err
339339
}
340340

341+
sess := sessionFromContext(ctx)
342+
341343
err = c.ValidSession(sess)
342344
if err != nil {
343345
return ListDatabasesResult{}, err
@@ -384,3 +386,51 @@ func (c *Client) ListDatabaseNames(ctx context.Context, filter interface{}, opts
384386

385387
return names, nil
386388
}
389+
390+
// WithSession allows a user to start a session themselves and manage
391+
// its lifetime. The only way to provide a session to a CRUD method is
392+
// to invoke that CRUD method with the mongo.SessionContext within the
393+
// closure. The mongo.SessionContext can be used as a regular context,
394+
// so methods like context.WithDeadline and context.WithTimeout are
395+
// supported.
396+
//
397+
// If the context.Context already has a mongo.Session attached, that
398+
// mongo.Session will be replaced with the one provided.
399+
//
400+
// Errors returned from the closure are transparently returned from
401+
// this function.
402+
func WithSession(ctx context.Context, sess Session, fn func(SessionContext) error) error {
403+
return fn(contextWithSession(ctx, sess))
404+
}
405+
406+
// UseSession creates a default session, that is only valid for the
407+
// lifetime of the closure. No cleanup outside of closing the session
408+
// is done upon exiting the closure. This means that an outstanding
409+
// transaction will be aborted, even if the closure returns an error.
410+
//
411+
// If ctx already contains a mongo.Session, that mongo.Session will be
412+
// replaced with the newly created mongo.Session.
413+
//
414+
// Errors returned from the closure are transparently returned from
415+
// this method.
416+
func (c *Client) UseSession(ctx context.Context, fn func(SessionContext) error) error {
417+
return c.UseSessionWithOptions(ctx, []sessionopt.Session{}, fn)
418+
}
419+
420+
// UseSessionWithOptions works like UseSession but allows the caller
421+
// to specify the options used to create the session.
422+
func (c *Client) UseSessionWithOptions(ctx context.Context, opts []sessionopt.Session, fn func(SessionContext) error) error {
423+
defaultSess, err := c.StartSession(opts...)
424+
if err != nil {
425+
return err
426+
}
427+
428+
defer defaultSess.EndSession(ctx)
429+
430+
sessCtx := sessionContext{
431+
Context: context.WithValue(ctx, sessionKey{}, defaultSess),
432+
Session: defaultSess,
433+
}
434+
435+
return fn(sessCtx)
436+
}

mongo/client_internal_test.go

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -390,19 +390,22 @@ func TestClient_CausalConsistency(t *testing.T) {
390390
err = c.Connect(ctx)
391391
require.NoError(t, err)
392392

393-
sess, err := c.StartSession(sessionopt.CausalConsistency(true))
393+
s, err := c.StartSession(sessionopt.CausalConsistency(true))
394+
sess := s.(*sessionImpl)
394395
require.NoError(t, err)
395396
require.NotNil(t, sess)
396397
require.True(t, sess.Consistent)
397398
sess.EndSession(ctx)
398399

399-
sess, err = c.StartSession(sessionopt.CausalConsistency(false))
400+
s, err = c.StartSession(sessionopt.CausalConsistency(false))
401+
sess = s.(*sessionImpl)
400402
require.NoError(t, err)
401403
require.NotNil(t, sess)
402404
require.False(t, sess.Consistent)
403405
sess.EndSession(ctx)
404406

405-
sess, err = c.StartSession()
407+
s, err = c.StartSession()
408+
sess = s.(*sessionImpl)
406409
require.NoError(t, err)
407410
require.NotNil(t, sess)
408411
require.True(t, sess.Consistent)

0 commit comments

Comments
 (0)