Skip to content

Commit 939bdf3

Browse files
committed
Improve early returns
1 parent 76b288a commit 939bdf3

File tree

1 file changed

+123
-62
lines changed

1 file changed

+123
-62
lines changed

internal/actionwait/wait.go

Lines changed: 123 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"context"
1111
"errors"
1212
"slices"
13+
"strings"
1314
"time"
1415
)
1516

@@ -88,7 +89,15 @@ type UnexpectedStateError struct {
8889
}
8990

9091
func (e *UnexpectedStateError) Error() string {
91-
return "operation entered unexpected state: " + string(e.Status)
92+
if len(e.Allowed) == 0 {
93+
return "operation entered unexpected state: " + string(e.Status)
94+
}
95+
allowedStr := make([]string, len(e.Allowed))
96+
for i, s := range e.Allowed {
97+
allowedStr[i] = string(s)
98+
}
99+
return "operation entered unexpected state: " + string(e.Status) + " (allowed: " +
100+
strings.Join(allowedStr, ", ") + ")"
92101
}
93102

94103
// sentinel errors helpers
@@ -102,21 +111,13 @@ var (
102111
// context cancellation, or fetch error occurs.
103112
// On success, the final FetchResult is returned with nil error.
104113
func WaitForStatus[T any](ctx context.Context, fetch FetchFunc[T], opts Options[T]) (FetchResult[T], error) { //nolint:cyclop // complexity driven by classification/state machine; readability preferred
105-
var zero FetchResult[T]
106-
107-
if opts.Timeout <= 0 {
108-
return zero, errors.New("actionwait: Timeout must be > 0")
109-
}
110-
if len(opts.SuccessStates) == 0 {
111-
return zero, errors.New("actionwait: at least one SuccessState required")
112-
}
113-
if opts.ConsecutiveSuccess <= 0 {
114-
opts.ConsecutiveSuccess = 1
115-
}
116-
if opts.Interval == nil {
117-
opts.Interval = FixedInterval(DefaultPollInterval)
114+
if err := validateOptions(opts); err != nil {
115+
var zero FetchResult[T]
116+
return zero, err
118117
}
119118

119+
normalizeOptions(&opts)
120+
120121
start := time.Now()
121122
deadline := start.Add(opts.Timeout)
122123
var lastProgress time.Time
@@ -130,64 +131,37 @@ func WaitForStatus[T any](ctx context.Context, fetch FetchFunc[T], opts Options[
130131
allowedTransient = append(allowedTransient, opts.TransitionalStates...)
131132

132133
for {
134+
// Early return: context cancelled
133135
if ctx.Err() != nil {
134136
return last, ctx.Err()
135137
}
136-
now := time.Now()
137-
if now.After(deadline) {
138+
139+
// Early return: timeout exceeded
140+
if time.Now().After(deadline) {
138141
return last, &TimeoutError{LastStatus: last.Status, Timeout: opts.Timeout}
139142
}
140143

144+
// Fetch current status
141145
fr, err := fetch(ctx)
142146
if err != nil {
143-
return fr, err
147+
return fr, err // Early return: fetch error
144148
}
145149
last = fr
146150

147-
// Classification precedence: failure -> success -> transitional -> unexpected
148-
if contains(opts.FailureStates, fr.Status) {
149-
return fr, &FailureStateError{Status: fr.Status}
150-
}
151-
if contains(opts.SuccessStates, fr.Status) {
152-
successStreak++
153-
if successStreak >= opts.ConsecutiveSuccess {
154-
return fr, nil
155-
}
156-
} else {
157-
successStreak = 0
158-
if len(opts.TransitionalStates) > 0 {
159-
if !contains(opts.TransitionalStates, fr.Status) {
160-
return fr, &UnexpectedStateError{Status: fr.Status, Allowed: allowedTransient}
161-
}
162-
}
151+
// Classify status and determine if we should terminate
152+
isTerminal, classifyErr := classifyStatus(fr, opts, &successStreak, allowedTransient)
153+
if isTerminal {
154+
return fr, classifyErr // Early return: terminal state (success or failure)
163155
}
164156

165-
// Progress callback throttling
166-
if opts.ProgressSink != nil && opts.ProgressInterval > 0 {
167-
if lastProgress.IsZero() || time.Since(lastProgress) >= opts.ProgressInterval {
168-
nextPoll := opts.Interval.NextPoll(attempt)
169-
opts.ProgressSink(anyFetchResult(fr), ProgressMeta{
170-
Attempt: attempt,
171-
Elapsed: time.Since(start),
172-
Remaining: maxDuration(0, time.Until(deadline)), // time.Until for clarity
173-
Deadline: deadline,
174-
NextPollIn: nextPoll,
175-
})
176-
lastProgress = time.Now()
177-
}
178-
}
157+
// Handle progress reporting
158+
handleProgressReport(opts, fr, start, deadline, attempt, &lastProgress)
179159

180-
// Sleep until next attempt
181-
sleep := opts.Interval.NextPoll(attempt)
182-
if sleep > 0 {
183-
timer := time.NewTimer(sleep)
184-
select {
185-
case <-ctx.Done():
186-
timer.Stop()
187-
return last, ctx.Err()
188-
case <-timer.C:
189-
}
160+
// Sleep until next attempt, with context cancellation check
161+
if err := sleepWithContext(ctx, opts.Interval.NextPoll(attempt)); err != nil {
162+
return last, err // Early return: context cancelled during sleep
190163
}
164+
191165
attempt++
192166
}
193167
}
@@ -197,14 +171,101 @@ func anyFetchResult[T any](fr FetchResult[T]) FetchResult[any] {
197171
return FetchResult[any]{Status: fr.Status, Value: any(fr.Value)}
198172
}
199173

200-
// contains tests membership in a slice of Status.
201-
func contains(haystack []Status, needle Status) bool {
202-
return slices.Contains(haystack, needle)
203-
}
204-
205174
func maxDuration(a, b time.Duration) time.Duration {
206175
if a > b {
207176
return a
208177
}
209178
return b
210179
}
180+
181+
// validateOptions performs early validation of required options.
182+
func validateOptions[T any](opts Options[T]) error {
183+
if opts.Timeout <= 0 {
184+
return errors.New("actionwait: Timeout must be > 0")
185+
}
186+
if len(opts.SuccessStates) == 0 {
187+
return errors.New("actionwait: at least one SuccessState required")
188+
}
189+
if opts.ConsecutiveSuccess < 0 {
190+
return errors.New("actionwait: ConsecutiveSuccess cannot be negative")
191+
}
192+
if opts.ProgressInterval < 0 {
193+
return errors.New("actionwait: ProgressInterval cannot be negative")
194+
}
195+
return nil
196+
}
197+
198+
// normalizeOptions sets defaults for optional configuration.
199+
func normalizeOptions[T any](opts *Options[T]) {
200+
if opts.ConsecutiveSuccess <= 0 {
201+
opts.ConsecutiveSuccess = 1
202+
}
203+
if opts.Interval == nil {
204+
opts.Interval = FixedInterval(DefaultPollInterval)
205+
}
206+
}
207+
208+
// classifyStatus determines the next action based on the current status.
209+
// Returns: (isTerminal, error) - if isTerminal is true, polling should stop.
210+
func classifyStatus[T any](fr FetchResult[T], opts Options[T], successStreak *int, allowedTransient []Status) (bool, error) {
211+
// Classification precedence: failure -> success -> transitional -> unexpected
212+
if slices.Contains(opts.FailureStates, fr.Status) {
213+
return true, &FailureStateError{Status: fr.Status}
214+
}
215+
216+
if slices.Contains(opts.SuccessStates, fr.Status) {
217+
*successStreak++
218+
if *successStreak >= opts.ConsecutiveSuccess {
219+
return true, nil // Success!
220+
}
221+
return false, nil // Continue polling for consecutive successes
222+
}
223+
224+
// Not a success state, reset streak
225+
*successStreak = 0
226+
227+
// Check if transitional state is allowed
228+
// If TransitionalStates is specified, status must be in that list
229+
// If TransitionalStates is empty, any non-success/non-failure state is allowed
230+
if len(opts.TransitionalStates) > 0 && !slices.Contains(opts.TransitionalStates, fr.Status) {
231+
return true, &UnexpectedStateError{Status: fr.Status, Allowed: allowedTransient}
232+
}
233+
234+
return false, nil // Continue polling
235+
}
236+
237+
// handleProgressReport sends progress updates if conditions are met.
238+
func handleProgressReport[T any](opts Options[T], fr FetchResult[T], start time.Time, deadline time.Time, attempt uint, lastProgress *time.Time) {
239+
if opts.ProgressSink == nil || opts.ProgressInterval <= 0 {
240+
return
241+
}
242+
243+
if lastProgress.IsZero() || time.Since(*lastProgress) >= opts.ProgressInterval {
244+
nextPoll := opts.Interval.NextPoll(attempt)
245+
opts.ProgressSink(anyFetchResult(fr), ProgressMeta{
246+
Attempt: attempt,
247+
Elapsed: time.Since(start),
248+
Remaining: maxDuration(0, time.Until(deadline)),
249+
Deadline: deadline,
250+
NextPollIn: nextPoll,
251+
})
252+
*lastProgress = time.Now()
253+
}
254+
}
255+
256+
// sleepWithContext sleeps for the specified duration while respecting context cancellation.
257+
func sleepWithContext(ctx context.Context, duration time.Duration) error {
258+
if duration <= 0 {
259+
return nil
260+
}
261+
262+
timer := time.NewTimer(duration)
263+
defer timer.Stop()
264+
265+
select {
266+
case <-ctx.Done():
267+
return ctx.Err()
268+
case <-timer.C:
269+
return nil
270+
}
271+
}

0 commit comments

Comments
 (0)