From 3629d5c09246593c4db35973a2541f1c998e322c Mon Sep 17 00:00:00 2001 From: pipiland2612 Date: Wed, 8 Oct 2025 15:10:51 +0300 Subject: [PATCH] implement Signed-off-by: pipiland2612 --- exp/api/remote/remote_api.go | 124 ++++++++++++++++---------- exp/api/remote/remote_api_test.go | 143 ++++++++++++++++++++++++++++++ 2 files changed, 221 insertions(+), 46 deletions(-) diff --git a/exp/api/remote/remote_api.go b/exp/api/remote/remote_api.go index abe533db1..e3bfd0a29 100644 --- a/exp/api/remote/remote_api.go +++ b/exp/api/remote/remote_api.go @@ -52,7 +52,12 @@ type APIOption func(o *apiOpts) error // err is the error that caused the retry. type RetryCallback func(err error) -// TODO(bwplotka): Add "too old sample" handling one day. +// MessageFilter is a function that filters or modifies the message before each write attempt. +// It receives the attempt number (0 = first attempt, 1+ = retries) and the message to be sent. +// It returns a potentially modified message, or an error if the message should not be sent. +// This can be used for age-based filtering, deduplication, or other application-level logic. +type MessageFilter func(attempt int, msg any) (filtered any, err error) + type apiOpts struct { logger *slog.Logger client *http.Client @@ -169,6 +174,7 @@ type WriteOption func(o *writeOpts) type writeOpts struct { retryCallback RetryCallback + filterFunc MessageFilter } // WithWriteRetryCallback sets a retry callback for this Write request. @@ -179,6 +185,16 @@ func WithWriteRetryCallback(callback RetryCallback) WriteOption { } } +// WithWriteFilter sets a filter function for this Write request. +// The filter is invoked before each write attempt (including the initial attempt). +// This allows filtering out old samples, deduplication, or other application-level logic. +// If the filter returns an error, the Write operation will stop and return that error. +func WithWriteFilter(filter MessageFilter) WriteOption { + return func(o *writeOpts) { + o.filterFunc = filter + } +} + type vtProtoEnabled interface { SizeVT() int MarshalToSizedBufferVT(dAtA []byte) (int, error) @@ -205,63 +221,79 @@ func (r *API) Write(ctx context.Context, msgType WriteMessageType, msg any, opts opt(&writeOpts) } - buf := r.bufPool.Get().(*[]byte) - if err := msgType.Validate(); err != nil { return WriteResponseStats{}, err } - // Encode the payload. - switch m := msg.(type) { - case vtProtoEnabled: - // Use optimized vtprotobuf if supported. - size := m.SizeVT() - if cap(*buf) < size { - *buf = make([]byte, size) - } else { - *buf = (*buf)[:size] - } + // Since we retry writes we need to track the total amount of accepted data + // across the various attempts. + accumulatedStats := WriteResponseStats{} - if _, err := m.MarshalToSizedBufferVT(*buf); err != nil { - return WriteResponseStats{}, fmt.Errorf("encoding request %w", err) - } - case gogoProtoEnabled: - // Gogo proto if supported. - size := m.Size() - if cap(*buf) < size { - *buf = make([]byte, size) - } else { - *buf = (*buf)[:size] + b := backoff.New(ctx, r.opts.backoff) + for { + // Apply filter if provided. + currentMsg := msg + if writeOpts.filterFunc != nil { + filteredMsg, err := writeOpts.filterFunc(b.NumRetries(), msg) + if err != nil { + // Filter returned error, likely no data left to send. + return accumulatedStats, err + } + currentMsg = filteredMsg } - if _, err := m.MarshalToSizedBuffer(*buf); err != nil { - return WriteResponseStats{}, fmt.Errorf("encoding request %w", err) + // Encode the payload. + buf := r.bufPool.Get().(*[]byte) + switch m := currentMsg.(type) { + case vtProtoEnabled: + // Use optimized vtprotobuf if supported. + size := m.SizeVT() + if cap(*buf) < size { + *buf = make([]byte, size) + } else { + *buf = (*buf)[:size] + } + + if _, err := m.MarshalToSizedBufferVT(*buf); err != nil { + r.bufPool.Put(buf) + return WriteResponseStats{}, fmt.Errorf("encoding request %w", err) + } + case gogoProtoEnabled: + // Gogo proto if supported. + size := m.Size() + if cap(*buf) < size { + *buf = make([]byte, size) + } else { + *buf = (*buf)[:size] + } + + if _, err := m.MarshalToSizedBuffer(*buf); err != nil { + r.bufPool.Put(buf) + return WriteResponseStats{}, fmt.Errorf("encoding request %w", err) + } + case proto.Message: + // Generic proto. + *buf, err = (proto.MarshalOptions{}).MarshalAppend(*buf, m) + if err != nil { + r.bufPool.Put(buf) + return WriteResponseStats{}, fmt.Errorf("encoding request %w", err) + } + default: + r.bufPool.Put(buf) + return WriteResponseStats{}, fmt.Errorf("unknown message type %T", m) } - case proto.Message: - // Generic proto. - *buf, err = (proto.MarshalOptions{}).MarshalAppend(*buf, m) + + comprBuf := r.bufPool.Get().(*[]byte) + payload, err := compressPayload(comprBuf, r.opts.compression, *buf) if err != nil { - return WriteResponseStats{}, fmt.Errorf("encoding request %w", err) + r.bufPool.Put(buf) + r.bufPool.Put(comprBuf) + return WriteResponseStats{}, fmt.Errorf("compressing %w", err) } - default: - return WriteResponseStats{}, fmt.Errorf("unknown message type %T", m) - } + r.bufPool.Put(buf) - comprBuf := r.bufPool.Get().(*[]byte) - payload, err := compressPayload(comprBuf, r.opts.compression, *buf) - if err != nil { - return WriteResponseStats{}, fmt.Errorf("compressing %w", err) - } - r.bufPool.Put(buf) - defer r.bufPool.Put(comprBuf) - - // Since we retry writes we need to track the total amount of accepted data - // across the various attempts. - accumulatedStats := WriteResponseStats{} - - b := backoff.New(ctx, r.opts.backoff) - for { rs, err := r.attemptWrite(ctx, r.opts.compression, msgType, payload, b.NumRetries()) + r.bufPool.Put(comprBuf) accumulatedStats.Add(rs) if err == nil { // Check the case mentioned in PRW 2.0. diff --git a/exp/api/remote/remote_api_test.go b/exp/api/remote/remote_api_test.go index f39913308..14dbbb544 100644 --- a/exp/api/remote/remote_api_test.go +++ b/exp/api/remote/remote_api_test.go @@ -295,4 +295,147 @@ func TestRemoteAPI_Write_WithHandler(t *testing.T) { t.Fatal("retry callback should not be invoked on successful request") } }) + + t.Run("filter invoked on each attempt", func(t *testing.T) { + tLogger := slog.Default() + mockCode := http.StatusInternalServerError + mStore := &mockStorage{ + mockErr: errors.New("storage error"), + mockCode: &mockCode, + } + srv := httptest.NewServer(NewWriteHandler(mStore, MessageTypes{WriteV2MessageType}, WithWriteHandlerLogger(tLogger))) + t.Cleanup(srv.Close) + + var filterInvocations []int + client, err := NewAPI(srv.URL, + WithAPIHTTPClient(srv.Client()), + WithAPILogger(tLogger), + WithAPIPath("api/v1/write"), + WithAPIBackoff(backoff.Config{ + Min: 1 * time.Millisecond, + Max: 1 * time.Millisecond, + MaxRetries: 2, + }), + ) + if err != nil { + t.Fatal(err) + } + + req := testV2() + _, err = client.Write(context.Background(), WriteV2MessageType, req, + WithWriteFilter(func(attempt int, msg any) (any, error) { + filterInvocations = append(filterInvocations, attempt) + return msg, nil + }), + ) + if err == nil { + t.Fatal("expected error, got nil") + } + + // Filter should be invoked for initial attempt (0) and 2 retries (1, 2). + expectedInvocations := []int{0, 1, 2} + if diff := cmp.Diff(expectedInvocations, filterInvocations); diff != "" { + t.Fatalf("unexpected filter invocations (-want +got):\n%s", diff) + } + }) + + t.Run("filter can modify message on retries", func(t *testing.T) { + tLogger := slog.Default() + mStore := &mockStorage{} + srv := httptest.NewServer(NewWriteHandler(mStore, MessageTypes{WriteV2MessageType}, WithWriteHandlerLogger(tLogger))) + t.Cleanup(srv.Close) + + client, err := NewAPI(srv.URL, + WithAPIHTTPClient(srv.Client()), + WithAPILogger(tLogger), + WithAPIPath("api/v1/write"), + ) + if err != nil { + t.Fatal(err) + } + + req := testV2() + originalTimeseriesCount := len(req.Timeseries) + + _, err = client.Write(context.Background(), WriteV2MessageType, req, + WithWriteFilter(func(attempt int, msg any) (any, error) { + r, ok := msg.(*writev2.Request) + if !ok { + t.Fatal("expected *writev2.Request") + } + + // On retries (attempt > 0), filter out the first timeseries. + if attempt > 0 { + filtered := &writev2.Request{ + Timeseries: r.Timeseries[1:], + Symbols: r.Symbols, + } + return filtered, nil + } + return msg, nil + }), + ) + if err != nil { + t.Fatal(err) + } + + // Verify original message was sent on first attempt. + if len(mStore.v2Reqs) != 1 { + t.Fatalf("expected 1 request stored, got %d", len(mStore.v2Reqs)) + } + if len(mStore.v2Reqs[0].Timeseries) != originalTimeseriesCount { + t.Fatalf("expected %d timeseries in stored request, got %d", + originalTimeseriesCount, len(mStore.v2Reqs[0].Timeseries)) + } + }) + + t.Run("filter error stops retries", func(t *testing.T) { + tLogger := slog.Default() + mockCode := http.StatusInternalServerError + mStore := &mockStorage{ + mockErr: errors.New("storage error"), + mockCode: &mockCode, + } + srv := httptest.NewServer(NewWriteHandler(mStore, MessageTypes{WriteV2MessageType}, WithWriteHandlerLogger(tLogger))) + t.Cleanup(srv.Close) + + var attemptCount int + client, err := NewAPI(srv.URL, + WithAPIHTTPClient(srv.Client()), + WithAPILogger(tLogger), + WithAPIPath("api/v1/write"), + WithAPIBackoff(backoff.Config{ + Min: 1 * time.Millisecond, + Max: 1 * time.Millisecond, + MaxRetries: 5, + }), + ) + if err != nil { + t.Fatal(err) + } + + req := testV2() + _, err = client.Write(context.Background(), WriteV2MessageType, req, + WithWriteFilter(func(attempt int, msg any) (any, error) { + attemptCount++ + // Return error on second retry (attempt 2). + if attempt >= 2 { + return nil, errors.New("filter rejected message") + } + return msg, nil + }), + ) + + if err == nil { + t.Fatal("expected error, got nil") + } + if !strings.Contains(err.Error(), "filter rejected message") { + t.Fatalf("expected error to contain 'filter rejected message', got %v", err) + } + + // Should only reach attempt 2 (0, 1, 2) before filter stops it. + if attemptCount != 3 { + t.Fatalf("expected 3 filter invocations (attempts 0,1,2), got %d", attemptCount) + } + }) }