Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 78 additions & 46 deletions exp/api/remote/remote_api.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,12 @@ type APIOption func(o *apiOpts) error
// err is the error that caused the retry.
type RetryCallback func(err error)

// TODO(bwplotka): Add "too old sample" handling one day.
// MessageFilter is a function that filters or modifies the message before each write attempt.
// It receives the attempt number (0 = first attempt, 1+ = retries) and the message to be sent.
// It returns a potentially modified message, or an error if the message should not be sent.
// This can be used for age-based filtering, deduplication, or other application-level logic.
type MessageFilter func(attempt int, msg any) (filtered any, err error)

type apiOpts struct {
logger *slog.Logger
client *http.Client
Expand Down Expand Up @@ -169,6 +174,7 @@ type WriteOption func(o *writeOpts)

type writeOpts struct {
retryCallback RetryCallback
filterFunc MessageFilter
}

// WithWriteRetryCallback sets a retry callback for this Write request.
Expand All @@ -179,6 +185,16 @@ func WithWriteRetryCallback(callback RetryCallback) WriteOption {
}
}

// WithWriteFilter sets a filter function for this Write request.
// The filter is invoked before each write attempt (including the initial attempt).
// This allows filtering out old samples, deduplication, or other application-level logic.
// If the filter returns an error, the Write operation will stop and return that error.
func WithWriteFilter(filter MessageFilter) WriteOption {
return func(o *writeOpts) {
o.filterFunc = filter
}
}

