Skip to content

Commit 3629d5c

Browse files
committed
implement
Signed-off-by: pipiland2612 <[email protected]>
1 parent c5900a1 commit 3629d5c

File tree

2 files changed

+221
-46
lines changed

2 files changed

+221
-46
lines changed

exp/api/remote/remote_api.go

Lines changed: 78 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,12 @@ type APIOption func(o *apiOpts) error
5252
// err is the error that caused the retry.
5353
type RetryCallback func(err error)
5454

55-
// TODO(bwplotka): Add "too old sample" handling one day.
55+
// MessageFilter is a function that filters or modifies the message before each write attempt.
56+
// It receives the attempt number (0 = first attempt, 1+ = retries) and the message to be sent.
57+
// It returns a potentially modified message, or an error if the message should not be sent.
58+
// This can be used for age-based filtering, deduplication, or other application-level logic.
59+
type MessageFilter func(attempt int, msg any) (filtered any, err error)
60+
5661
type apiOpts struct {
5762
logger *slog.Logger
5863
client *http.Client
@@ -169,6 +174,7 @@ type WriteOption func(o *writeOpts)
169174

170175
type writeOpts struct {
171176
retryCallback RetryCallback
177+
filterFunc MessageFilter
172178
}
173179

174180
// WithWriteRetryCallback sets a retry callback for this Write request.
@@ -179,6 +185,16 @@ func WithWriteRetryCallback(callback RetryCallback) WriteOption {
179185
}
180186
}
181187

