Skip to content

Commit 84ddf5a

Browse files
bartle-stripeDivjot Arora
authored andcommitted
Don't mutate inputs in DiffTopology or Topology.DiffHostlist.
Also guard against a similar race in `compareHosts`. GODRIVER-1301
1 parent 7f54217 commit 84ddf5a

File tree

3 files changed

+102
-77
lines changed

3 files changed

+102
-77
lines changed

x/mongo/driver/description/topology.go

Lines changed: 30 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,6 @@
77
package description
88

99
import (
10-
"sort"
11-
"strings"
12-
1310
"go.mongodb.org/mongo-driver/x/mongo/driver/address"
1411
)
1512

@@ -41,39 +38,24 @@ type TopologyDiff struct {
4138
func DiffTopology(old, new Topology) TopologyDiff {
4239
var diff TopologyDiff
4340

44-
// TODO: do this without sorting...
45-
oldServers := serverSorter(old.Servers)
46-
newServers := serverSorter(new.Servers)
47-
48-
sort.Sort(oldServers)
49-
sort.Sort(newServers)
50-
51-
i := 0
52-
j := 0
53-
for {
54-
if i < len(oldServers) && j < len(newServers) {
55-
comp := strings.Compare(oldServers[i].Addr.String(), newServers[j].Addr.String())
56-
switch comp {
57-
case 1:
58-
//left is bigger than
59-
diff.Added = append(diff.Added, newServers[j])
60-
j++
61-
case -1:
62-
// right is bigger
63-
diff.Removed = append(diff.Removed, oldServers[i])
64-
i++
65-
case 0:
66-
i++
67-
j++
68-
}
69-
} else if i < len(oldServers) {
70-
diff.Removed = append(diff.Removed, oldServers[i])
71-
i++
72-
} else if j < len(newServers) {
73-
diff.Added = append(diff.Added, newServers[j])
74-
j++
41+
oldServers := make(map[string]bool)
42+
for _, s := range old.Servers {
43+
oldServers[s.Addr.String()] = true
44+
}
45+
46+
for _, s := range new.Servers {
47+
addr := s.Addr.String()
48+
if oldServers[addr] {
49+
delete(oldServers, addr)
7550
} else {
76-
break
51+
diff.Added = append(diff.Added, s)
52+
}
53+
}
54+
55+
for _, s := range old.Servers {
56+
addr := s.Addr.String()
57+
if oldServers[addr] {
58+
diff.Removed = append(diff.Removed, s)
7759
}
7860
}
7961

@@ -90,47 +72,22 @@ type HostlistDiff struct {
9072
func (t Topology) DiffHostlist(hostlist []string) HostlistDiff {
9173
var diff HostlistDiff
9274

93-
oldServers := serverSorter(t.Servers)
94-
sort.Sort(oldServers)
95-
sort.Strings(hostlist)
96-
97-
i := 0
98-
j := 0
99-
for {
100-
if i < len(oldServers) && j < len(hostlist) {
101-
oldServer := oldServers[i].Addr.String()
102-
comp := strings.Compare(oldServer, hostlist[j])
103-
switch comp {
104-
case 1:
105-
// oldServers[i] is bigger
106-
diff.Added = append(diff.Added, hostlist[j])
107-
j++
108-
case -1:
109-
// hostlist[j] is bigger
110-
diff.Removed = append(diff.Removed, oldServer)
111-
i++
112-
case 0:
113-
i++
114-
j++
115-
}
116-
} else if i < len(oldServers) {
117-
diff.Removed = append(diff.Removed, oldServers[i].Addr.String())
118-
i++
119-
} else if j < len(hostlist) {
120-
diff.Added = append(diff.Added, hostlist[j])
121-
j++
75+
oldServers := make(map[string]bool)
76+
for _, s := range t.Servers {
77+
oldServers[s.Addr.String()] = true
78+
}
79+
80+
for _, addr := range hostlist {
81+
if oldServers[addr] {
82+
delete(oldServers, addr)
12283
} else {
123-
break
84+
diff.Added = append(diff.Added, addr)
12485
}
12586
}
12687

127-
return diff
128-
}
129-
130-
type serverSorter []Server
88+
for addr := range oldServers {
89+
diff.Removed = append(diff.Removed, addr)
90+
}
13191

132-
func (ss serverSorter) Len() int { return len(ss) }
133-
func (ss serverSorter) Swap(i, j int) { ss[i], ss[j] = ss[j], ss[i] }
134-
func (ss serverSorter) Less(i, j int) bool {
135-
return strings.Compare(ss[i].Addr.String(), ss[j].Addr.String()) < 0
92+
return diff
13693
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// Copyright (C) MongoDB, Inc. 2019-present.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License"); you may
4+
// not use this file except in compliance with the License. You may obtain
5+
// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
6+
7+
package description
8+
9+
import (
10+
"testing"
11+
12+
"github.com/stretchr/testify/assert"
13+
)
14+
15+
func TestDiffTopology(t *testing.T) {
16+
s1 := Server{Addr: "1.0.0.0:27017"}
17+
s2 := Server{Addr: "2.0.0.0:27017"}
18+
s3 := Server{Addr: "3.0.0.0:27017"}
19+
s4 := Server{Addr: "4.0.0.0:27017"}
20+
s5 := Server{Addr: "5.0.0.0:27017"}
21+
s6 := Server{Addr: "6.0.0.0:27017"}
22+
23+
t1 := Topology{
24+
Servers: []Server{s6, s1, s3, s2},
25+
}
26+
t2 := Topology{
27+
Servers: []Server{s2, s4, s3, s5},
28+
}
29+
30+
diff := DiffTopology(t1, t2)
31+
32+
assert.ElementsMatch(t, []Server{s4, s5}, diff.Added)
33+
assert.ElementsMatch(t, []Server{s1, s6}, diff.Removed)
34+
35+
// Ensure that original topology servers were not reordered.
36+
assert.EqualValues(t, []Server{s6, s1, s3, s2}, t1.Servers)
37+
assert.EqualValues(t, []Server{s2, s4, s3, s5}, t2.Servers)
38+
}
39+
40+
func TestTopology_DiffHostlist(t *testing.T) {
41+
h1 := "1.0.0.0:27017"
42+
h2 := "2.0.0.0:27017"
43+
h3 := "3.0.0.0:27017"
44+
h4 := "4.0.0.0:27017"
45+
h5 := "5.0.0.0:27017"
46+
h6 := "6.0.0.0:27017"
47+
s1 := Server{Addr: "1.0.0.0:27017"}
48+
s2 := Server{Addr: "2.0.0.0:27017"}
49+
s3 := Server{Addr: "3.0.0.0:27017"}
50+
s6 := Server{Addr: "6.0.0.0:27017"}
51+
52+
topo := Topology{
53+
Servers: []Server{s6, s1, s3, s2},
54+
}
55+
hostlist := []string{h2, h4, h3, h5}
56+
57+
diff := topo.DiffHostlist(hostlist)
58+
59+
assert.ElementsMatch(t, []string{h4, h5}, diff.Added)
60+
assert.ElementsMatch(t, []string{h1, h6}, diff.Removed)
61+
62+
// Ensure that original topology servers and hostlist were not reordered.
63+
assert.EqualValues(t, []Server{s6, s1, s3, s2}, topo.Servers)
64+
assert.EqualValues(t, []string{h2, h4, h3, h5}, hostlist)
65+
}

x/mongo/driver/topology/polling_srv_records_test.go

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -108,14 +108,17 @@ func compareHosts(t *testing.T, received []description.Server, expected []string
108108
t.Fatalf("Number of hosts in topology does not match expected value. Got %v; want %v.", len(received), len(expected))
109109
}
110110

111-
actual := serverSorter(received)
111+
// Take a copy of servers so we don't risk a data race similar to GODRIVER-1301.
112+
servers := make([]description.Server, len(received))
113+
copy(servers, received)
114+
actual := serverSorter(servers)
112115
sort.Sort(actual)
113116
sort.Strings(expected)
114117

115-
for i := range received {
116-
if received[i].Addr.String() != expected[i] {
118+
for i := range servers {
119+
if servers[i].Addr.String() != expected[i] {
117120
t.Errorf("Hosts in topology differ from expected values. Got %v; want %v.",
118-
received[i].Addr.String(), expected[i])
121+
servers[i].Addr.String(), expected[i])
119122
}
120123
}
121124
}

0 commit comments

Comments
 (0)