Skip to content

Commit 915670b

Browse files
committed
Add custom Dialer to ClientOptions
GODRIVER-195 Change-Id: I4060ae2af015d13b0ba206eb0a597c319a550c49
1 parent 6b85f6b commit 915670b

File tree

4 files changed

+71
-7
lines changed

4 files changed

+71
-7
lines changed

mongo/client.go

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,12 +25,13 @@ const defaultLocalThreshold = 15 * time.Millisecond
2525

2626
// Client performs operations on a given topology.
2727
type Client struct {
28-
topology *topology.Topology
29-
connString connstring.ConnString
30-
localThreshold time.Duration
31-
readPreference *readpref.ReadPref
32-
readConcern *readconcern.ReadConcern
33-
writeConcern *writeconcern.WriteConcern
28+
topologyOptions []topology.Option
29+
topology *topology.Topology
30+
connString connstring.ConnString
31+
localThreshold time.Duration
32+
readPreference *readpref.ReadPref
33+
readConcern *readconcern.ReadConcern
34+
writeConcern *writeconcern.WriteConcern
3435
}
3536

3637
// NewClient creates a new client to connect to a cluster specified by the uri.
@@ -78,7 +79,11 @@ func newClient(cs connstring.ConnString, opts *ClientOptions) (*Client, error) {
7879
}
7980
}
8081

81-
topo, err := topology.New(topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return client.connString }))
82+
topts := append(
83+
client.topologyOptions,
84+
topology.WithConnString(func(connstring.ConnString) connstring.ConnString { return client.connString }),
85+
)
86+
topo, err := topology.New(topts...)
8287
if err != nil {
8388
return nil, err
8489
}

mongo/client_options.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ package mongo
99
import (
1010
"time"
1111

12+
"github.com/mongodb/mongo-go-driver/core/connection"
1213
"github.com/mongodb/mongo-go-driver/core/connstring"
14+
"github.com/mongodb/mongo-go-driver/core/topology"
1315
)
1416

1517
type option func(*Client) error
@@ -83,6 +85,30 @@ func (co *ClientOptions) ConnectTimeout(d time.Duration) *ClientOptions {
8385
return &ClientOptions{next: co, opt: fn}
8486
}
8587

88+
// Dialer specifies a custom dialer used to dial new connections to a server.
89+
func (co *ClientOptions) Dialer(d Dialer) *ClientOptions {
90+
var fn option = func(c *Client) error {
91+
c.topologyOptions = append(
92+
c.topologyOptions,
93+
topology.WithServerOptions(func(opts ...topology.ServerOption) []topology.ServerOption {
94+
return append(
95+
opts,
96+
topology.WithConnectionOptions(func(opts ...connection.Option) []connection.Option {
97+
return append(
98+
opts,
99+
connection.WithDialer(func(connection.Dialer) connection.Dialer {
100+
return d
101+
}),
102+
)
103+
}),
104+
)
105+
}),
106+
)
107+
return nil
108+
}
109+
return &ClientOptions{next: co, opt: fn}
110+
}
111+
86112
// HeartbeatInterval specifies the interval to wait between server monitoring checks.
87113
func (co *ClientOptions) HeartbeatInterval(d time.Duration) *ClientOptions {
88114
var fn option = func(c *Client) error {

mongo/client_options_test.go

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package mongo
22

33
import (
4+
"context"
5+
"net"
6+
"sync/atomic"
47
"testing"
58

69
"time"
@@ -97,3 +100,26 @@ func TestClientOptions_chainAll(t *testing.T) {
97100
opts = opts.next
98101
}
99102
}
103+
104+
func TestClientOptions_CustomDialer(t *testing.T) {
105+
td := &testDialer{d: &net.Dialer{}}
106+
opts := ClientOpt.Dialer(td)
107+
client, err := newClient(testutil.ConnString(t), opts)
108+
require.NoError(t, err)
109+
_, err = client.ListDatabases(context.Background(), nil)
110+
require.NoError(t, err)
111+
got := atomic.LoadInt32(&td.called)
112+
if got < 1 {
113+
t.Errorf("Custom dialer was not used when dialing new connections")
114+
}
115+
}
116+
117+
type testDialer struct {
118+
called int32
119+
d Dialer
120+
}
121+
122+
func (td *testDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
123+
atomic.AddInt32(&td.called, 1)
124+
return td.d.DialContext(ctx, network, address)
125+
}

mongo/mongo.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,23 @@
77
package mongo
88

99
import (
10+
"context"
1011
"errors"
1112
"fmt"
1213
"io"
14+
"net"
1315
"reflect"
1416
"strings"
1517

1618
"github.com/mongodb/mongo-go-driver/bson"
1719
"github.com/mongodb/mongo-go-driver/bson/objectid"
1820
)
1921

22+
// Dialer is used to make network connections.
23+
type Dialer interface {
24+
DialContext(ctx context.Context, network, address string) (net.Conn, error)
25+
}
26+
2027
// TransformDocument handles transforming a document of an allowable type into
2128
// a *bson.Document. This method is called directly after most methods that
2229
// have one or more parameters that are documents.

0 commit comments

Comments
 (0)