diff --git a/x/mongo/driver/topology/rtt_monitor.go b/x/mongo/driver/topology/rtt_monitor.go index 88856b6b50..a55b70e9fb 100644 --- a/x/mongo/driver/topology/rtt_monitor.go +++ b/x/mongo/driver/topology/rtt_monitor.go @@ -41,11 +41,14 @@ type rttMonitor struct { // connMu guards connecting and disconnecting. This is necessary since // disconnecting will await the cancellation of a started connection. The // use case for rttMonitor.connect needs to be goroutine safe. - connMu sync.Mutex - averageRTT time.Duration - averageRTTSet bool - movingMin *list.List - minRTT time.Duration + connMu sync.Mutex + averageRTT time.Duration + averageRTTSet bool + movingMin *list.List + minRTT time.Duration + stddevRTT time.Duration + stddevSum float64 + callsToAppendMovingMin int closeWg sync.WaitGroup cfg *rttConfig @@ -179,8 +182,8 @@ func (r *rttMonitor) runHellos(conn *connection) { } } -// reset sets the average and min RTT to 0. This should only be called from the server monitor when an error -// occurs during a server check. Errors in the RTT monitor should not reset the RTTs. +// reset sets the average, min, and stddev RTT to 0. This should only be called from the server monitor +// when an error occurs during a server check. Errors in the RTT monitor should not reset the RTTs. func (r *rttMonitor) reset() { r.mu.Lock() defer r.mu.Unlock() @@ -188,11 +191,15 @@ func (r *rttMonitor) reset() { r.movingMin = list.New() r.averageRTT = 0 r.averageRTTSet = false + r.stddevSum = 0 + r.callsToAppendMovingMin = 0 } // appendMovingMin will append the RTT to the movingMin list which tracks a // minimum RTT within the last "minRTTSamplesForMovingMin" RTT samples. func (r *rttMonitor) appendMovingMin(rtt time.Duration) { + r.callsToAppendMovingMin++ + if r.movingMin == nil || rtt < 0 { return } @@ -202,6 +209,12 @@ func (r *rttMonitor) appendMovingMin(rtt time.Duration) { } r.movingMin.PushBack(rtt) + + // Collect a sum of stddevs over maxRTTSamplesForMovingMin calls, ignore if calls are less than max + if r.callsToAppendMovingMin >= maxRTTSamplesForMovingMin { + stddev := standardDeviationList(r.movingMin) + r.stddevSum += stddev + } } // min will return the minimum value in the movingMin list. @@ -222,6 +235,21 @@ func (r *rttMonitor) min() time.Duration { return min } +// stddev will return the current moving stddev. +func (r *rttMonitor) stddev() time.Duration { + var stddev time.Duration + + if r.callsToAppendMovingMin < maxRTTSamplesForMovingMin { + return 0 + } + + // Get the number of times stddev was updated and calculate the average stddev + frequency := (r.callsToAppendMovingMin + 1) - maxRTTSamplesForMovingMin + stddev = time.Duration(r.stddevSum / float64(frequency)) + + return stddev +} + func (r *rttMonitor) addSample(rtt time.Duration) { // Lock for the duration of this method. We're doing compuationally inexpensive work very infrequently, so lock // contention isn't expected. @@ -230,6 +258,7 @@ func (r *rttMonitor) addSample(rtt time.Duration) { r.appendMovingMin(rtt) r.minRTT = r.min() + r.stddevRTT = r.stddev() if !r.averageRTTSet { r.averageRTT = rtt @@ -262,7 +291,8 @@ func (r *rttMonitor) Stats() string { defer r.mu.RUnlock() return fmt.Sprintf( - "network round-trip time stats: moving avg: %v, min: %v", + "network round-trip time stats: moving avg: %v, min: %v, moving stddev: %v", r.averageRTT, - r.minRTT) + r.minRTT, + r.stddevRTT) } diff --git a/x/mongo/driver/topology/rtt_monitor_test.go b/x/mongo/driver/topology/rtt_monitor_test.go index c53535a8f2..086a5a489c 100644 --- a/x/mongo/driver/topology/rtt_monitor_test.go +++ b/x/mongo/driver/topology/rtt_monitor_test.go @@ -416,3 +416,71 @@ func TestRTTMonitor_min(t *testing.T) { }) } } + +func TestRTTMonitor_stddev(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + samples []time.Duration + want float64 + }{ + { + name: "empty", + samples: []time.Duration{}, + want: 0, + }, + { + name: "one", + samples: makeArithmeticSamples(1, 1), + want: 0, + }, + { + name: "below maxRTTSamples", + samples: makeArithmeticSamples(1, 5), + want: 0, + }, + { + name: "equal maxRTTSamples", + samples: makeArithmeticSamples(1, 10), + want: 2.872281e+06, + }, + { + name: "exceed maxRTTSamples", + samples: makeArithmeticSamples(1, 15), + want: 2.872281e+06, + }, + { + name: "non-sequential", + samples: []time.Duration{ + 2 * time.Millisecond, + 1 * time.Millisecond, + 4 * time.Millisecond, + 3 * time.Millisecond, + 7 * time.Millisecond, + 12 * time.Millisecond, + 6 * time.Millisecond, + 8 * time.Millisecond, + 5 * time.Millisecond, + 13 * time.Millisecond, + }, + want: 3.806573e+06, + }, + } + + for _, test := range tests { + test := test // capture the range variable + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + rtt := &rttMonitor{ + movingMin: list.New(), + } + for _, sample := range test.samples { + rtt.appendMovingMin(sample) + } + assert.Equal(t, test.want, float64(rtt.stddev())) + }) + } +} diff --git a/x/mongo/driver/topology/stats.go b/x/mongo/driver/topology/stats.go new file mode 100644 index 0000000000..62062de562 --- /dev/null +++ b/x/mongo/driver/topology/stats.go @@ -0,0 +1,33 @@ +// Copyright (C) MongoDB, Inc. 2024-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package topology + +import ( + "container/list" + "math" + "time" +) + +func standardDeviationList(l *list.List) float64 { + if l.Len() == 0 { + return 0 + } + + var mean, variance float64 + count := 0.0 + + for el := l.Front(); el != nil; el = el.Next() { + count++ + sample := float64(el.Value.(time.Duration)) + + delta := sample - mean + mean += delta / count + variance += delta * (sample - mean) + } + + return math.Sqrt(variance / count) +} diff --git a/x/mongo/driver/topology/stats_test.go b/x/mongo/driver/topology/stats_test.go new file mode 100644 index 0000000000..e458cfbfcc --- /dev/null +++ b/x/mongo/driver/topology/stats_test.go @@ -0,0 +1,51 @@ +// Copyright (C) MongoDB, Inc. 2022-present. +// +// Licensed under the Apache License, Version 2.0 (the "License"); you may +// not use this file except in compliance with the License. You may obtain +// a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 + +package topology + +import ( + "container/list" + "testing" + "time" + + "go.mongodb.org/mongo-driver/v2/internal/assert" +) + +func TestStandardDeviationList_Duration(t *testing.T) { + tests := []struct { + name string + data []time.Duration + want float64 + }{ + { + name: "empty", + data: []time.Duration{}, + want: 0, + }, + { + name: "multiple", + data: []time.Duration{ + time.Millisecond, + 2 * time.Millisecond, + time.Microsecond, + }, + want: 816088.36667497, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + l := list.New() + for _, d := range test.data { + l.PushBack(d) + } + + got := standardDeviationList(l) + + assert.InDelta(t, test.want, got, 1e-6) + }) + } +}