Skip to content

Commit 0f940a6

Browse files
committed
bufix: fix Incorrect round context string formatting and retry attempt not propagated to events middleware
1 parent abf0548 commit 0f940a6

File tree

4 files changed

+127
-3
lines changed

4 files changed

+127
-3
lines changed

pkg/agent/state.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package agent
22

33
import (
4+
"fmt"
45
"sync"
56
"time"
67
)
@@ -144,7 +145,7 @@ func (s *State) RecordError(err error) {
144145
entry := ErrorEntry{
145146
Time: time.Now(),
146147
Error: err,
147-
Context: "Round " + string(rune(s.roundCount)),
148+
Context: fmt.Sprintf("Round %d", s.roundCount),
148149
}
149150

150151
s.errorLog = append(s.errorLog, entry)

pkg/dispatcher/events.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ import (
99
"github.com/Zerofisher/goai/pkg/types"
1010
)
1111

12+
// contextKey is a custom type for context keys to avoid collisions
13+
type contextKey string
14+
15+
// retryAttemptKey is the context key for retry attempt number
16+
const retryAttemptKey contextKey = "retry_attempt"
17+
1218
// ToolObserver defines the interface for observing tool execution events
1319
type ToolObserver interface {
1420
// OnToolEvent is called when a tool event occurs
@@ -196,7 +202,7 @@ func truncateOutput(output string, maxChars int) (string, map[string]interface{}
196202
// attemptFromContext extracts the retry attempt number from context
197203
// Returns 1 if not present (first attempt)
198204
func attemptFromContext(ctx context.Context) int {
199-
if attempt, ok := ctx.Value("retry_attempt").(int); ok {
205+
if attempt, ok := ctx.Value(retryAttemptKey).(int); ok {
200206
return attempt
201207
}
202208
return 1

pkg/dispatcher/events_test.go

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -399,3 +399,117 @@ func TestEventsMiddleware_NilObserver(t *testing.T) {
399399
t.Error("Expected successful execution")
400400
}
401401
}
402+
403+
// TestEventsMiddleware_WithRetryAttempt tests retry attempt tracking
404+
func TestEventsMiddleware_WithRetryAttempt(t *testing.T) {
405+
obs := &mockObserver{}
406+
opts := DefaultEventsOptions()
407+
408+
middleware := EventsMiddleware(obs, opts)
409+
410+
tu := types.ToolUse{
411+
ID: "test-retry",
412+
Name: "retry_tool",
413+
Input: map[string]interface{}{},
414+
}
415+
416+
next := func(_ context.Context, _ types.ToolUse) types.ToolResult {
417+
return types.ToolResult{
418+
ToolUseID: tu.ID,
419+
Content: "success",
420+
IsError: false,
421+
}
422+
}
423+
424+
// Test with retry_attempt set in context
425+
ctx := context.WithValue(context.Background(), retryAttemptKey, 3)
426+
middleware(ctx, tu, next)
427+
428+
events := obs.getEvents()
429+
if len(events) != 2 {
430+
t.Fatalf("Expected 2 events, got %d", len(events))
431+
}
432+
433+
// Both events should have Attempt = 3
434+
if events[0].Attempt != 3 {
435+
t.Errorf("Started event: expected Attempt=3, got %d", events[0].Attempt)
436+
}
437+
if events[1].Attempt != 3 {
438+
t.Errorf("Succeeded event: expected Attempt=3, got %d", events[1].Attempt)
439+
}
440+
}
441+
442+
// TestRetryMiddleware_WithEvents tests that retry attempts are propagated to events
443+
func TestRetryMiddleware_WithEvents(t *testing.T) {
444+
obs := &mockObserver{}
445+
opts := DefaultEventsOptions()
446+
447+
// Wrap RetryMiddleware with EventsMiddleware
448+
retryMiddleware := RetryMiddleware(2, 0) // 2 retries, no delay
449+
eventsMiddleware := EventsMiddleware(obs, opts)
450+
451+
tu := types.ToolUse{
452+
ID: "test-retry-events",
453+
Name: "flaky_tool",
454+
Input: map[string]interface{}{},
455+
}
456+
457+
attemptCount := 0
458+
next := func(_ context.Context, _ types.ToolUse) types.ToolResult {
459+
attemptCount++
460+
if attemptCount < 3 {
461+
// First two attempts fail with retryable error
462+
return types.ToolResult{
463+
ToolUseID: tu.ID,
464+
Content: "Error: connection timeout",
465+
IsError: true,
466+
}
467+
}
468+
// Third attempt succeeds
469+
return types.ToolResult{
470+
ToolUseID: tu.ID,
471+
Content: "success",
472+
IsError: false,
473+
}
474+
}
475+
476+
// Chain middlewares: retry wraps events wraps next
477+
ctx := context.Background()
478+
result := retryMiddleware(ctx, tu, func(ctx context.Context, tu types.ToolUse) types.ToolResult {
479+
return eventsMiddleware(ctx, tu, next)
480+
})
481+
482+
// Verify result is success
483+
if result.IsError {
484+
t.Errorf("Expected success after retries, got error: %s", result.Content)
485+
}
486+
487+
// Verify events
488+
events := obs.getEvents()
489+
// Should have: started+failed (attempt 1), started+failed (attempt 2), started+succeeded (attempt 3)
490+
expectedEvents := 6
491+
if len(events) != expectedEvents {
492+
t.Fatalf("Expected %d events, got %d", expectedEvents, len(events))
493+
}
494+
495+
// Check attempt numbers
496+
for i, event := range events {
497+
expectedAttempt := (i / 2) + 1 // Each pair of events (started, completed) belongs to one attempt
498+
if event.Attempt != expectedAttempt {
499+
t.Errorf("Event %d: expected Attempt=%d, got %d", i, expectedAttempt, event.Attempt)
500+
}
501+
}
502+
503+
// Verify event types
504+
if events[0].Type != types.ToolEventStarted || events[1].Type != types.ToolEventFailed {
505+
t.Error("First attempt should be started → failed")
506+
}
507+
if events[2].Type != types.ToolEventStarted || events[3].Type != types.ToolEventFailed {
508+
t.Error("Second attempt should be started → failed")
509+
}
510+
if events[4].Type != types.ToolEventStarted || events[5].Type != types.ToolEventSucceeded {
511+
t.Error("Third attempt should be started → succeeded")
512+
}
513+
}
514+
515+

pkg/dispatcher/middleware.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,11 @@ func RetryMiddleware(maxRetries int, retryDelay time.Duration) Middleware {
9292
var lastErr error
9393

9494
for i := 0; i <= maxRetries; i++ {
95+
// Set retry attempt in context for EventsMiddleware
96+
ctxWithAttempt := context.WithValue(ctx, retryAttemptKey, i+1)
97+
9598
// Execute
96-
result = next(ctx, toolUse)
99+
result = next(ctxWithAttempt, toolUse)
97100

98101
// Success or non-retryable error
99102
if !result.IsError || i == maxRetries {

0 commit comments

Comments
 (0)