Skip to content

Commit 42728fe

Browse files
author
Isabella Siu
committed
GODRIVER-898 make topology update serverDescriptions synchronously
Change-Id: I0a9e1719adced1ad315d32d82d04e9d1585717e1
1 parent f3add1f commit 42728fe

File tree

11 files changed

+144
-192
lines changed

11 files changed

+144
-192
lines changed

mongo/transactions_test.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,6 @@ func runTransactionTestFile(t *testing.T, filepath string) {
142142

143143
func runTransactionsTestCase(t *testing.T, test *transTestCase, testfile transTestFile, dbAdmin *Database) {
144144
t.Run(test.Description, func(t *testing.T) {
145-
146145
// kill sessions from previously failed tests
147146
killSessions(t, dbAdmin.client)
148147

x/mongo/driver/topology/connection.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ func (sc *sconn) processErr(err error) {
6060
return
6161
}
6262

63-
ne, ok := err.(connection.NetworkError)
63+
ne, ok := err.(connection.Error)
6464
if !ok {
6565
return
6666
}

x/mongo/driver/topology/connection_test.go

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ func (n netErr) Temporary() bool {
3333
}
3434

3535
type connect struct {
36-
err *connection.NetworkError
36+
err *connection.Error
3737
}
3838

3939
func (c connect) WriteWireMessage(ctx context.Context, wm wiremessage.WireMessage) error {
@@ -58,7 +58,7 @@ func (c connect) ID() string {
5858
// Test case for sconn processErr
5959
func TestConnectionProcessErrSpec(t *testing.T) {
6060
ctx := context.Background()
61-
s, err := NewServer(address.Address("localhost"))
61+
s, err := NewServer(address.Address("localhost"), nil)
6262
require.NoError(t, err)
6363

6464
desc := s.Description()
@@ -67,7 +67,7 @@ func TestConnectionProcessErrSpec(t *testing.T) {
6767
s.connectionstate = connected
6868

6969
innerErr := netErr{}
70-
connectErr := connection.NetworkError{"blah", innerErr}
70+
connectErr := connection.Error{ConnectionID: "blah", Wrapped: innerErr}
7171
c := connect{&connectErr}
7272
sc := sconn{c, s, 1}
7373
err = sc.WriteWireMessage(ctx, nil)

x/mongo/driver/topology/server.go

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,14 @@ type Server struct {
9999
currentSubscriberID uint64
100100

101101
subscriptionsClosed bool
102+
103+
updateTopologyCallback atomic.Value
102104
}
103105

104106
// ConnectServer creates a new Server and then initializes it using the
105107
// Connect method.
106-
func ConnectServer(ctx context.Context, addr address.Address, opts ...ServerOption) (*Server, error) {
107-
srvr, err := NewServer(addr, opts...)
108+
func ConnectServer(ctx context.Context, addr address.Address, topo func(description.Server), opts ...ServerOption) (*Server, error) {
109+
srvr, err := NewServer(addr, topo, opts...)
108110
if err != nil {
109111
return nil, err
110112
}
@@ -117,7 +119,7 @@ func ConnectServer(ctx context.Context, addr address.Address, opts ...ServerOpti
117119

118120
// NewServer creates a new server. The mongodb server at the address will be monitored
119121
// on an internal monitoring goroutine.
120-
func NewServer(addr address.Address, opts ...ServerOption) (*Server, error) {
122+
func NewServer(addr address.Address, topo func(description.Server), opts ...ServerOption) (*Server, error) {
121123
cfg, err := newServerConfig(opts...)
122124
if err != nil {
123125
return nil, err
@@ -133,6 +135,7 @@ func NewServer(addr address.Address, opts ...ServerOption) (*Server, error) {
133135
subscribers: make(map[uint64]chan description.Server),
134136
}
135137
s.desc.Store(description.Server{Addr: addr})
138+
s.updateTopologyCallback.Store(topo)
136139

137140
var maxConns uint64
138141
if cfg.maxConns == 0 {
@@ -175,6 +178,8 @@ func (s *Server) Disconnect(ctx context.Context) error {
175178
return ErrServerClosed
176179
}
177180

181+
s.updateTopologyCallback.Store((func(description.Server))(nil))
182+
178183
// For every call to Connect there must be at least 1 goroutine that is
179184
// waiting on the done channel.
180185
s.done <- struct{}{}
@@ -373,6 +378,11 @@ func (s *Server) updateDescription(desc description.Server, initial bool) {
373378
}()
374379
s.desc.Store(desc)
375380

381+
topo := s.updateTopologyCallback.Load().(func(description.Server))
382+
if topo != nil {
383+
topo(desc)
384+
}
385+
376386
s.subLock.Lock()
377387
for _, c := range s.subscribers {
378388
select {

x/mongo/driver/topology/server_test.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func TestServer(t *testing.T) {
7474

7575
for _, tt := range serverTestTable {
7676
t.Run(tt.name, func(t *testing.T) {
77-
s, err := NewServer(address.Address("localhost"))
77+
s, err := NewServer(address.Address("localhost"), nil)
7878
require.NoError(t, err)
7979

8080
var desc *description.Server
@@ -103,7 +103,7 @@ func TestServer(t *testing.T) {
103103
})
104104
}
105105
t.Run("WriteConcernError", func(t *testing.T) {
106-
s, err := NewServer(address.Address("localhost"))
106+
s, err := NewServer(address.Address("localhost"), nil)
107107
require.NoError(t, err)
108108

109109
var desc *description.Server
@@ -127,7 +127,7 @@ func TestServer(t *testing.T) {
127127
require.Equal(t, drained, true)
128128
})
129129
t.Run("no WriteConcernError", func(t *testing.T) {
130-
s, err := NewServer(address.Address("localhost"))
130+
s, err := NewServer(address.Address("localhost"), nil)
131131
require.NoError(t, err)
132132

133133
var desc *description.Server
@@ -148,4 +148,12 @@ func TestServer(t *testing.T) {
148148
drained := s.pool.(*pool).drainCalled.Load().(bool)
149149
require.Equal(t, drained, false)
150150
})
151+
t.Run("update topology", func(t *testing.T) {
152+
var updated bool
153+
s, err := NewServer(address.Address("localhost"), func(description.Server) { updated = true })
154+
require.NoError(t, err)
155+
s.updateDescription(description.Server{Addr: s.address}, false)
156+
require.True(t, updated)
157+
158+
})
151159
}

x/mongo/driver/topology/topology.go

Lines changed: 45 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -63,9 +63,7 @@ type Topology struct {
6363

6464
done chan struct{}
6565

66-
fsm *fsm
67-
changes chan description.Server
68-
changeswg sync.WaitGroup
66+
fsm *fsm
6967

7068
SessionPool *session.Pool
7169

@@ -84,8 +82,6 @@ type Topology struct {
8482
serversLock sync.Mutex
8583
serversClosed bool
8684
servers map[address.Address]*Server
87-
88-
wg sync.WaitGroup
8985
}
9086

9187
// New creates a new topology.
@@ -99,7 +95,6 @@ func New(opts ...Option) (*Topology, error) {
9995
cfg: cfg,
10096
done: make(chan struct{}),
10197
fsm: newFSM(),
102-
changes: make(chan description.Server),
10398
subscribers: make(map[uint64]chan description.Topology),
10499
servers: make(map[address.Address]*Server),
105100
}
@@ -134,9 +129,6 @@ func (t *Topology) Connect(ctx context.Context) error {
134129
}
135130
t.serversLock.Unlock()
136131

137-
go t.update()
138-
t.changeswg.Add(1)
139-
140132
t.subscriptionsClosed = false // explicitly set in case topology was disconnected and then reconnected
141133

142134
atomic.StoreInt32(&t.connectionstate, connected)
@@ -154,16 +146,25 @@ func (t *Topology) Disconnect(ctx context.Context) error {
154146
return ErrTopologyClosed
155147
}
156148

149+
servers := make(map[address.Address]*Server)
157150
t.serversLock.Lock()
158151
t.serversClosed = true
159152
for addr, server := range t.servers {
160-
t.removeServer(ctx, addr, server)
153+
servers[addr] = server
161154
}
162155
t.serversLock.Unlock()
163156

164-
t.wg.Wait()
165-
t.done <- struct{}{}
166-
t.changeswg.Wait()
157+
for _, server := range servers {
158+
_ = server.Disconnect(ctx)
159+
}
160+
161+
t.subLock.Lock()
162+
for id, ch := range t.subscribers {
163+
close(ch)
164+
delete(t.subscribers, id)
165+
}
166+
t.subscriptionsClosed = true
167+
t.subLock.Unlock()
167168

168169
t.desc.Store(description.Topology{})
169170

@@ -328,110 +329,73 @@ func (t *Topology) selectServer(ctx context.Context, subscriptionCh <-chan descr
328329
}
329330
}
330331

331-
func (t *Topology) update() {
332-
defer t.changeswg.Done()
333-
defer func() {
334-
// ¯\_(ツ)_/¯
335-
if r := recover(); r != nil {
336-
<-t.done
337-
}
338-
}()
332+
func (t *Topology) apply(ctx context.Context, desc description.Server) {
333+
var err error
339334

340-
for {
341-
select {
342-
case change := <-t.changes:
343-
current, err := t.apply(context.TODO(), change)
344-
if err != nil {
345-
continue
346-
}
335+
t.serversLock.Lock()
336+
defer t.serversLock.Unlock()
347337

348-
t.desc.Store(current)
349-
t.subLock.Lock()
350-
for _, ch := range t.subscribers {
351-
// We drain the description if there's one in the channel
352-
select {
353-
case <-ch:
354-
default:
355-
}
356-
ch <- current
357-
}
358-
t.subLock.Unlock()
359-
case <-t.done:
360-
t.subLock.Lock()
361-
for id, ch := range t.subscribers {
362-
close(ch)
363-
delete(t.subscribers, id)
364-
}
365-
t.subscriptionsClosed = true
366-
t.subLock.Unlock()
367-
return
368-
}
338+
if _, ok := t.servers[desc.Addr]; t.serversClosed || !ok {
339+
return
369340
}
370-
}
371341

372-
func (t *Topology) apply(ctx context.Context, desc description.Server) (description.Topology, error) {
373-
var err error
374342
prev := t.fsm.Topology
375343

376344
current, err := t.fsm.apply(desc)
377345
if err != nil {
378-
return description.Topology{}, err
346+
return
379347
}
380348

381349
diff := description.DiffTopology(prev, current)
382-
t.serversLock.Lock()
383-
if t.serversClosed {
384-
t.serversLock.Unlock()
385-
return description.Topology{}, nil
386-
}
387350

388351
for _, removed := range diff.Removed {
389352
if s, ok := t.servers[removed.Addr]; ok {
390-
t.removeServer(ctx, removed.Addr, s)
353+
go func() {
354+
cancelCtx, cancel := context.WithCancel(ctx)
355+
cancel()
356+
_ = s.Disconnect(cancelCtx)
357+
}()
358+
delete(t.servers, removed.Addr)
391359
}
392360
}
393361

394362
for _, added := range diff.Added {
395363
_ = t.addServer(ctx, added.Addr)
396364
}
397-
t.serversLock.Unlock()
398-
return current, nil
365+
366+
t.desc.Store(current)
367+
368+
t.subLock.Lock()
369+
for _, ch := range t.subscribers {
370+
// We drain the description if there's one in the channel
371+
select {
372+
case <-ch:
373+
default:
374+
}
375+
ch <- current
376+
}
377+
t.subLock.Unlock()
378+
399379
}
400380

401381
func (t *Topology) addServer(ctx context.Context, addr address.Address) error {
402382
if _, ok := t.servers[addr]; ok {
403383
return nil
404384
}
405385

406-
svr, err := ConnectServer(ctx, addr, t.cfg.serverOpts...)
407-
if err != nil {
408-
return err
386+
topoFunc := func(desc description.Server) {
387+
t.apply(context.TODO(), desc)
409388
}
410-
411-
t.servers[addr] = svr
412-
var sub *ServerSubscription
413-
sub, err = svr.Subscribe()
389+
svr, err := ConnectServer(ctx, addr, topoFunc, t.cfg.serverOpts...)
414390
if err != nil {
415391
return err
416392
}
417393

418-
t.wg.Add(1)
419-
go func() {
420-
for c := range sub.C {
421-
t.changes <- c
422-
}
423-
424-
t.wg.Done()
425-
}()
394+
t.servers[addr] = svr
426395

427396
return nil
428397
}
429398

430-
func (t *Topology) removeServer(ctx context.Context, addr address.Address, server *Server) {
431-
_ = server.Disconnect(ctx)
432-
delete(t.servers, addr)
433-
}
434-
435399
// String implements the Stringer interface
436400
func (t *Topology) String() string {
437401
desc := t.Description()

0 commit comments

Comments
 (0)