Skip to content

Commit 46805e2

Browse files
[-] fix data race in Prometheus sink, fixes #1136 (#1140)
* Fixed data race in Collect() * Added test to detect race in prometheus * Fixed linter issues * define `PromMetricCache` type * use testutil.TestContext --------- Co-authored-by: Pavlo Golub <pavlo.golub@gmail.com>
1 parent 17f7632 commit 46805e2

File tree

2 files changed

+82
-7
lines changed

2 files changed

+82
-7
lines changed

internal/sinks/prometheus.go

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,10 @@ func (promw *PrometheusWriter) Write(msg metrics.MeasurementEnvelope) error {
108108
return nil
109109
}
110110

111+
type PromMetricCache = map[string]map[string]metrics.MeasurementEnvelope // [dbUnique][metric]lastly_fetched_data
112+
111113
// Async Prom cache
112-
var promAsyncMetricCache = make(map[string]map[string]metrics.MeasurementEnvelope) // [dbUnique][metric]lastly_fetched_data
114+
var promAsyncMetricCache = make(PromMetricCache)
113115
var promAsyncMetricCacheLock = sync.RWMutex{}
114116

115117
func (promw *PrometheusWriter) PromAsyncCacheAddMetricData(dbUnique, metric string, msgArr metrics.MeasurementEnvelope) { // cache structure: [dbUnique][metric]lastly_fetched_data
@@ -124,8 +126,7 @@ func (promw *PrometheusWriter) PromAsyncCacheInitIfRequired(dbUnique, _ string)
124126
promAsyncMetricCacheLock.Lock()
125127
defer promAsyncMetricCacheLock.Unlock()
126128
if _, ok := promAsyncMetricCache[dbUnique]; !ok {
127-
metricMap := make(map[string]metrics.MeasurementEnvelope)
128-
promAsyncMetricCache[dbUnique] = metricMap
129+
promAsyncMetricCache[dbUnique] = make(map[string]metrics.MeasurementEnvelope)
129130
}
130131
}
131132

@@ -159,15 +160,21 @@ func (promw *PrometheusWriter) Collect(ch chan<- prometheus.Metric) {
159160
promw.totalScrapes.Add(1)
160161
ch <- promw.totalScrapes
161162

163+
promAsyncMetricCacheLock.Lock()
162164
if len(promAsyncMetricCache) == 0 {
165+
promAsyncMetricCacheLock.Unlock()
163166
promw.logger.Warning("No dbs configured for monitoring. Check config")
164167
ch <- promw.totalScrapeFailures
165168
promw.lastScrapeErrors.Set(0)
166169
ch <- promw.lastScrapeErrors
167170
return
168171
}
172+
snapshot := promAsyncMetricCache
173+
promAsyncMetricCache = make(PromMetricCache, len(snapshot))
174+
promAsyncMetricCacheLock.Unlock()
175+
169176
t1 := time.Now()
170-
for dbname, metricsMessages := range promAsyncMetricCache {
177+
for _, metricsMessages := range snapshot {
171178
for metric, metricMessages := range metricsMessages {
172179
if metric == "change_events" {
173180
continue // not supported
@@ -178,9 +185,6 @@ func (promw *PrometheusWriter) Collect(ch chan<- prometheus.Metric) {
178185
ch <- pm
179186
}
180187
}
181-
promAsyncMetricCacheLock.Lock()
182-
promAsyncMetricCache[dbname] = make(map[string]metrics.MeasurementEnvelope) // clear the cache for this db after metrics are collected
183-
promAsyncMetricCacheLock.Unlock()
184188
}
185189
promw.logger.WithField("count", rows).WithField("elapsed", time.Since(t1)).Info("measurements written")
186190
ch <- promw.totalScrapeFailures
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
package sinks
2+
3+
import (
4+
"sync"
5+
"testing"
6+
"time"
7+
8+
"github.com/cybertec-postgresql/pgwatch/v5/internal/metrics"
9+
"github.com/cybertec-postgresql/pgwatch/v5/internal/testutil"
10+
"github.com/prometheus/client_golang/prometheus"
11+
)
12+
13+
func TestCollect_RaceCondition_Real(_ *testing.T) {
14+
// 1. Initialize the real PrometheusWriter
15+
// Note: In the current buggy code, this shares the global 'promAsyncMetricCache'
16+
promw, _ := NewPrometheusWriter(testutil.TestContext, "127.0.0.1:0/pgwatch")
17+
18+
// 2. Register a metric so Write() actually puts data into the map
19+
_ = promw.SyncMetric("race_db", "test_metric", AddOp)
20+
21+
var wg sync.WaitGroup
22+
done := make(chan struct{})
23+
24+
// --- The Writer (Simulating Database Updates) ---
25+
wg.Go(func() {
26+
for {
27+
select {
28+
case <-done:
29+
return
30+
default:
31+
// Call the REAL Write method
32+
_ = promw.Write(metrics.MeasurementEnvelope{
33+
DBName: "race_db",
34+
MetricName: "test_metric",
35+
Data: metrics.Measurements{
36+
{
37+
metrics.EpochColumnName: time.Now().UnixNano(),
38+
"value": int64(100),
39+
},
40+
},
41+
})
42+
// No sleep here -> hammer the map as fast as possible
43+
}
44+
}
45+
})
46+
47+
// --- The Collector (Simulating Prometheus Scrapes) ---
48+
wg.Go(func() {
49+
// Prometheus provides a channel to receive metrics
50+
ch := make(chan prometheus.Metric, 10000)
51+
52+
// Scrape 50 times (more than enough to trigger a race in a tight loop)
53+
for range 50 {
54+
// Call the REAL Collect method
55+
promw.Collect(ch)
56+
57+
// Drain the channel so it doesn't block
58+
drainLoop:
59+
for {
60+
select {
61+
case <-ch:
62+
default:
63+
break drainLoop
64+
}
65+
}
66+
}
67+
close(done) // Tell the writer to stop
68+
})
69+
70+
wg.Wait()
71+
}

0 commit comments

Comments
 (0)