Skip to content

Commit acbd1d0

Browse files
Make eventSequencer thread safe
1 parent 95e4d63 commit acbd1d0

File tree

4 files changed

+81
-51
lines changed

4 files changed

+81
-51
lines changed

internal/integration/unified/client_entity.go

Lines changed: 61 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,14 @@ var securitySensitiveCommands = []string{
4848
// recorded. After pool initialization completes, we set eventCutoffSeq to the
4949
// current sequence number. Event accessors for CMAP and SDAM types then
5050
// filter out any events with sequence <= eventCutoffSeq.
51+
//
52+
// Sequencing is thread-safe to support concurrent operations that may generate
53+
// events (e.g., connection checkouts generating CMAP events).
5154
type eventSequencer struct {
5255
counter atomic.Int64
53-
cutoff int64
56+
cutoff atomic.Int64
57+
58+
mu sync.RWMutex
5459

5560
// pool events are heterogeneous, so we track their sequence separately
5661
poolSeq []int64
@@ -59,27 +64,42 @@ type eventSequencer struct {
5964

6065
// setCutoff marks the current sequence as the filtering cutoff point.
6166
func (es *eventSequencer) setCutoff() {
62-
es.cutoff = es.counter.Load()
67+
es.cutoff.Store(es.counter.Load())
6368
}
6469

6570
// recordEvent stores the sequence number for a given event type.
6671
func (es *eventSequencer) recordEvent(eventType monitoringEventType) {
6772
next := es.counter.Add(1)
73+
74+
es.mu.Lock()
6875
es.seqByEventType[eventType] = append(es.seqByEventType[eventType], next)
76+
es.mu.Unlock()
6977
}
7078

7179
func (es *eventSequencer) recordPooledEvent() {
7280
next := es.counter.Add(1)
81+
82+
es.mu.Lock()
7383
es.poolSeq = append(es.poolSeq, next)
84+
es.mu.Unlock()
7485
}
7586

7687
// shouldFilter returns true if the event at the given index should be filtered.
7788
func (es *eventSequencer) shouldFilter(eventType monitoringEventType, index int) bool {
78-
if es.cutoff == 0 {
89+
cutoff := es.cutoff.Load()
90+
if cutoff == 0 {
7991
return false
8092
}
8193

82-
return es.seqByEventType[eventType][index] <= es.cutoff
94+
es.mu.RLock()
95+
defer es.mu.RUnlock()
96+
97+
seqs, ok := es.seqByEventType[eventType]
98+
if !ok || index < 0 || index >= len(seqs) {
99+
return false
100+
}
101+
102+
return seqs[index] <= cutoff
83103
}
84104

85105
// clientEntity is a wrapper for a mongo.Client object that also holds additional information required during test
@@ -352,17 +372,43 @@ func (c *clientEntity) failedEvents() []*event.CommandFailedEvent {
352372
return events
353373
}
354374

355-
// filterEventsBySeq filters events by sequence number using the provided
356-
// sequence slice. See comments on eventSequencer for more details.
357-
func filterEventsBySeq[T any](c *clientEntity, events []T, seqSlice []int64) []T {
358-
if c.eventSequencer.cutoff == 0 {
375+
// filterEventsBySeq filters events by sequence number for the given eventType.
376+
// See comments on eventSequencer for more details.
377+
func filterEventsBySeq[T any](c *clientEntity, events []T, eventType monitoringEventType) []T {
378+
cutoff := c.eventSequencer.cutoff.Load()
379+
if cutoff == 0 {
359380
return events
360381
}
361382

362-
var filtered []T
363-
for i, evt := range events {
364-
if seqSlice[i] > c.eventSequencer.cutoff {
365-
filtered = append(filtered, evt)
383+
// Lock order: eventProcessMu -> eventSequencer.mu (matches writers)
384+
c.eventProcessMu.RLock()
385+
c.eventSequencer.mu.RLock()
386+
387+
// Snapshot to minimize time under locks and avoid races
388+
localEvents := append([]T(nil), events...)
389+
390+
var seqSlice []int64
391+
if eventType == poolAnyEvent {
392+
seqSlice = c.eventSequencer.poolSeq
393+
} else {
394+
seqSlice = c.eventSequencer.seqByEventType[eventType]
395+
}
396+
397+
localSeqs := append([]int64(nil), seqSlice...)
398+
399+
c.eventSequencer.mu.RUnlock()
400+
c.eventProcessMu.RUnlock()
401+
402+
// guard against index out of range.
403+
n := len(localEvents)
404+
if len(localSeqs) < n {
405+
n = len(localSeqs)
406+
}
407+
408+
filtered := make([]T, 0, n)
409+
for i := 0; i < n; i++ {
410+
if localSeqs[i] > cutoff {
411+
filtered = append(filtered, localEvents[i])
366412
}
367413
}
368414

@@ -555,8 +601,10 @@ func (c *clientEntity) processPoolEvent(evt *event.PoolEvent) {
555601

556602
eventType := monitoringEventTypeFromPoolEvent(evt)
557603
if _, ok := c.observedEvents[eventType]; ok {
604+
c.eventProcessMu.Lock()
558605
c.pooled = append(c.pooled, evt)
559606
c.eventSequencer.recordPooledEvent()
607+
c.eventProcessMu.Unlock()
560608
}
561609

562610
c.addEventsCount(eventType)
@@ -787,9 +835,7 @@ func awaitMinimumPoolSize(ctx context.Context, entity *clientEntity, minPoolSize
787835
case <-awaitCtx.Done():
788836
return fmt.Errorf("timed out waiting for client to reach minPoolSize")
789837
case <-ticker.C:
790-
if uint64(entity.eventsCount[connectionReadyEvent]) >= minPoolSize {
791-
// Clear all CMAP and SDAM events that occurred during pool
792-
// initialization.
838+
if uint64(entity.getEventCount(connectionReadyEvent)) >= minPoolSize {
793839
entity.eventSequencer.setCutoff()
794840

795841
return nil

internal/integration/unified/client_entity_test.go

Lines changed: 8 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,6 @@ func recordTopologyOpening(c *clientEntity) {
3030
c.eventSequencer.recordEvent(topologyOpeningEvent)
3131
}
3232

33-
func recordHeartbeatSucceeded(c *clientEntity) {
34-
c.serverHeartbeatSucceeded = append(c.serverHeartbeatSucceeded, &event.ServerHeartbeatSucceededEvent{})
35-
c.eventSequencer.recordEvent(serverHeartbeatSucceededEvent)
36-
}
37-
3833
func Test_eventSequencer(t *testing.T) {
3934
tests := []struct {
4035
name string
@@ -91,19 +86,6 @@ func Test_eventSequencer(t *testing.T) {
9186
topologyOpeningEvent: 1,
9287
},
9388
},
94-
{
95-
name: "cutoff at beginning filters nothing",
96-
cutoffAfter: 0,
97-
setupEvents: func(c *clientEntity) {
98-
// Cutoff will be set immediately (before any events)
99-
recordPoolEvent(c)
100-
recordHeartbeatSucceeded(c)
101-
},
102-
expectedPooled: 1,
103-
expectedSDAM: map[monitoringEventType]int{
104-
serverHeartbeatSucceededEvent: 1,
105-
},
106-
},
10789
{
10890
name: "cutoff after all events filters everything",
10991
cutoffAfter: 3,
@@ -135,25 +117,24 @@ func Test_eventSequencer(t *testing.T) {
135117
// Set cutoff if specified
136118
if tt.cutoffAfter > 0 {
137119
// Manually set cutoff to the specified event sequence
138-
client.eventSequencer.cutoff = int64(tt.cutoffAfter)
120+
client.eventSequencer.cutoff.Store(int64(tt.cutoffAfter))
139121
}
140122

141123
// Test pool event filtering
142-
filteredPool := filterEventsBySeq(client, client.pooled, client.eventSequencer.poolSeq)
124+
filteredPool := filterEventsBySeq(client, client.pooled, poolAnyEvent)
143125
assert.Equal(t, tt.expectedPooled, len(filteredPool), "pool events count mismatch")
144126

145127
// Test SDAM event filtering
146128
for eventType, expectedCount := range tt.expectedSDAM {
147129
var actualCount int
148-
seqs := client.eventSequencer.seqByEventType[eventType]
149130

150131
switch eventType {
151132
case serverDescriptionChangedEvent:
152-
actualCount = len(filterEventsBySeq(client, client.serverDescriptionChanged, seqs))
133+
actualCount = len(filterEventsBySeq(client, client.serverDescriptionChanged, serverDescriptionChangedEvent))
153134
case serverHeartbeatSucceededEvent:
154-
actualCount = len(filterEventsBySeq(client, client.serverHeartbeatSucceeded, seqs))
135+
actualCount = len(filterEventsBySeq(client, client.serverHeartbeatSucceeded, serverHeartbeatSucceededEvent))
155136
case topologyOpeningEvent:
156-
actualCount = len(filterEventsBySeq(client, client.topologyOpening, seqs))
137+
actualCount = len(filterEventsBySeq(client, client.topologyOpening, topologyOpeningEvent))
157138
}
158139

159140
assert.Equal(t, expectedCount, actualCount, "%s count mismatch", eventType)
@@ -180,14 +161,14 @@ func Test_eventSequencer_setCutoff(t *testing.T) {
180161
client.eventSequencer.setCutoff()
181162

182163
// Verify cutoff matches counter
183-
assert.Equal(t, int64(2), client.eventSequencer.cutoff, "cutoff should be 2")
164+
assert.Equal(t, int64(2), client.eventSequencer.cutoff.Load(), "cutoff should be 2")
184165

185166
// Record more events
186167
recordPoolEvent(client)
187168

188169
// Verify counter incremented but cutoff didn't
189170
assert.Equal(t, int64(3), client.eventSequencer.counter.Load(), "counter should be 3")
190-
assert.Equal(t, int64(2), client.eventSequencer.cutoff, "cutoff should still be 2")
171+
assert.Equal(t, int64(2), client.eventSequencer.cutoff.Load(), "cutoff should still be 2")
191172
}
192173

193174
func Test_eventSequencer_shouldFilter(t *testing.T) {
@@ -245,7 +226,7 @@ func Test_eventSequencer_shouldFilter(t *testing.T) {
245226

246227
for _, tt := range tests {
247228
t.Run(tt.name, func(t *testing.T) {
248-
es.cutoff = tt.cutoff
229+
es.cutoff.Store(tt.cutoff)
249230
result := es.shouldFilter(tt.eventType, tt.index)
250231
assert.Equal(t, tt.expected, result, "shouldFilter result mismatch")
251232
})

internal/integration/unified/event.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ const (
3737
topologyDescriptionChangedEvent monitoringEventType = "TopologyDescriptionChangedEvent"
3838
topologyOpeningEvent monitoringEventType = "TopologyOpeningEvent"
3939
topologyClosedEvent monitoringEventType = "TopologyClosedEvent"
40+
41+
// sentinel: indicates "use pooled (CMAP) sequence".
42+
poolAnyEvent monitoringEventType = "_PoolAny"
4043
)
4144

4245
func monitoringEventTypeFromString(eventStr string) (monitoringEventType, bool) {

internal/integration/unified/event_verification.go

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ func verifyCommandEvents(ctx context.Context, client *clientEntity, expectedEven
312312
}
313313

314314
func verifyCMAPEvents(client *clientEntity, expectedEvents *expectedEvents) error {
315-
pooled := filterEventsBySeq(client, client.pooled, client.eventSequencer.poolSeq)
315+
pooled := filterEventsBySeq(client, client.pooled, poolAnyEvent)
316316
if len(expectedEvents.CMAPEvents) == 0 && len(pooled) != 0 {
317317
return fmt.Errorf("expected no cmap events to be sent but got %s", stringifyEventsForClient(client))
318318
}
@@ -443,7 +443,7 @@ func stringifyEventsForClient(client *clientEntity) string {
443443
}
444444

445445
str.WriteString("\nPool Events\n\n")
446-
for _, evt := range filterEventsBySeq(client, client.pooled, client.eventSequencer.poolSeq) {
446+
for _, evt := range filterEventsBySeq(client, client.pooled, poolAnyEvent) {
447447
str.WriteString(fmt.Sprintf("[%s] Event Type: %q\n", evt.Address, evt.Type))
448448
}
449449

@@ -522,13 +522,13 @@ func getNextTopologyClosedEvent(
522522

523523
func verifySDAMEvents(client *clientEntity, expectedEvents *expectedEvents) error {
524524
var (
525-
changed = filterEventsBySeq(client, client.serverDescriptionChanged, client.eventSequencer.seqByEventType[serverDescriptionChangedEvent])
526-
started = filterEventsBySeq(client, client.serverHeartbeatStartedEvent, client.eventSequencer.seqByEventType[serverHeartbeatStartedEvent])
527-
succeeded = filterEventsBySeq(client, client.serverHeartbeatSucceeded, client.eventSequencer.seqByEventType[serverHeartbeatSucceededEvent])
528-
failed = filterEventsBySeq(client, client.serverHeartbeatFailedEvent, client.eventSequencer.seqByEventType[serverHeartbeatFailedEvent])
529-
tchanged = filterEventsBySeq(client, client.topologyDescriptionChanged, client.eventSequencer.seqByEventType[topologyDescriptionChangedEvent])
530-
topening = filterEventsBySeq(client, client.topologyOpening, client.eventSequencer.seqByEventType[topologyOpeningEvent])
531-
tclosed = filterEventsBySeq(client, client.topologyClosed, client.eventSequencer.seqByEventType[topologyClosedEvent])
525+
changed = filterEventsBySeq(client, client.serverDescriptionChanged, serverDescriptionChangedEvent)
526+
started = filterEventsBySeq(client, client.serverHeartbeatStartedEvent, serverHeartbeatStartedEvent)
527+
succeeded = filterEventsBySeq(client, client.serverHeartbeatSucceeded, serverHeartbeatSucceededEvent)
528+
failed = filterEventsBySeq(client, client.serverHeartbeatFailedEvent, serverHeartbeatFailedEvent)
529+
tchanged = filterEventsBySeq(client, client.topologyDescriptionChanged, topologyDescriptionChangedEvent)
530+
topening = filterEventsBySeq(client, client.topologyOpening, topologyOpeningEvent)
531+
tclosed = filterEventsBySeq(client, client.topologyClosed, topologyClosedEvent)
532532
)
533533

534534
vol := func() int {

0 commit comments

Comments
 (0)