188+
// WithWriteFilter sets a filter function for this Write request.
189+
// The filter is invoked before each write attempt (including the initial attempt).
190+
// This allows filtering out old samples, deduplication, or other application-level logic.
191+
// If the filter returns an error, the Write operation will stop and return that error.
192+
func WithWriteFilter(filter MessageFilter) WriteOption {
193+
return func(o *writeOpts) {
194+
o.filterFunc = filter
195+
}
196+
}
197+
182198
type vtProtoEnabled interface {
183199
SizeVT() int
184200
MarshalToSizedBufferVT(dAtA []byte) (int, error)
@@ -205,63 +221,79 @@ func (r *API) Write(ctx context.Context, msgType WriteMessageType, msg any, opts
205221
opt(&writeOpts)
206222
}
207223

208-
buf := r.bufPool.Get().(*[]byte)
209-
210224
if err := msgType.Validate(); err != nil {
211225
return WriteResponseStats{}, err
212226
}
213227

214-
// Encode the payload.
215-
switch m := msg.(type) {
216-
case vtProtoEnabled:
217-
// Use optimized vtprotobuf if supported.
218-
size := m.SizeVT()
219-
if cap(*buf) < size {
220-
*buf = make([]byte, size)
221-
} else {
222-
*buf = (*buf)[:size]
223-
}
228+
// Since we retry writes we need to track the total amount of accepted data
229+
// across the various attempts.
230+
accumulatedStats := WriteResponseStats{}
224231

225-
if _, err := m.MarshalToSizedBufferVT(*buf); err != nil {
226-
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
227-
}
228-
case gogoProtoEnabled:
229-
// Gogo proto if supported.
230-
size := m.Size()
231-
if cap(*buf) < size {
232-
*buf = make([]byte, size)
233-
} else {
234-
*buf = (*buf)[:size]
232+
b := backoff.New(ctx, r.opts.backoff)
233+
for {
234+
// Apply filter if provided.
235+
currentMsg := msg
236+
if writeOpts.filterFunc != nil {
237+
filteredMsg, err := writeOpts.filterFunc(b.NumRetries(), msg)
238+
if err != nil {
239+
// Filter returned error, likely no data left to send.
240+
return accumulatedStats, err
241+
}
242+
currentMsg = filteredMsg
235243
}
236244

237-
if _, err := m.MarshalToSizedBuffer(*buf); err != nil {
238-
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
245+
// Encode the payload.
246+
buf := r.bufPool.Get().(*[]byte)
247+
switch m := currentMsg.(type) {
248+
case vtProtoEnabled:
249+
// Use optimized vtprotobuf if supported.
250+
size := m.SizeVT()
251+
if cap(*buf) < size {
252+
*buf = make([]byte, size)
253+
} else {
254+
*buf = (*buf)[:size]
255+
}
256+
257+
if _, err := m.MarshalToSizedBufferVT(*buf); err != nil {
258+
r.bufPool.Put(buf)
259+
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
260+
}
261+
case gogoProtoEnabled:
262+
// Gogo proto if supported.
263+
size := m.Size()
264+
if cap(*buf) < size {
265+
*buf = make([]byte, size)
266+
} else {
267+
*buf = (*buf)[:size]
268+
}
269+
270+
if _, err := m.MarshalToSizedBuffer(*buf); err != nil {
271+
r.bufPool.Put(buf)
272+
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
273+
}
274+
case proto.Message:
275+
// Generic proto.
276+
*buf, err = (proto.MarshalOptions{}).MarshalAppend(*buf, m)
277+
if err != nil {
278+
r.bufPool.Put(buf)
279+
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
280+
}
281+
default:
282+
r.bufPool.Put(buf)
283+
return WriteResponseStats{}, fmt.Errorf("unknown message type %T", m)
239284
}
240-
case proto.Message:
241-
// Generic proto.
242-
*buf, err = (proto.MarshalOptions{}).MarshalAppend(*buf, m)
285+
286+
comprBuf := r.bufPool.Get().(*[]byte)
287+
payload, err := compressPayload(comprBuf, r.opts.compression, *buf)
243288
if err != nil {
244-
return WriteResponseStats{}, fmt.Errorf("encoding request %w", err)
289+
r.bufPool.Put(buf)
290+
r.bufPool.Put(comprBuf)
291+
return WriteResponseStats{}, fmt.Errorf("compressing %w", err)
245292
}
246-
default:
247-
return WriteResponseStats{}, fmt.Errorf("unknown message type %T", m)
248-
}
293+
r.bufPool.Put(buf)
249294

250-
comprBuf := r.bufPool.Get().(*[]byte)
251-
payload, err := compressPayload(comprBuf, r.opts.compression, *buf)
252-
if err != nil {
253-
return WriteResponseStats{}, fmt.Errorf("compressing %w", err)
254-
}
255-
r.bufPool.Put(buf)
256-
defer r.bufPool.Put(comprBuf)
257-
258-
// Since we retry writes we need to track the total amount of accepted data
259-
// across the various attempts.
260-
accumulatedStats := WriteResponseStats{}
261-
262-
b := backoff.New(ctx, r.opts.backoff)
263-
for {
264295
rs, err := r.attemptWrite(ctx, r.opts.compression, msgType, payload, b.NumRetries())
296+
r.bufPool.Put(comprBuf)
265297
accumulatedStats.Add(rs)
266298
if err == nil {
267299
// Check the case mentioned in PRW 2.0.

exp/api/remote/remote_api_test.go

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,4 +295,147 @@ func TestRemoteAPI_Write_WithHandler(t *testing.T) {
295295
t.Fatal("retry callback should not be invoked on successful request")
296296
}
297297
})
298+
299+
t.Run("filter invoked on each attempt", func(t *testing.T) {
300+
tLogger := slog.Default()
301+
mockCode := http.StatusInternalServerError
302+
mStore := &mockStorage{
303+
mockErr: errors.New("storage error"),
304+
mockCode: &mockCode,
305+
}
306+
srv := httptest.NewServer(NewWriteHandler(mStore, MessageTypes{WriteV2MessageType}, WithWriteHandlerLogger(tLogger)))
307+
t.Cleanup(srv.Close)
308+
309+
var filterInvocations []int
310+
client, err := NewAPI(srv.URL,
311+
WithAPIHTTPClient(srv.Client()),
312+
WithAPILogger(tLogger),
313+
WithAPIPath("api/v1/write"),
314+
WithAPIBackoff(backoff.Config{
315+
Min: 1 * time.Millisecond,
316+
Max: 1 * time.Millisecond,
317+
MaxRetries: 2,
318+
}),
319+
)
320+
if err != nil {
321+
t.Fatal(err)
322+
}
323+
324+
req := testV2()
325+
_, err = client.Write(context.Background(), WriteV2MessageType, req,
326+
WithWriteFilter(func(attempt int, msg any) (any, error) {
327+
filterInvocations = append(filterInvocations, attempt)
328+
return msg, nil
329+
}),
330+
)
331+
if err == nil {
332+
t.Fatal("expected error, got nil")
333+
}
334+
335+
// Filter should be invoked for initial attempt (0) and 2 retries (1, 2).
336+
expectedInvocations := []int{0, 1, 2}
337+
if diff := cmp.Diff(expectedInvocations, filterInvocations); diff != "" {
338+
t.Fatalf("unexpected filter invocations (-want +got):\n%s", diff)
339+
}
340+
})
341+
342+
t.Run("filter can modify message on retries", func(t *testing.T) {
343+
tLogger := slog.Default()
344+
mStore := &mockStorage{}
345+
srv := httptest.NewServer(NewWriteHandler(mStore, MessageTypes{WriteV2MessageType}, WithWriteHandlerLogger(tLogger)))
346+
t.Cleanup(srv.Close)
347+
348+
client, err := NewAPI(srv.URL,
349+
WithAPIHTTPClient(srv.Client()),
350+
WithAPILogger(tLogger),
351+
WithAPIPath("api/v1/write"),
352+
)
353+
if err != nil {
354+
t.Fatal(err)
355+
}
356+
357+
req := testV2()
358+
originalTimeseriesCount := len(req.Timeseries)
359+
360+
_, err = client.Write(context.Background(), WriteV2MessageType, req,
361+
WithWriteFilter(func(attempt int, msg any) (any, error) {
362+
r, ok := msg.(*writev2.Request)
363+
if !ok {
364+
t.Fatal("expected *writev2.Request")
365+
}
366+
367+
// On retries (attempt > 0), filter out the first timeseries.
368+
if attempt > 0 {
369+
filtered := &writev2.Request{
370+
Timeseries: r.Timeseries[1:],
371+
Symbols: r.Symbols,
372+
}
373+
return filtered, nil
374+
}
375+
return msg, nil
376+
}),
377+
)
378+
if err != nil {
379+
t.Fatal(err)
380+
}
381+
382+
// Verify original message was sent on first attempt.
383+
if len(mStore.v2Reqs) != 1 {
384+
t.Fatalf("expected 1 request stored, got %d", len(mStore.v2Reqs))
385+
}
386+
if len(mStore.v2Reqs[0].Timeseries) != originalTimeseriesCount {
387+
t.Fatalf("expected %d timeseries in stored request, got %d",
388+
originalTimeseriesCount, len(mStore.v2Reqs[0].Timeseries))
389+
}
390+
})
391+
392+
t.Run("filter error stops retries", func(t *testing.T) {
393+
tLogger := slog.Default()
394+
mockCode := http.StatusInternalServerError
395+
mStore := &mockStorage{
396+
mockErr: errors.New("storage error"),
397+
mockCode: &mockCode,
398+
}
399+
srv := httptest.NewServer(NewWriteHandler(mStore, MessageTypes{WriteV2MessageType}, WithWriteHandlerLogger(tLogger)))
400+
t.Cleanup(srv.Close)
401+
402+
var attemptCount int
403+
client, err := NewAPI(srv.URL,
404+
WithAPIHTTPClient(srv.Client()),
405+
WithAPILogger(tLogger),
406+
WithAPIPath("api/v1/write"),
407+
WithAPIBackoff(backoff.Config{
408+
Min: 1 * time.Millisecond,
409+
Max: 1 * time.Millisecond,
410+
MaxRetries: 5,
411+
}),
412+
)
413+
if err != nil {
414+
t.Fatal(err)
415+
}
416+
417+
req := testV2()
418+
_, err = client.Write(context.Background(), WriteV2MessageType, req,
419+
WithWriteFilter(func(attempt int, msg any) (any, error) {
420+
attemptCount++
421+
// Return error on second retry (attempt 2).
422+
if attempt >= 2 {
423+
return nil, errors.New("filter rejected message")
424+
}
425+
return msg, nil
426+
}),
427+
)
428+
429+
if err == nil {
430+
t.Fatal("expected error, got nil")
431+
}
432+
if !strings.Contains(err.Error(), "filter rejected message") {
433+
t.Fatalf("expected error to contain 'filter rejected message', got %v", err)
434+
}
435+
436+
// Should only reach attempt 2 (0, 1, 2) before filter stops it.
437+
if attemptCount != 3 {
438+
t.Fatalf("expected 3 filter invocations (attempts 0,1,2), got %d", attemptCount)
439+
}
440+
})
298441
}

0 commit comments

Comments
 (0)