Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions mongo/mongo.go
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ func (m *Mongo) RoundTrip(msg *Message, tags []string) (_ *Message, err error) {
requestCursorID, _ := msg.Op.CursorID()
requestCommand, collection := msg.Op.CommandAndCollection()
txnDetails := msg.Op.TransactionDetails()
readPref, _ := msg.Op.ReadPref()

var conn driver.Connection
var server driver.Server
Expand All @@ -176,7 +177,7 @@ func (m *Mongo) RoundTrip(msg *Message, tags []string) (_ *Message, err error) {
}

if conn == nil {
server, err := m.selectServer(collection)
server, err := m.selectServer(collection, readPref)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -276,13 +277,13 @@ func (m *Mongo) RoundTrip(msg *Message, tags []string) (_ *Message, err error) {
}, nil
}

func (m *Mongo) selectServer(collection string) (server driver.Server, err error) {
func (m *Mongo) selectServer(collection string, readPref *readpref.ReadPref) (server driver.Server, err error) {
defer func(start time.Time) {
_ = m.statsd.Timing("server_selection", time.Since(start), []string{fmt.Sprintf("success:%v", err == nil)}, 1)
}(time.Now())
// Select a server
selector := description.CompositeSelector([]description.ServerSelector{
description.ReadPrefSelector(readpref.Primary()), // ignored by sharded clusters
description.ReadPrefSelector(readPref), // ignored by sharded clusters
description.LatencySelector(15 * time.Millisecond), // default localThreshold for the client
})
return m.topology.SelectServer(m.roundTripCtx, selector)
Expand Down
63 changes: 63 additions & 0 deletions mongo/operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ import (
"go.mongodb.org/mongo-driver/x/bsonx/bsoncore"
"go.mongodb.org/mongo-driver/x/mongo/driver"
"go.mongodb.org/mongo-driver/x/mongo/driver/wiremessage"

"go.mongodb.org/mongo-driver/mongo/readpref"
)

type Message struct {
Expand All @@ -33,6 +35,7 @@ type Operation interface {
Unacknowledged() bool
CommandAndCollection() (Command, string)
TransactionDetails() *TransactionDetails
ReadPref() (rpref *readpref.ReadPref, ok bool)
}

// see https://github.com/mongodb/mongo-go-driver/blob/v1.7.2/x/mongo/driver/operation.go#L1361-L1426
Expand Down Expand Up @@ -244,6 +247,7 @@ type opMsgSection interface {
isIsMaster() bool
append(buffer []byte) []byte
commandAndCollection() (Command, string)
ReadPref() (rpref *readpref.ReadPref, ok bool)
}

type opMsgSectionSingle struct {
Expand Down Expand Up @@ -277,6 +281,20 @@ func (o *opMsgSectionSingle) String() string {
return fmt.Sprintf("{ SectionSingle msg: %s }", o.msg.String())
}

func (o *opMsgSectionSingle) ReadPref() (rpref *readpref.ReadPref, ok bool) {
if prefDoc, ok := o.msg.Lookup("$readPreference").DocumentOK(); ok {
if prefStr, ok := prefDoc.Lookup("mode").StringValueOK(); ok {
// Note: only the mode is unpacked currently
mode, err := readpref.ModeFromString(prefStr)
if err == nil {
rpref, _ := readpref.New(mode)
return rpref, true
}
}
}
return readpref.Primary(), false
}

type opMsgSectionSequence struct {
identifier string
msgs []bsoncore.Document
Expand Down Expand Up @@ -676,6 +694,35 @@ func (g *opGetMore) String() string {
return fmt.Sprintf("{ OpGetMore fullCollectionName: %s, numberToReturn: %d, cursorID: %d }", g.fullCollectionName, g.numberToReturn, g.cursorID)
}

func (g *opGetMore) ReadPref() (rp *readpref.ReadPref, ok bool) {
return readpref.Primary(), false
}

func (r *opReply) ReadPref() (rp *readpref.ReadPref, ok bool) {
return readpref.Primary(), false
}

func (o *opMsgSectionSequence) ReadPref() (rpref *readpref.ReadPref, ok bool) {
return readpref.Primary(), false
}

func (m *opMsg) ReadPref() (rp *readpref.ReadPref, ok bool) {
for _, section := range m.sections {
if rpref, ok := section.ReadPref(); ok {
return rpref, ok
}
}
return readpref.Primary(), false
}

func (q *opQuery) ReadPref() (rp *readpref.ReadPref, ok bool) {
return readpref.Primary(), false
}

func (o *opUnknown) ReadPref() (rp *readpref.ReadPref, ok bool) {
return readpref.Primary(), false
}

// https://docs.mongodb.com/manual/reference/mongodb-wire-protocol/#op_update
type opUpdate struct {
reqID int32
Expand All @@ -689,6 +736,10 @@ func (u *opUpdate) TransactionDetails() *TransactionDetails {
return nil
}

func (g *opUpdate) ReadPref() (rp *readpref.ReadPref, ok bool) {
return readpref.Primary(), false
}

func decodeUpdate(reqID int32, wm []byte) (*opUpdate, error) {
var ok bool
u := opUpdate{
Expand Down Expand Up @@ -773,6 +824,10 @@ func (i *opInsert) TransactionDetails() *TransactionDetails {
return nil
}

func (g *opInsert) ReadPref() (rp *readpref.ReadPref, ok bool) {
return readpref.Primary(), false
}

func decodeInsert(reqID int32, wm []byte) (*opInsert, error) {
var ok bool
i := opInsert{
Expand Down Expand Up @@ -857,6 +912,10 @@ func (d *opDelete) TransactionDetails() *TransactionDetails {
return nil
}

func (g *opDelete) ReadPref() (rp *readpref.ReadPref, ok bool) {
return readpref.Primary(), false
}

func decodeDelete(reqID int32, wm []byte) (*opDelete, error) {
var ok bool
d := opDelete{
Expand Down Expand Up @@ -938,6 +997,10 @@ func (k *opKillCursors) TransactionDetails() *TransactionDetails {
return nil
}

func (g *opKillCursors) ReadPref() (rp *readpref.ReadPref, ok bool) {
return readpref.Primary(), false
}

func decodeKillCursors(reqID int32, wm []byte) (*opKillCursors, error) {
var ok bool
k := opKillCursors{
Expand Down
2 changes: 1 addition & 1 deletion util/statsd_test.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
package util_test

import (
"github.com/coinbase/mongobetween/util"
"github.com/DataDog/datadog-go/statsd"
"github.com/coinbase/mongobetween/util"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"sync"
Expand Down