Skip to content

Commit 282fd57

Browse files
committed
[!] optimize PostgresWriter.flush() with custom CopyFrom iterator (#777)
Optimize `PostgresWriter.flush()` for better performance. It avoids per-row JSON marshaling overhead. Introduced `copyFromMeasurements` to batch process `MeasurementEnvelopes` efficiently, grouping by metric and writing directly to PostgreSQL using a streaming approach. New approach reduces memory allocations and temporary objects, speeds up `flush()` with lower GC pressure.
1 parent e35caa6 commit 282fd57

File tree

4 files changed

+458
-92
lines changed

4 files changed

+458
-92
lines changed

internal/metrics/metrics.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -334,6 +334,8 @@ metrics:
334334
coalesce(reset_val, '') as value
335335
from
336336
pg_settings
337+
where
338+
name <> 'connection_ID'
337339
cpu_load:
338340
sqls:
339341
11: |-

internal/reaper/reaper.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"cmp"
55
"context"
66
"fmt"
7+
"runtime"
78
"slices"
89
"strings"
910
"time"
@@ -36,7 +37,7 @@ type Reaper struct {
3637
func NewReaper(ctx context.Context, opts *cmdopts.Options) (r *Reaper) {
3738
return &Reaper{
3839
Options: opts,
39-
measurementCh: make(chan metrics.MeasurementEnvelope, 10000),
40+
measurementCh: make(chan metrics.MeasurementEnvelope, 256),
4041
measurementCache: NewInstanceMetricCache(),
4142
logger: log.GetLogger(ctx),
4243
monitoredSources: make(sources.SourceConns, 0),
@@ -50,6 +51,17 @@ func (r *Reaper) Ready() bool {
5051
return r.ready.Load()
5152
}
5253

54+
func (r *Reaper) PrintMemStats() {
55+
var m runtime.MemStats
56+
runtime.ReadMemStats(&m)
57+
58+
bToKb := func(b uint64) uint64 {
59+
return b / 1024
60+
}
61+
r.logger.Debugf("Alloc: %d Kb, TotalAlloc: %d Kb, Sys: %d Kb, NumGC: %d, HeapAlloc: %d Kb, HeapSys: %d Kb",
62+
bToKb(m.Alloc), bToKb(m.TotalAlloc), bToKb(m.Sys), m.NumGC, bToKb(m.HeapAlloc), bToKb(m.HeapSys))
63+
}
64+
5365
// Reap() starts the main monitoring loop. It is responsible for fetching metrics measurements
5466
// from the sources and storing them to the sinks. It also manages the lifecycle of
5567
// the metric gatherers. In case of a source or metric definition change, it will
@@ -64,6 +76,9 @@ func (r *Reaper) Reap(ctx context.Context) {
6476
r.ready.Store(true)
6577

6678
for { //main loop
79+
if r.Logging.LogLevel == "debug" {
80+
r.PrintMemStats()
81+
}
6782
if err = r.LoadSources(); err != nil {
6883
logger.WithError(err).Error("could not refresh active sources, using last valid cache")
6984
}

internal/sinks/postgres.go

Lines changed: 103 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,14 @@ package sinks
33
import (
44
"context"
55
_ "embed"
6-
"encoding/json"
76
"errors"
87
"fmt"
9-
"maps"
8+
"slices"
109
"strings"
1110
"time"
1211

12+
jsoniter "github.com/json-iterator/go"
13+
1314
"github.com/cybertec-postgresql/pgwatch/v3/internal/db"
1415
"github.com/cybertec-postgresql/pgwatch/v3/internal/log"
1516
"github.com/cybertec-postgresql/pgwatch/v3/internal/metrics"
@@ -21,6 +22,7 @@ var (
2122
cacheLimit = 256
2223
highLoadTimeout = time.Second * 5
2324
deleterDelay = time.Hour
25+
targetColumns = [...]string{"time", "dbname", "data", "tag_data"}
2426
)
2527

2628
func NewPostgresWriter(ctx context.Context, connstr string, opts *CmdOpts) (pgw *PostgresWriter, err error) {
@@ -221,62 +223,105 @@ func (pgw *PostgresWriter) poll() {
221223
}
222224
}
223225

226+
func newCopyFromMeasurements(rows []metrics.MeasurementEnvelope) *copyFromMeasurements {
227+
return &copyFromMeasurements{envelopes: rows, envelopeIdx: -1, measurementIdx: -1}
228+
}
229+
230+
type copyFromMeasurements struct {
231+
envelopes []metrics.MeasurementEnvelope
232+
envelopeIdx int
233+
measurementIdx int // index of the current measurement in the envelope
234+
metricName string
235+
}
236+
237+
func (c *copyFromMeasurements) Next() bool {
238+
for {
239+
// Check if we need to advance to the next envelope
240+
if c.envelopeIdx < 0 || c.measurementIdx+1 >= len(c.envelopes[c.envelopeIdx].Data) {
241+
// Advance to next envelope
242+
c.envelopeIdx++
243+
if c.envelopeIdx >= len(c.envelopes) {
244+
return false // No more envelopes
245+
}
246+
c.measurementIdx = -1 // Reset measurement index for new envelope
247+
248+
// Set metric name from first envelope, or detect metric boundary
249+
if c.metricName == "" {
250+
c.metricName = c.envelopes[c.envelopeIdx].MetricName
251+
} else if c.metricName != c.envelopes[c.envelopeIdx].MetricName {
252+
// We've hit a different metric - we're done with current metric
253+
// Reset position to process this envelope on next call
254+
c.envelopeIdx--
255+
c.measurementIdx = len(c.envelopes[c.envelopeIdx].Data) // Set to length so we've "finished" this envelope
256+
c.metricName = "" // Reset for next metric
257+
return false
258+
}
259+
}
260+
261+
// Advance to next measurement in current envelope
262+
c.measurementIdx++
263+
if c.measurementIdx < len(c.envelopes[c.envelopeIdx].Data) {
264+
return true // Found valid measurement
265+
}
266+
// If we reach here, we've exhausted current envelope, loop will advance to next envelope
267+
}
268+
}
269+
270+
func (c *copyFromMeasurements) EOF() bool {
271+
return c.envelopeIdx >= len(c.envelopes)
272+
}
273+
274+
func (c *copyFromMeasurements) Values() ([]any, error) {
275+
row := c.envelopes[c.envelopeIdx].Data[c.measurementIdx]
276+
tagRow := c.envelopes[c.envelopeIdx].CustomTags
277+
if tagRow == nil {
278+
tagRow = make(map[string]string)
279+
}
280+
for k, v := range row {
281+
if strings.HasPrefix(k, metrics.TagPrefix) {
282+
tagRow[strings.TrimPrefix(k, metrics.TagPrefix)] = fmt.Sprintf("%v", v)
283+
delete(row, k)
284+
}
285+
}
286+
jsonTags, terr := jsoniter.ConfigFastest.MarshalToString(tagRow)
287+
json, err := jsoniter.ConfigFastest.MarshalToString(row)
288+
if err != nil || terr != nil {
289+
return nil, errors.Join(err, terr)
290+
}
291+
return []any{time.Unix(0, c.envelopes[c.envelopeIdx].Data.GetEpoch()), c.envelopes[c.envelopeIdx].DBName, json, jsonTags}, nil
292+
}
293+
294+
func (c *copyFromMeasurements) Err() error {
295+
return nil
296+
}
297+
298+
func (c *copyFromMeasurements) MetricName() pgx.Identifier {
299+
return pgx.Identifier{c.envelopes[c.envelopeIdx+1].MetricName} // Metric name is taken from the next envelope
300+
}
301+
224302
// flush sends the cached measurements to the database
225303
func (pgw *PostgresWriter) flush(msgs []metrics.MeasurementEnvelope) {
226304
if len(msgs) == 0 {
227305
return
228306
}
229307
logger := log.GetLogger(pgw.ctx)
230-
metricsToStorePerMetric := make(map[string][]MeasurementMessagePostgres)
231-
rowsBatched := 0
232-
totalRows := 0
308+
// metricsToStorePerMetric := make(map[string][]MeasurementMessagePostgres)
233309
pgPartBounds := make(map[string]ExistingPartitionInfo) // metric=min/max
234310
pgPartBoundsDbName := make(map[string]map[string]ExistingPartitionInfo) // metric=[dbname=min/max]
235311
var err error
236312

237-
for _, msg := range msgs {
238-
if len(msg.Data) == 0 {
239-
continue
313+
slices.SortFunc(msgs, func(a, b metrics.MeasurementEnvelope) int {
314+
if a.MetricName < b.MetricName {
315+
return -1
316+
} else if a.MetricName > b.MetricName {
317+
return 1
240318
}
241-
for _, dataRow := range msg.Data {
242-
var epochTime time.Time
243-
244-
tags := make(map[string]string)
245-
fields := make(map[string]any)
246-
247-
totalRows++
248-
249-
if msg.CustomTags != nil {
250-
tags = maps.Clone(msg.CustomTags)
251-
}
252-
epochTime = time.Unix(0, metrics.Measurement(dataRow).GetEpoch())
253-
for k, v := range dataRow {
254-
if v == nil || v == "" || k == metrics.EpochColumnName {
255-
continue // not storing NULLs
256-
}
257-
if strings.HasPrefix(k, metrics.TagPrefix) {
258-
tag := k[4:]
259-
tags[tag] = fmt.Sprintf("%v", v)
260-
} else {
261-
fields[k] = v
262-
}
263-
}
264-
265-
var metricsArr []MeasurementMessagePostgres
266-
var ok bool
267-
268-
metricNameTemp := msg.MetricName
269-
270-
metricsArr, ok = metricsToStorePerMetric[metricNameTemp]
271-
if !ok {
272-
metricsToStorePerMetric[metricNameTemp] = make([]MeasurementMessagePostgres, 0)
273-
}
274-
metricsArr = append(metricsArr, MeasurementMessagePostgres{Time: epochTime, DBName: msg.DBName,
275-
Metric: msg.MetricName, Data: fields, TagData: tags})
276-
metricsToStorePerMetric[metricNameTemp] = metricsArr
277-
278-
rowsBatched++
319+
return 0
320+
})
279321

322+
for _, msg := range msgs {
323+
for _, dataRow := range msg.Data {
324+
epochTime := time.Unix(0, metrics.Measurement(dataRow).GetEpoch())
280325
switch pgw.metricSchema {
281326
case DbStorageSchemaTimescale:
282327
// set min/max timestamps to check/create partitions
@@ -317,60 +362,27 @@ func (pgw *PostgresWriter) flush(msgs []metrics.MeasurementEnvelope) {
317362
default:
318363
logger.Fatal("unknown storage schema...")
319364
}
320-
if forceRecreatePartitions {
321-
forceRecreatePartitions = false
322-
}
365+
forceRecreatePartitions = false
323366
if err != nil {
324367
pgw.lastError <- err
325368
}
326369

327-
// send data to PG, with a separate COPY for all metrics
370+
var rowsBatched, n int64
328371
t1 := time.Now()
329-
330-
for metricName, metrics := range metricsToStorePerMetric {
331-
332-
getTargetTable := func() pgx.Identifier {
333-
return pgx.Identifier{metricName}
334-
}
335-
336-
getTargetColumns := func() []string {
337-
return []string{"time", "dbname", "data", "tag_data"}
338-
}
339-
340-
for _, m := range metrics {
341-
l := logger.WithField("db", m.DBName).WithField("metric", m.Metric)
342-
jsonBytes, err := json.Marshal(m.Data)
343-
if err != nil {
344-
logger.Errorf("Skipping 1 metric for [%s:%s] due to JSON conversion error: %s", m.DBName, m.Metric, err)
345-
continue
346-
}
347-
348-
getTagData := func() any {
349-
if len(m.TagData) > 0 {
350-
jsonBytesTags, err := json.Marshal(m.TagData)
351-
if err != nil {
352-
l.Error(err)
353-
return nil
354-
}
355-
return string(jsonBytesTags)
356-
}
357-
return nil
372+
cfm := newCopyFromMeasurements(msgs)
373+
for !cfm.EOF() {
374+
n, err = pgw.sinkDb.CopyFrom(context.Background(), cfm.MetricName(), targetColumns[:], cfm)
375+
rowsBatched += n
376+
if err != nil {
377+
logger.Error(err)
378+
if PgError, ok := err.(*pgconn.PgError); ok {
379+
forceRecreatePartitions = PgError.Code == "23514"
358380
}
359-
360-
rows := [][]any{{m.Time, m.DBName, string(jsonBytes), getTagData()}}
361-
362-
if _, err = pgw.sinkDb.CopyFrom(context.Background(), getTargetTable(), getTargetColumns(), pgx.CopyFromRows(rows)); err != nil {
363-
l.Error(err)
364-
if PgError, ok := err.(*pgconn.PgError); ok {
365-
forceRecreatePartitions = PgError.Code == "23514"
366-
}
367-
if forceRecreatePartitions {
368-
logger.Warning("Some metric partitions might have been removed, halting all metric storage. Trying to re-create all needed partitions on next run")
369-
}
381+
if forceRecreatePartitions {
382+
logger.Warning("Some metric partitions might have been removed, halting all metric storage. Trying to re-create all needed partitions on next run")
370383
}
371384
}
372385
}
373-
374386
diff := time.Since(t1)
375387
if err == nil {
376388
logger.WithField("rows", rowsBatched).WithField("elapsed", diff).Info("measurements written")

0 commit comments

Comments
 (0)