type vtProtoEnabled interface {
SizeVT() int
MarshalToSizedBufferVT(dAtA []byte) (int, error)
Expand All @@ -205,63 +221,79 @@ func (r *API) Write(ctx context.Context, msgType WriteMessageType, msg any, opts
opt(&writeOpts)
}

buf := r.bufPool.Get().(*[]byte)

if err := msgType.Validate(); err != nil {
return WriteResponseStats{}, err
}

// Encode the payload.
switch m := msg.(type) {
case vtProtoEnabled:
// Use optimized vtprotobuf if supported.
size := m.SizeVT()
if cap(*buf) < size {
*buf = make([]byte, size)
} else {
*buf = (*buf)[:size]
}
// Since we retry writes we need to track the total amount of accepted data
// across the various attempts.
accumulatedStats := WriteResponseStats{}

if _, err := m.MarshalToSizedBufferVT(*buf); err != nil {
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
}
case gogoProtoEnabled:
// Gogo proto if supported.
size := m.Size()
if cap(*buf) < size {
*buf = make([]byte, size)
} else {
*buf = (*buf)[:size]
b := backoff.New(ctx, r.opts.backoff)
for {
// Apply filter if provided.
currentMsg := msg
if writeOpts.filterFunc != nil {
filteredMsg, err := writeOpts.filterFunc(b.NumRetries(), msg)
if err != nil {
// Filter returned error, likely no data left to send.
return accumulatedStats, err
}
currentMsg = filteredMsg
}

if _, err := m.MarshalToSizedBuffer(*buf); err != nil {
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
// Encode the payload.
buf := r.bufPool.Get().(*[]byte)
switch m := currentMsg.(type) {
case vtProtoEnabled:
// Use optimized vtprotobuf if supported.
size := m.SizeVT()
if cap(*buf) < size {
*buf = make([]byte, size)
} else {
*buf = (*buf)[:size]
}

if _, err := m.MarshalToSizedBufferVT(*buf); err != nil {
r.bufPool.Put(buf)
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
}
case gogoProtoEnabled:
// Gogo proto if supported.
size := m.Size()
if cap(*buf) < size {
*buf = make([]byte, size)
} else {
*buf = (*buf)[:size]
}

if _, err := m.MarshalToSizedBuffer(*buf); err != nil {
r.bufPool.Put(buf)
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
}
case proto.Message:
// Generic proto.
*buf, err = (proto.MarshalOptions{}).MarshalAppend(*buf, m)
if err != nil {
r.bufPool.Put(buf)
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
}
default:
r.bufPool.Put(buf)
return WriteResponseStats{}, fmt.Errorf("unknown message type %T", m)
}
case proto.Message:
// Generic proto.
*buf, err = (proto.MarshalOptions{}).MarshalAppend(*buf, m)

comprBuf := r.bufPool.Get().(*[]byte)
payload, err := compressPayload(comprBuf, r.opts.compression, *buf)
if err != nil {
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
r.bufPool.Put(buf)
r.bufPool.Put(comprBuf)
return WriteResponseStats{}, fmt.Errorf("compressing %w", err)
}
default:
return WriteResponseStats{}, fmt.Errorf("unknown message type %T", m)
}
r.bufPool.Put(buf)

comprBuf := r.bufPool.Get().(*[]byte)
payload, err := compressPayload(comprBuf, r.opts.compression, *buf)
if err != nil {
return WriteResponseStats{}, fmt.Errorf("compressing %w", err)
}
r.bufPool.Put(buf)
defer r.bufPool.Put(comprBuf)

// Since we retry writes we need to track the total amount of accepted data
// across the various attempts.
accumulatedStats := WriteResponseStats{}

b := backoff.New(ctx, r.opts.backoff)
for {
rs, err := r.attemptWrite(ctx, r.opts.compression, msgType, payload, b.NumRetries())
r.bufPool.Put(comprBuf)
accumulatedStats.Add(rs)
if err == nil {
// Check the case mentioned in PRW 2.0.
Expand Down
143 changes: 143 additions & 0 deletions exp/api/remote/remote_api_test.go
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here's an idea for more realistic tests: we could implement a custom test server or modify mockStorage to handle retry logic. This would allow us to validate code behavior against scenarios such as 'fails, retries, fails, then succeeds.' Do you feel that this make sense?

Original file line number Diff line number Diff line change
Expand Up @@ -295,4 +295,147 @@ func TestRemoteAPI_Write_WithHandler(t *testing.T) {
t.Fatal("retry callback should not be invoked on successful request")
}
})

t.Run("filter invoked on each attempt", func(t *testing.T) {
tLogger := slog.Default()
mockCode := http.StatusInternalServerError
mStore := &mockStorage{
mockErr: errors.New("storage error"),
mockCode: &mockCode,
}
srv := httptest.NewServer(NewWriteHandler(mStore, MessageTypes{WriteV2MessageType}, WithWriteHandlerLogger(tLogger)))
t.Cleanup(srv.Close)

var filterInvocations []int
client, err := NewAPI(srv.URL,
WithAPIHTTPClient(srv.Client()),
WithAPILogger(tLogger),
WithAPIPath("api/v1/write"),
WithAPIBackoff(backoff.Config{
Min: 1 * time.Millisecond,
Max: 1 * time.Millisecond,
MaxRetries: 2,
}),
)
if err != nil {
t.Fatal(err)
}

req := testV2()
_, err = client.Write(context.Background(), WriteV2MessageType, req,
WithWriteFilter(func(attempt int, msg any) (any, error) {
filterInvocations = append(filterInvocations, attempt)
return msg, nil
}),
)
if err == nil {
t.Fatal("expected error, got nil")
}

// Filter should be invoked for initial attempt (0) and 2 retries (1, 2).
expectedInvocations := []int{0, 1, 2}
if diff := cmp.Diff(expectedInvocations, filterInvocations); diff != "" {
t.Fatalf("unexpected filter invocations (-want +got):\n%s", diff)
}
})

t.Run("filter can modify message on retries", func(t *testing.T) {
Copy link

@perebaj perebaj Oct 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might be misunderstanding, but I'm wondering if the test name fully matches what we're testing here? It seems like we're only attempting once, rather than simulating retry behavior. Does that sound right, or am I missing something?

tLogger := slog.Default()
mStore := &mockStorage{}
srv := httptest.NewServer(NewWriteHandler(mStore, MessageTypes{WriteV2MessageType}, WithWriteHandlerLogger(tLogger)))
t.Cleanup(srv.Close)

client, err := NewAPI(srv.URL,
WithAPIHTTPClient(srv.Client()),
WithAPILogger(tLogger),
WithAPIPath("api/v1/write"),
)
if err != nil {
t.Fatal(err)
}

req := testV2()
originalTimeseriesCount := len(req.Timeseries)

_, err = client.Write(context.Background(), WriteV2MessageType, req,
WithWriteFilter(func(attempt int, msg any) (any, error) {
r, ok := msg.(*writev2.Request)
if !ok {
t.Fatal("expected *writev2.Request")
}

// On retries (attempt > 0), filter out the first timeseries.
if attempt > 0 {
filtered := &writev2.Request{
Timeseries: r.Timeseries[1:],
Symbols: r.Symbols,
}
return filtered, nil
}
return msg, nil
}),
)
if err != nil {
t.Fatal(err)
}

// Verify original message was sent on first attempt.
if len(mStore.v2Reqs) != 1 {
t.Fatalf("expected 1 request stored, got %d", len(mStore.v2Reqs))
}
if len(mStore.v2Reqs[0].Timeseries) != originalTimeseriesCount {
t.Fatalf("expected %d timeseries in stored request, got %d",
originalTimeseriesCount, len(mStore.v2Reqs[0].Timeseries))
}
})

t.Run("filter error stops retries", func(t *testing.T) {
tLogger := slog.Default()
mockCode := http.StatusInternalServerError
mStore := &mockStorage{
mockErr: errors.New("storage error"),
mockCode: &mockCode,
}
srv := httptest.NewServer(NewWriteHandler(mStore, MessageTypes{WriteV2MessageType}, WithWriteHandlerLogger(tLogger)))
t.Cleanup(srv.Close)

var attemptCount int
client, err := NewAPI(srv.URL,
WithAPIHTTPClient(srv.Client()),
WithAPILogger(tLogger),
WithAPIPath("api/v1/write"),
WithAPIBackoff(backoff.Config{
Min: 1 * time.Millisecond,
Max: 1 * time.Millisecond,
MaxRetries: 5,
}),
)
if err != nil {
t.Fatal(err)
}

req := testV2()
_, err = client.Write(context.Background(), WriteV2MessageType, req,
WithWriteFilter(func(attempt int, msg any) (any, error) {
attemptCount++
// Return error on second retry (attempt 2).
if attempt >= 2 {
return nil, errors.New("filter rejected message")
}
return msg, nil
}),
)

if err == nil {
t.Fatal("expected error, got nil")
}
if !strings.Contains(err.Error(), "filter rejected message") {
t.Fatalf("expected error to contain 'filter rejected message', got %v", err)
}

// Should only reach attempt 2 (0, 1, 2) before filter stops it.
if attemptCount != 3 {
t.Fatalf("expected 3 filter invocations (attempts 0,1,2), got %d", attemptCount)
}
})
}
Loading