@@ -399,3 +399,117 @@ func TestEventsMiddleware_NilObserver(t *testing.T) {
399399 t .Error ("Expected successful execution" )
400400 }
401401}
402+
403+ // TestEventsMiddleware_WithRetryAttempt tests retry attempt tracking
404+ func TestEventsMiddleware_WithRetryAttempt (t * testing.T ) {
405+ obs := & mockObserver {}
406+ opts := DefaultEventsOptions ()
407+
408+ middleware := EventsMiddleware (obs , opts )
409+
410+ tu := types.ToolUse {
411+ ID : "test-retry" ,
412+ Name : "retry_tool" ,
413+ Input : map [string ]interface {}{},
414+ }
415+
416+ next := func (_ context.Context , _ types.ToolUse ) types.ToolResult {
417+ return types.ToolResult {
418+ ToolUseID : tu .ID ,
419+ Content : "success" ,
420+ IsError : false ,
421+ }
422+ }
423+
424+ // Test with retry_attempt set in context
425+ ctx := context .WithValue (context .Background (), retryAttemptKey , 3 )
426+ middleware (ctx , tu , next )
427+
428+ events := obs .getEvents ()
429+ if len (events ) != 2 {
430+ t .Fatalf ("Expected 2 events, got %d" , len (events ))
431+ }
432+
433+ // Both events should have Attempt = 3
434+ if events [0 ].Attempt != 3 {
435+ t .Errorf ("Started event: expected Attempt=3, got %d" , events [0 ].Attempt )
436+ }
437+ if events [1 ].Attempt != 3 {
438+ t .Errorf ("Succeeded event: expected Attempt=3, got %d" , events [1 ].Attempt )
439+ }
440+ }
441+
442+ // TestRetryMiddleware_WithEvents tests that retry attempts are propagated to events
443+ func TestRetryMiddleware_WithEvents (t * testing.T ) {
444+ obs := & mockObserver {}
445+ opts := DefaultEventsOptions ()
446+
447+ // Wrap RetryMiddleware with EventsMiddleware
448+ retryMiddleware := RetryMiddleware (2 , 0 ) // 2 retries, no delay
449+ eventsMiddleware := EventsMiddleware (obs , opts )
450+
451+ tu := types.ToolUse {
452+ ID : "test-retry-events" ,
453+ Name : "flaky_tool" ,
454+ Input : map [string ]interface {}{},
455+ }
456+
457+ attemptCount := 0
458+ next := func (_ context.Context , _ types.ToolUse ) types.ToolResult {
459+ attemptCount ++
460+ if attemptCount < 3 {
461+ // First two attempts fail with retryable error
462+ return types.ToolResult {
463+ ToolUseID : tu .ID ,
464+ Content : "Error: connection timeout" ,
465+ IsError : true ,
466+ }
467+ }
468+ // Third attempt succeeds
469+ return types.ToolResult {
470+ ToolUseID : tu .ID ,
471+ Content : "success" ,
472+ IsError : false ,
473+ }
474+ }
475+
476+ // Chain middlewares: retry wraps events wraps next
477+ ctx := context .Background ()
478+ result := retryMiddleware (ctx , tu , func (ctx context.Context , tu types.ToolUse ) types.ToolResult {
479+ return eventsMiddleware (ctx , tu , next )
480+ })
481+
482+ // Verify result is success
483+ if result .IsError {
484+ t .Errorf ("Expected success after retries, got error: %s" , result .Content )
485+ }
486+
487+ // Verify events
488+ events := obs .getEvents ()
489+ // Should have: started+failed (attempt 1), started+failed (attempt 2), started+succeeded (attempt 3)
490+ expectedEvents := 6
491+ if len (events ) != expectedEvents {
492+ t .Fatalf ("Expected %d events, got %d" , expectedEvents , len (events ))
493+ }
494+
495+ // Check attempt numbers
496+ for i , event := range events {
497+ expectedAttempt := (i / 2 ) + 1 // Each pair of events (started, completed) belongs to one attempt
498+ if event .Attempt != expectedAttempt {
499+ t .Errorf ("Event %d: expected Attempt=%d, got %d" , i , expectedAttempt , event .Attempt )
500+ }
501+ }
502+
503+ // Verify event types
504+ if events [0 ].Type != types .ToolEventStarted || events [1 ].Type != types .ToolEventFailed {
505+ t .Error ("First attempt should be started → failed" )
506+ }
507+ if events [2 ].Type != types .ToolEventStarted || events [3 ].Type != types .ToolEventFailed {
508+ t .Error ("Second attempt should be started → failed" )
509+ }
510+ if events [4 ].Type != types .ToolEventStarted || events [5 ].Type != types .ToolEventSucceeded {
511+ t .Error ("Third attempt should be started → succeeded" )
512+ }
513+ }
514+
515+
0 commit comments