55 "encoding/gob"
66 "errors"
77 "fmt"
8+ "math"
89 "reflect"
910 "runtime"
1011 "sync"
@@ -55,8 +56,9 @@ type WorkflowStatus struct {
5556
5657// WorkflowState holds the runtime state for a workflow execution
5758type WorkflowState struct {
58- WorkflowID string
59- stepCounter int
59+ WorkflowID string
60+ stepCounter int
61+ isWithinStep bool
6062}
6163
6264// NextStepID returns the next step ID and increments the counter
@@ -554,24 +556,40 @@ func runAsWorkflow[P any, R any](ctx context.Context, fn WorkflowFunc[P, R], inp
554556type StepFunc [P any , R any ] func (ctx context.Context , input P ) (R , error )
555557
556558type StepParams struct {
557- MaxAttempts int
558- BackoffRate int
559+ MaxRetries int
560+ BackoffFactor float64
561+ BaseInterval time.Duration
562+ MaxInterval time.Duration
559563}
560564
561565// StepOption is a functional option for configuring step parameters
562566type StepOption func (* StepParams )
563567
564- // WithMaxAttempts sets the maximum number of retry attempts for a step
565- func WithMaxAttempts ( maxAttempts int ) StepOption {
568+ // WithStepMaxRetries sets the maximum number of retries for a step
569+ func WithStepMaxRetries ( maxRetries int ) StepOption {
566570 return func (p * StepParams ) {
567- p .MaxAttempts = maxAttempts
571+ p .MaxRetries = maxRetries
572+ }
573+ }
574+
575+ // WithBackoffFactor sets the backoff factor for retries (multiplier for exponential backoff)
576+ func WithBackoffFactor (backoffFactor float64 ) StepOption {
577+ return func (p * StepParams ) {
578+ p .BackoffFactor = backoffFactor
579+ }
580+ }
581+
582+ // WithBaseInterval sets the base delay for the first retry
583+ func WithBaseInterval (baseInterval time.Duration ) StepOption {
584+ return func (p * StepParams ) {
585+ p .BaseInterval = baseInterval
568586 }
569587}
570588
571- // WithBackoffRate sets the backoff rate for retries
572- func WithBackoffRate ( backoffRate int ) StepOption {
589+ // WithMaxInterval sets the maximum delay for retries
590+ func WithMaxInterval ( maxInterval time. Duration ) StepOption {
573591 return func (p * StepParams ) {
574- p .BackoffRate = backoffRate
592+ p .MaxInterval = maxInterval
575593 }
576594}
577595
@@ -582,16 +600,26 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op
582600
583601 operationName := runtime .FuncForPC (reflect .ValueOf (fn ).Pointer ()).Name ()
584602
585- // Apply options to build params
586- params := StepParams {}
603+ // Apply options to build params with defaults
604+ params := StepParams {
605+ MaxRetries : 0 ,
606+ BackoffFactor : 2.0 ,
607+ BaseInterval : 500 * time .Millisecond ,
608+ MaxInterval : 1 * time .Hour ,
609+ }
587610 for _ , opt := range opts {
588611 opt (& params )
589612 }
590613
591614 // Get workflow state from context
592615 workflowState , ok := ctx .Value (WorkflowStateKey ).(* WorkflowState )
593616 if ! ok || workflowState == nil {
594- return * new (R ), NewStepExecutionError ("" , operationName , "workflow state not found in context" )
617+ return * new (R ), NewStepExecutionError ("" , operationName , "workflow state not found in context: are you running this step within a workflow?" )
618+ }
619+
620+ // If within a step, just run the function directly
621+ if workflowState .isWithinStep {
622+ return fn (ctx , input )
595623 }
596624
597625 // Get next step ID
@@ -610,7 +638,59 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op
610638 return recordedOutput .output .(R ), recordedOutput .err
611639 }
612640
613- stepOutput , stepError := fn (ctx , input )
641+ // Execute step with retry logic if MaxRetries > 0
642+ stepState := WorkflowState {
643+ WorkflowID : workflowState .WorkflowID ,
644+ stepCounter : workflowState .stepCounter ,
645+ isWithinStep : true ,
646+ }
647+ stepCtx := context .WithValue (ctx , WorkflowStateKey , & stepState )
648+
649+ stepOutput , stepError := fn (stepCtx , input )
650+
651+ // Retry if MaxRetries > 0 and the first execution failed
652+ var joinedErrors error
653+ if stepError != nil && params .MaxRetries > 0 {
654+ joinedErrors = errors .Join (joinedErrors , stepError )
655+
656+ for retry := 1 ; retry <= params .MaxRetries ; retry ++ {
657+ // Calculate delay for exponential backoff
658+ delay := params .BaseInterval
659+ if retry > 1 {
660+ exponentialDelay := float64 (params .BaseInterval ) * math .Pow (params .BackoffFactor , float64 (retry - 1 ))
661+ delay = time .Duration (math .Min (exponentialDelay , float64 (params .MaxInterval )))
662+ }
663+
664+ fmt .Printf ("step %s failed, retrying %d/%d in %v: %v\n " , operationName , retry , params .MaxRetries , delay , stepError )
665+
666+ // Wait before retry
667+ select {
668+ case <- ctx .Done ():
669+ return * new (R ), NewStepExecutionError (workflowState .WorkflowID , operationName , fmt .Sprintf ("context cancelled during retry: %v" , ctx .Err ()))
670+ case <- time .After (delay ):
671+ // Continue to retry
672+ }
673+
674+ // Execute the retry
675+ stepOutput , stepError = fn (stepCtx , input )
676+
677+ // If successful, break
678+ if stepError == nil {
679+ break
680+ }
681+
682+ // Join the error with existing errors
683+ joinedErrors = errors .Join (joinedErrors , stepError )
684+
685+ // If max retries reached, create MaxStepRetriesExceeded error
686+ if retry == params .MaxRetries {
687+ stepError = NewMaxStepRetriesExceededError (workflowState .WorkflowID , operationName , params .MaxRetries , joinedErrors )
688+ break
689+ }
690+ }
691+ }
692+
693+ // Record the final result
614694 dbInput := recordOperationResultDBInput {
615695 workflowID : workflowState .WorkflowID ,
616696 operationName : operationName ,
@@ -620,9 +700,9 @@ func RunAsStep[P any, R any](ctx context.Context, fn StepFunc[P, R], input P, op
620700 }
621701 recErr := getExecutor ().systemDB .RecordOperationResult (ctx , dbInput )
622702 if recErr != nil {
623- // fmt.Println("failed to record step error:", err)
624- return * new (R ), NewStepExecutionError (workflowState .WorkflowID , operationName , fmt .Sprintf ("recording step outcome: %v" , err ))
703+ return * new (R ), NewStepExecutionError (workflowState .WorkflowID , operationName , fmt .Sprintf ("recording step outcome: %v" , recErr ))
625704 }
705+
626706 return stepOutput , stepError
627707}
628708
0 commit comments