diff --git a/.teamcity/scripts/provider_tests/acceptance_tests.sh b/.teamcity/scripts/provider_tests/acceptance_tests.sh index f42d935c248d..539ee5a129f8 100644 --- a/.teamcity/scripts/provider_tests/acceptance_tests.sh +++ b/.teamcity/scripts/provider_tests/acceptance_tests.sh @@ -34,6 +34,7 @@ fi TF_ACC=1 go test \ ./internal/acctest/... \ + ./internal/actionwait/... \ ./internal/attrmap/... \ ./internal/backoff/... \ ./internal/conns/... \ diff --git a/.teamcity/scripts/provider_tests/unit_tests.sh b/.teamcity/scripts/provider_tests/unit_tests.sh index 2f0072dc45e0..c3bf8454caf7 100644 --- a/.teamcity/scripts/provider_tests/unit_tests.sh +++ b/.teamcity/scripts/provider_tests/unit_tests.sh @@ -6,6 +6,7 @@ set -euo pipefail go test \ ./internal/acctest/... \ + ./internal/actionwait/... \ ./internal/attrmap/... \ ./internal/backoff/... \ ./internal/conns/... \ diff --git a/internal/actionwait/errors.go b/internal/actionwait/errors.go new file mode 100644 index 000000000000..58e440763b88 --- /dev/null +++ b/internal/actionwait/errors.go @@ -0,0 +1,70 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package actionwait + +import ( + "errors" + "strings" + "time" +) + +// TimeoutError is returned when the operation does not reach a success state within Timeout. +type TimeoutError struct { + LastStatus Status + Timeout time.Duration +} + +func (e *TimeoutError) Error() string { + return "timeout waiting for target status after " + e.Timeout.String() +} + +// FailureStateError indicates the operation entered a declared failure state. +type FailureStateError struct { + Status Status +} + +func (e *FailureStateError) Error() string { + return "operation entered failure state: " + string(e.Status) +} + +// UnexpectedStateError indicates the operation entered a state outside success/transitional/failure sets. +type UnexpectedStateError struct { + Status Status + Allowed []Status +} + +func (e *UnexpectedStateError) Error() string { + if len(e.Allowed) == 0 { + return "operation entered unexpected state: " + string(e.Status) + } + allowedStr := make([]string, len(e.Allowed)) + for i, s := range e.Allowed { + allowedStr[i] = string(s) + } + return "operation entered unexpected state: " + string(e.Status) + " (allowed: " + + strings.Join(allowedStr, ", ") + ")" +} + +// Error type assertions for compile-time verification +var ( + _ error = (*TimeoutError)(nil) + _ error = (*FailureStateError)(nil) + _ error = (*UnexpectedStateError)(nil) +) + +// Helper functions for error type checking +func IsTimeout(err error) bool { + var timeoutErr *TimeoutError + return errors.As(err, &timeoutErr) +} + +func IsFailureState(err error) bool { + var failureErr *FailureStateError + return errors.As(err, &failureErr) +} + +func IsUnexpectedState(err error) bool { + var unexpectedErr *UnexpectedStateError + return errors.As(err, &unexpectedErr) +} diff --git a/internal/actionwait/errors_test.go b/internal/actionwait/errors_test.go new file mode 100644 index 000000000000..c7379df86521 --- /dev/null +++ b/internal/actionwait/errors_test.go @@ -0,0 +1,283 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package actionwait + +import ( + "errors" + "strings" + "testing" + "time" +) + +func TestTimeoutError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err *TimeoutError + wantMsg string + wantType string + }{ + { + name: "with last status", + err: &TimeoutError{ + LastStatus: "CREATING", + Timeout: 5 * time.Minute, + }, + wantMsg: "timeout waiting for target status after 5m0s", + wantType: "*actionwait.TimeoutError", + }, + { + name: "with empty status", + err: &TimeoutError{ + LastStatus: "", + Timeout: 30 * time.Second, + }, + wantMsg: "timeout waiting for target status after 30s", + wantType: "*actionwait.TimeoutError", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.err.Error(); got != tt.wantMsg { + t.Errorf("TimeoutError.Error() = %q, want %q", got, tt.wantMsg) + } + + // Verify it implements error interface + var err error = tt.err + if got := err.Error(); got != tt.wantMsg { + t.Errorf("TimeoutError as error.Error() = %q, want %q", got, tt.wantMsg) + } + }) + } +} + +func TestFailureStateError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err *FailureStateError + wantMsg string + }{ + { + name: "with status", + err: &FailureStateError{ + Status: "FAILED", + }, + wantMsg: "operation entered failure state: FAILED", + }, + { + name: "with empty status", + err: &FailureStateError{ + Status: "", + }, + wantMsg: "operation entered failure state: ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.err.Error(); got != tt.wantMsg { + t.Errorf("FailureStateError.Error() = %q, want %q", got, tt.wantMsg) + } + }) + } +} + +func TestUnexpectedStateError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err *UnexpectedStateError + wantMsg string + }{ + { + name: "no allowed states", + err: &UnexpectedStateError{ + Status: "UNKNOWN", + Allowed: nil, + }, + wantMsg: "operation entered unexpected state: UNKNOWN", + }, + { + name: "empty allowed states", + err: &UnexpectedStateError{ + Status: "UNKNOWN", + Allowed: []Status{}, + }, + wantMsg: "operation entered unexpected state: UNKNOWN", + }, + { + name: "single allowed state", + err: &UnexpectedStateError{ + Status: "UNKNOWN", + Allowed: []Status{"AVAILABLE"}, + }, + wantMsg: "operation entered unexpected state: UNKNOWN (allowed: AVAILABLE)", + }, + { + name: "multiple allowed states", + err: &UnexpectedStateError{ + Status: "UNKNOWN", + Allowed: []Status{"CREATING", "AVAILABLE", "UPDATING"}, + }, + wantMsg: "operation entered unexpected state: UNKNOWN (allowed: CREATING, AVAILABLE, UPDATING)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.err.Error(); got != tt.wantMsg { + t.Errorf("UnexpectedStateError.Error() = %q, want %q", got, tt.wantMsg) + } + }) + } +} + +func TestErrorTypeChecking(t *testing.T) { + t.Parallel() + + // Create instances of each error type + timeoutErr := &TimeoutError{LastStatus: "CREATING", Timeout: time.Minute} + failureErr := &FailureStateError{Status: "FAILED"} + unexpectedErr := &UnexpectedStateError{Status: "UNKNOWN", Allowed: []Status{"AVAILABLE"}} + genericErr := errors.New("generic error") + + tests := []struct { + name string + err error + wantIsTimeout bool + wantIsFailure bool + wantIsUnexpected bool + }{ + { + name: "TimeoutError", + err: timeoutErr, + wantIsTimeout: true, + wantIsFailure: false, + wantIsUnexpected: false, + }, + { + name: "FailureStateError", + err: failureErr, + wantIsTimeout: false, + wantIsFailure: true, + wantIsUnexpected: false, + }, + { + name: "UnexpectedStateError", + err: unexpectedErr, + wantIsTimeout: false, + wantIsFailure: false, + wantIsUnexpected: true, + }, + { + name: "generic error", + err: genericErr, + wantIsTimeout: false, + wantIsFailure: false, + wantIsUnexpected: false, + }, + { + name: "nil error", + err: nil, + wantIsTimeout: false, + wantIsFailure: false, + wantIsUnexpected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := IsTimeout(tt.err); got != tt.wantIsTimeout { + t.Errorf("IsTimeout(%v) = %v, want %v", tt.err, got, tt.wantIsTimeout) + } + + if got := IsFailureState(tt.err); got != tt.wantIsFailure { + t.Errorf("IsFailureState(%v) = %v, want %v", tt.err, got, tt.wantIsFailure) + } + + if got := IsUnexpectedState(tt.err); got != tt.wantIsUnexpected { + t.Errorf("IsUnexpectedState(%v) = %v, want %v", tt.err, got, tt.wantIsUnexpected) + } + }) + } +} + +func TestWrappedErrors(t *testing.T) { + t.Parallel() + + // Test that error type checking works with wrapped errors + baseErr := &TimeoutError{LastStatus: "CREATING", Timeout: time.Minute} + wrappedErr := errors.New("wrapped: " + baseErr.Error()) + + // Direct error should be detected + if !IsTimeout(baseErr) { + t.Errorf("IsTimeout should detect direct TimeoutError") + } + + // Wrapped string error should NOT be detected (this is expected behavior) + if IsTimeout(wrappedErr) { + t.Errorf("IsTimeout should not detect string-wrapped error") + } + + // But wrapped with errors.Join should work + joinedErr := errors.Join(baseErr, errors.New("additional context")) + if !IsTimeout(joinedErr) { + t.Errorf("IsTimeout should detect error in errors.Join") + } +} + +func TestErrorMessages(t *testing.T) { + t.Parallel() + + // Verify error messages contain expected components for debugging + timeoutErr := &TimeoutError{ + LastStatus: "PENDING", + Timeout: 2 * time.Minute, + } + + msg := timeoutErr.Error() + if !strings.Contains(msg, "timeout") { + t.Errorf("TimeoutError message should contain 'timeout', got: %q", msg) + } + if !strings.Contains(msg, "2m0s") { + t.Errorf("TimeoutError message should contain timeout duration, got: %q", msg) + } + + failureErr := &FailureStateError{Status: "ERROR"} + msg = failureErr.Error() + if !strings.Contains(msg, "failure state") { + t.Errorf("FailureStateError message should contain 'failure state', got: %q", msg) + } + if !strings.Contains(msg, "ERROR") { + t.Errorf("FailureStateError message should contain status, got: %q", msg) + } + + unexpectedErr := &UnexpectedStateError{ + Status: "WEIRD", + Allowed: []Status{"GOOD", "BETTER"}, + } + msg = unexpectedErr.Error() + if !strings.Contains(msg, "unexpected state") { + t.Errorf("UnexpectedStateError message should contain 'unexpected state', got: %q", msg) + } + if !strings.Contains(msg, "WEIRD") { + t.Errorf("UnexpectedStateError message should contain actual status, got: %q", msg) + } + if !strings.Contains(msg, "GOOD, BETTER") { + t.Errorf("UnexpectedStateError message should contain allowed states, got: %q", msg) + } +} diff --git a/internal/actionwait/wait.go b/internal/actionwait/wait.go new file mode 100644 index 000000000000..1f095f7deef0 --- /dev/null +++ b/internal/actionwait/wait.go @@ -0,0 +1,254 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +// Package actionwait provides a lightweight, action-focused polling helper +// for imperative Terraform actions which need to await asynchronous AWS +// operation completion with periodic user progress events. +package actionwait + +import ( + "context" + "errors" + "slices" + "time" + + "github.com/hashicorp/terraform-provider-aws/internal/backoff" +) + +// DefaultPollInterval is the default fixed polling interval used when no custom IntervalStrategy is provided. +const DefaultPollInterval = 30 * time.Second + +// Status represents a string status value returned from a polled API. +type Status string + +// FetchResult wraps the latest status (and optional value) from a poll attempt. +// Value may be a richer SDK structure (pointer) or zero for simple cases. +type FetchResult[T any] struct { + Status Status + Value T +} + +// FetchFunc retrieves the latest state of an asynchronous operation. +// It should be side-effect free aside from the remote read. +type FetchFunc[T any] func(context.Context) (FetchResult[T], error) + +// IntervalStrategy allows pluggable poll interval behavior (fixed, backoff, etc.). +type IntervalStrategy interface { //nolint:interfacebloat // single method interface (tiny intentional interface) + NextPoll(attempt uint) time.Duration +} + +// FixedInterval implements IntervalStrategy with a constant delay. +type FixedInterval time.Duration + +// NextPoll returns the fixed duration. +func (fi FixedInterval) NextPoll(uint) time.Duration { return time.Duration(fi) } + +// BackoffInterval implements IntervalStrategy using a backoff.Delay strategy. +// This allows actionwait to leverage sophisticated backoff algorithms while +// maintaining the declarative status-based polling approach. +type BackoffInterval struct { + delay backoff.Delay +} + +// NextPoll returns the next polling interval using the wrapped backoff delay strategy. +func (bi BackoffInterval) NextPoll(attempt uint) time.Duration { + return bi.delay.Next(attempt) +} + +// WithBackoffDelay creates an IntervalStrategy that uses the provided backoff.Delay. +// This bridges actionwait's IntervalStrategy interface with the backoff package's +// delay strategies (fixed, exponential, SDK-compatible, etc.). +// +// Example usage: +// +// opts := actionwait.Options[MyType]{ +// Interval: actionwait.WithBackoffDelay(backoff.FixedDelay(time.Second)), +// // ... other options +// } +func WithBackoffDelay(delay backoff.Delay) IntervalStrategy { + return BackoffInterval{delay: delay} +} + +// Options configure the WaitForStatus loop. +type Options[T any] struct { + Timeout time.Duration // Required total timeout. + Interval IntervalStrategy // Poll interval strategy (default: 30s fixed). + ProgressInterval time.Duration // Throttle for ProgressSink (default: disabled if <=0). + SuccessStates []Status // Required (>=1) terminal success states. + TransitionalStates []Status // Optional allowed in-flight states. + FailureStates []Status // Optional explicit failure states. + ConsecutiveSuccess int // Number of consecutive successes required (default 1). + ProgressSink func(fr FetchResult[any], meta ProgressMeta) +} + +// ProgressMeta supplies metadata for progress callbacks. +type ProgressMeta struct { + Attempt uint + Elapsed time.Duration + Remaining time.Duration + Deadline time.Time + NextPollIn time.Duration +} + +// WaitForStatus polls using fetch until a success state, failure state, timeout, unexpected state, +// context cancellation, or fetch error occurs. +// On success, the final FetchResult is returned with nil error. +func WaitForStatus[T any](ctx context.Context, fetch FetchFunc[T], opts Options[T]) (FetchResult[T], error) { //nolint:cyclop // complexity driven by classification/state machine; readability preferred + if err := validateOptions(opts); err != nil { + var zero FetchResult[T] + return zero, err + } + + normalizeOptions(&opts) + + start := time.Now() + deadline := start.Add(opts.Timeout) + var lastProgress time.Time + var attempt uint + var successStreak int + var last FetchResult[T] + + // Precompute allowed states for unexpected classification (success + transitional + failure) + // Failure states are excluded from Allowed to ensure they classify distinctly. + allowedTransient := append([]Status{}, opts.SuccessStates...) + allowedTransient = append(allowedTransient, opts.TransitionalStates...) + + for { + // Early return: context cancelled + if ctx.Err() != nil { + return last, ctx.Err() + } + + // Early return: timeout exceeded + if time.Now().After(deadline) { + return last, &TimeoutError{LastStatus: last.Status, Timeout: opts.Timeout} + } + + // Fetch current status + fr, err := fetch(ctx) + if err != nil { + return fr, err // Early return: fetch error + } + last = fr + + // Classify status and determine if we should terminate + isTerminal, classifyErr := classifyStatus(fr, opts, &successStreak, allowedTransient) + if isTerminal { + return fr, classifyErr // Early return: terminal state (success or failure) + } + + // Handle progress reporting + handleProgressReport(opts, fr, start, deadline, attempt, &lastProgress) + + // Sleep until next attempt, with context cancellation check + if err := sleepWithContext(ctx, opts.Interval.NextPoll(attempt)); err != nil { + return last, err // Early return: context cancelled during sleep + } + + attempt++ + } +} + +// anyFetchResult converts a typed FetchResult[T] into FetchResult[any] for ProgressSink. +func anyFetchResult[T any](fr FetchResult[T]) FetchResult[any] { + return FetchResult[any]{Status: fr.Status, Value: any(fr.Value)} +} + +func maxDuration(a, b time.Duration) time.Duration { + if a > b { + return a + } + return b +} + +// validateOptions performs early validation of required options. +func validateOptions[T any](opts Options[T]) error { + if opts.Timeout <= 0 { + return errors.New("actionwait: Timeout must be > 0") + } + if len(opts.SuccessStates) == 0 { + return errors.New("actionwait: at least one SuccessState required") + } + if opts.ConsecutiveSuccess < 0 { + return errors.New("actionwait: ConsecutiveSuccess cannot be negative") + } + if opts.ProgressInterval < 0 { + return errors.New("actionwait: ProgressInterval cannot be negative") + } + return nil +} + +// normalizeOptions sets defaults for optional configuration. +func normalizeOptions[T any](opts *Options[T]) { + if opts.ConsecutiveSuccess <= 0 { + opts.ConsecutiveSuccess = 1 + } + if opts.Interval == nil { + opts.Interval = FixedInterval(DefaultPollInterval) + } +} + +// classifyStatus determines the next action based on the current status. +// Returns: (isTerminal, error) - if isTerminal is true, polling should stop. +func classifyStatus[T any](fr FetchResult[T], opts Options[T], successStreak *int, allowedTransient []Status) (bool, error) { + // Classification precedence: failure -> success -> transitional -> unexpected + if slices.Contains(opts.FailureStates, fr.Status) { + return true, &FailureStateError{Status: fr.Status} + } + + if slices.Contains(opts.SuccessStates, fr.Status) { + *successStreak++ + if *successStreak >= opts.ConsecutiveSuccess { + return true, nil // Success! + } + return false, nil // Continue polling for consecutive successes + } + + // Not a success state, reset streak + *successStreak = 0 + + // Check if transitional state is allowed + // If TransitionalStates is specified, status must be in that list + // If TransitionalStates is empty, any non-success/non-failure state is allowed + if len(opts.TransitionalStates) > 0 && !slices.Contains(opts.TransitionalStates, fr.Status) { + return true, &UnexpectedStateError{Status: fr.Status, Allowed: allowedTransient} + } + + return false, nil // Continue polling +} + +// handleProgressReport sends progress updates if conditions are met. +func handleProgressReport[T any](opts Options[T], fr FetchResult[T], start time.Time, deadline time.Time, attempt uint, lastProgress *time.Time) { + if opts.ProgressSink == nil || opts.ProgressInterval <= 0 { + return + } + + if lastProgress.IsZero() || time.Since(*lastProgress) >= opts.ProgressInterval { + nextPoll := opts.Interval.NextPoll(attempt) + opts.ProgressSink(anyFetchResult(fr), ProgressMeta{ + Attempt: attempt, + Elapsed: time.Since(start), + Remaining: maxDuration(0, time.Until(deadline)), + Deadline: deadline, + NextPollIn: nextPoll, + }) + *lastProgress = time.Now() + } +} + +// sleepWithContext sleeps for the specified duration while respecting context cancellation. +func sleepWithContext(ctx context.Context, duration time.Duration) error { + if duration <= 0 { + return nil + } + + timer := time.NewTimer(duration) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/internal/actionwait/wait_test.go b/internal/actionwait/wait_test.go new file mode 100644 index 000000000000..22096adb77f3 --- /dev/null +++ b/internal/actionwait/wait_test.go @@ -0,0 +1,423 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package actionwait + +import ( + "context" + "errors" + "strings" + "sync/atomic" + "testing" + "time" + + "github.com/hashicorp/terraform-provider-aws/internal/backoff" +) + +// fastFixedInterval returns a very small fixed interval to speed tests. +const fastFixedInterval = 5 * time.Millisecond + +// makeCtx creates a context with generous overall test timeout safeguard. +func makeCtx(t *testing.T) context.Context { // test helper + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + return ctx +} + +func TestWaitForStatus_ValidationErrors(t *testing.T) { + t.Parallel() + // Subtests parallelized; each uses its own context with timeout. + cases := map[string]Options[struct{}]{ + "missing timeout": {SuccessStates: []Status{"ok"}}, + "missing success": {Timeout: time.Second}, + "negative consecutive": {Timeout: time.Second, SuccessStates: []Status{"ok"}, ConsecutiveSuccess: -1}, + "negative progress interval": {Timeout: time.Second, SuccessStates: []Status{"ok"}, ProgressInterval: -time.Second}, + } + + for name, opts := range cases { + t.Run(name, func(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + _, err := WaitForStatus(ctx, func(context.Context) (FetchResult[struct{}], error) { + return FetchResult[struct{}]{Status: "irrelevant"}, nil + }, opts) + if err == nil { + t.Fatalf("expected validation error") + } + }) + } +} + +func TestWaitForStatus_SuccessImmediate(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + fr, err := WaitForStatus(ctx, func(context.Context) (FetchResult[int], error) { + return FetchResult[int]{Status: "DONE", Value: 42}, nil + }, Options[int]{ + Timeout: 250 * time.Millisecond, + SuccessStates: []Status{"DONE"}, + Interval: FixedInterval(fastFixedInterval), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if fr.Value != 42 || fr.Status != "DONE" { + t.Fatalf("unexpected result: %#v", fr) + } +} + +func TestWaitForStatus_SuccessAfterTransitions(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + var calls int32 + fr, err := WaitForStatus(ctx, func(context.Context) (FetchResult[string], error) { + c := atomic.AddInt32(&calls, 1) + switch c { + case 1, 2: + return FetchResult[string]{Status: "IN_PROGRESS", Value: "step"}, nil + default: + return FetchResult[string]{Status: "COMPLETE", Value: "done"}, nil + } + }, Options[string]{ + Timeout: 500 * time.Millisecond, + SuccessStates: []Status{"COMPLETE"}, + TransitionalStates: []Status{"IN_PROGRESS"}, + Interval: FixedInterval(fastFixedInterval), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if fr.Status != "COMPLETE" || fr.Value != "done" { + t.Fatalf("unexpected final result: %#v", fr) + } +} + +func TestWaitForStatus_FailureState(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + fr, err := WaitForStatus(ctx, func(context.Context) (FetchResult[struct{}], error) { + return FetchResult[struct{}]{Status: "FAILED"}, nil + }, Options[struct{}]{ + Timeout: 200 * time.Millisecond, + SuccessStates: []Status{"SUCCEEDED"}, + FailureStates: []Status{"FAILED"}, + Interval: FixedInterval(fastFixedInterval), + }) + if err == nil { + t.Fatal("expected failure error") + } + if _, ok := err.(*FailureStateError); !ok { //nolint:errorlint // direct type assertion adequate in tests + t.Fatalf("expected FailureStateError, got %T", err) + } + if fr.Status != "FAILED" { + t.Fatalf("unexpected status: %v", fr.Status) + } +} + +func TestWaitForStatus_UnexpectedState_WithTransitional(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + _, err := WaitForStatus(ctx, func(context.Context) (FetchResult[int], error) { + return FetchResult[int]{Status: "UNKNOWN"}, nil + }, Options[int]{ + Timeout: 200 * time.Millisecond, + SuccessStates: []Status{"OK"}, + TransitionalStates: []Status{"PENDING"}, + Interval: FixedInterval(fastFixedInterval), + }) + if err == nil { + t.Fatal("expected unexpected state error") + } + if _, ok := err.(*UnexpectedStateError); !ok { //nolint:errorlint // direct type assertion adequate in tests + t.Fatalf("expected UnexpectedStateError, got %T", err) + } +} + +func TestWaitForStatus_NoTransitionalListAllowsAnyUntilTimeout(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + start := time.Now() + _, err := WaitForStatus(ctx, func(context.Context) (FetchResult[struct{}], error) { + return FetchResult[struct{}]{Status: "WHATEVER"}, nil + }, Options[struct{}]{ + Timeout: 50 * time.Millisecond, + SuccessStates: []Status{"DONE"}, + Interval: FixedInterval(10 * time.Millisecond), + }) + if err == nil { + t.Fatal("expected timeout error") + } + if _, ok := err.(*TimeoutError); !ok { //nolint:errorlint // direct type assertion adequate in tests + t.Fatalf("expected TimeoutError, got %T", err) + } + if time.Since(start) < 40*time.Millisecond { // sanity that we actually waited + t.Fatalf("timeout returned too early") + } +} + +func TestWaitForStatus_ContextCancel(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(makeCtx(t)) + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + _, err := WaitForStatus(ctx, func(context.Context) (FetchResult[struct{}], error) { + return FetchResult[struct{}]{Status: "PENDING"}, nil + }, Options[struct{}]{ + Timeout: 500 * time.Millisecond, + SuccessStates: []Status{"DONE"}, + Interval: FixedInterval(fastFixedInterval), + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } +} + +func TestWaitForStatus_FetchErrorPropagation(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + testErr := errors.New("boom") + _, err := WaitForStatus(ctx, func(context.Context) (FetchResult[int], error) { + return FetchResult[int]{}, testErr + }, Options[int]{ + Timeout: 200 * time.Millisecond, + SuccessStates: []Status{"OK"}, + Interval: FixedInterval(fastFixedInterval), + }) + if !errors.Is(err, testErr) { + t.Fatalf("expected fetch error, got %v", err) + } +} + +func TestWaitForStatus_ConsecutiveSuccess(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + var toggle int32 + // alternate success / transitional until two consecutive successes happen + fr, err := WaitForStatus(ctx, func(context.Context) (FetchResult[string], error) { + n := atomic.AddInt32(&toggle, 1) + // Pattern: BUILDING, READY, READY, READY ... ensures at least two consecutive successes by third attempt + if n == 1 { + return FetchResult[string]{Status: "BUILDING", Value: "val"}, nil + } + return FetchResult[string]{Status: "READY", Value: "val"}, nil + }, Options[string]{ + Timeout: 750 * time.Millisecond, + SuccessStates: []Status{"READY"}, + TransitionalStates: []Status{"BUILDING"}, + ConsecutiveSuccess: 2, + Interval: FixedInterval(2 * time.Millisecond), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if fr.Status != "READY" { + t.Fatalf("expected READY, got %v", fr.Status) + } + if atomic.LoadInt32(&toggle) < 3 { // at least three fetches required (BUILDING, READY, READY) + t.Fatalf("expected multiple attempts, got %d", toggle) + } +} + +func TestWaitForStatus_ProgressSinkThrottling(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + var progressCalls int32 + var fetchCalls int32 + _, _ = WaitForStatus(ctx, func(context.Context) (FetchResult[int], error) { + atomic.AddInt32(&fetchCalls, 1) + if fetchCalls >= 5 { + return FetchResult[int]{Status: "DONE"}, nil + } + return FetchResult[int]{Status: "WORKING"}, nil + }, Options[int]{ + Timeout: 500 * time.Millisecond, + SuccessStates: []Status{"DONE"}, + TransitionalStates: []Status{"WORKING"}, + Interval: FixedInterval(5 * time.Millisecond), + ProgressInterval: 15 * time.Millisecond, // should group roughly 3 polls + ProgressSink: func(fr FetchResult[any], meta ProgressMeta) { + atomic.AddInt32(&progressCalls, 1) + if fr.Status != "WORKING" && fr.Status != "DONE" { + t.Fatalf("unexpected status in progress sink: %v", fr.Status) + } + if meta.NextPollIn <= 0 { + t.Fatalf("expected positive NextPollIn") + } + }, + }) + // With 5 fetch calls and 15ms progress vs 5ms poll, expect fewer progress events than fetches + if progressCalls <= 1 || progressCalls >= fetchCalls { + t.Fatalf("unexpected progress call count: %d (fetches %d)", progressCalls, fetchCalls) + } +} + +func TestWaitForStatus_ConsecutiveSuccessDefault(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + fr, err := WaitForStatus(ctx, func(context.Context) (FetchResult[struct{}], error) { + return FetchResult[struct{}]{Status: "READY"}, nil + }, Options[struct{}]{ + Timeout: 100 * time.Millisecond, + SuccessStates: []Status{"READY"}, + Interval: FixedInterval(fastFixedInterval), + // ConsecutiveSuccess left zero to trigger defaulting logic + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if fr.Status != "READY" { + t.Fatalf("unexpected status: %v", fr.Status) + } +} + +func TestWaitForStatus_ProgressSinkDisabled(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + var progressCalls int32 + _, err := WaitForStatus(ctx, func(context.Context) (FetchResult[int], error) { + return FetchResult[int]{Status: "DONE"}, nil + }, Options[int]{ + Timeout: 100 * time.Millisecond, + SuccessStates: []Status{"DONE"}, + Interval: FixedInterval(fastFixedInterval), + ProgressInterval: 0, // disabled + ProgressSink: func(FetchResult[any], ProgressMeta) { + atomic.AddInt32(&progressCalls, 1) + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if progressCalls != 0 { // should not be invoked when ProgressInterval <= 0 + t.Fatalf("expected zero progress sink calls, got %d", progressCalls) + } +} + +func TestWaitForStatus_UnexpectedStateErrorMessage(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + _, err := WaitForStatus(ctx, func(context.Context) (FetchResult[int], error) { + return FetchResult[int]{Status: "UNKNOWN"}, nil + }, Options[int]{ + Timeout: 200 * time.Millisecond, + SuccessStates: []Status{"OK"}, + TransitionalStates: []Status{"PENDING", "IN_PROGRESS"}, + Interval: FixedInterval(fastFixedInterval), + }) + if err == nil { + t.Fatal("expected unexpected state error") + } + var unexpectedErr *UnexpectedStateError + if !errors.As(err, &unexpectedErr) { + t.Fatalf("expected UnexpectedStateError, got %T", err) + } + errMsg := unexpectedErr.Error() + if !strings.Contains(errMsg, "UNKNOWN") { + t.Errorf("error message should contain status 'UNKNOWN', got: %s", errMsg) + } + if !strings.Contains(errMsg, "allowed:") { + t.Errorf("error message should list allowed states, got: %s", errMsg) + } + if !strings.Contains(errMsg, "PENDING") { + t.Errorf("error message should contain allowed state 'PENDING', got: %s", errMsg) + } +} + +func TestBackoffInterval(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + delay backoff.Delay + attempts []uint + expectedDurations []time.Duration + }{ + { + name: "fixed delay", + delay: backoff.FixedDelay(100 * time.Millisecond), + attempts: []uint{0, 1, 2, 3}, + expectedDurations: []time.Duration{0, 100 * time.Millisecond, 100 * time.Millisecond, 100 * time.Millisecond}, + }, + { + name: "zero delay", + delay: backoff.ZeroDelay, + attempts: []uint{0, 1, 2}, + expectedDurations: []time.Duration{0, 0, 0}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + interval := BackoffInterval{delay: tt.delay} + + for i, attempt := range tt.attempts { + got := interval.NextPoll(attempt) + want := tt.expectedDurations[i] + if got != want { + t.Errorf("NextPoll(%d) = %v, want %v", attempt, got, want) + } + } + }) + } +} + +func TestWithBackoffDelay(t *testing.T) { + t.Parallel() + + delay := backoff.FixedDelay(50 * time.Millisecond) + interval := WithBackoffDelay(delay) + + // Test that it wraps the delay correctly + if got := interval.NextPoll(0); got != 0 { + t.Errorf("NextPoll(0) = %v, want 0", got) + } + if got := interval.NextPoll(1); got != 50*time.Millisecond { + t.Errorf("NextPoll(1) = %v, want 50ms", got) + } +} + +func TestBackoffIntegration(t *testing.T) { + t.Parallel() + + ctx := makeCtx(t) + + var callCount atomic.Int32 + fetch := func(context.Context) (FetchResult[string], error) { + count := callCount.Add(1) + switch count { + case 1: + return FetchResult[string]{Status: "CREATING", Value: "attempt1"}, nil + case 2: + return FetchResult[string]{Status: "AVAILABLE", Value: "success"}, nil + default: + t.Errorf("unexpected call count: %d", count) + return FetchResult[string]{}, errors.New("too many calls") + } + } + + opts := Options[string]{ + Timeout: 2 * time.Second, + Interval: WithBackoffDelay(backoff.FixedDelay(fastFixedInterval)), + SuccessStates: []Status{"AVAILABLE"}, + TransitionalStates: []Status{"CREATING"}, + } + + result, err := WaitForStatus(ctx, fetch, opts) + if err != nil { + t.Fatalf("WaitForStatus() error = %v", err) + } + + if result.Status != "AVAILABLE" { + t.Errorf("result.Status = %q, want %q", result.Status, "AVAILABLE") + } + if result.Value != "success" { + t.Errorf("result.Value = %q, want %q", result.Value, "success") + } + if callCount.Load() != 2 { + t.Errorf("expected 2 fetch calls, got %d", callCount.Load()) + } +} diff --git a/internal/service/cloudfront/create_invalidation_action.go b/internal/service/cloudfront/create_invalidation_action.go index 0cd271486ef5..1f8275ce5dd1 100644 --- a/internal/service/cloudfront/create_invalidation_action.go +++ b/internal/service/cloudfront/create_invalidation_action.go @@ -5,8 +5,8 @@ package cloudfront import ( "context" + "errors" "fmt" - "slices" "time" "github.com/YakDriver/regexache" @@ -23,6 +23,7 @@ import ( "github.com/hashicorp/terraform-plugin-framework/types" "github.com/hashicorp/terraform-plugin-log/tflog" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/id" + "github.com/hashicorp/terraform-provider-aws/internal/actionwait" "github.com/hashicorp/terraform-provider-aws/internal/framework" fwtypes "github.com/hashicorp/terraform-provider-aws/internal/framework/types" "github.com/hashicorp/terraform-provider-aws/names" @@ -216,13 +217,51 @@ func (a *createInvalidationAction) Invoke(ctx context.Context, req action.Invoke Message: fmt.Sprintf("Invalidation %s created, waiting for completion...", invalidationID), }) - // Wait for invalidation to complete with periodic progress updates - err = a.waitForInvalidationComplete(ctx, conn, distributionID, invalidationID, timeout, resp) + // Wait for invalidation to complete with periodic progress updates using actionwait + // Use fixed interval since CloudFront invalidations have predictable timing and + // don't benefit from exponential backoff - status changes are infrequent and consistent + _, err = actionwait.WaitForStatus(ctx, func(ctx context.Context) (actionwait.FetchResult[struct{}], error) { + input := cloudfront.GetInvalidationInput{ + DistributionId: aws.String(distributionID), + Id: aws.String(invalidationID), + } + output, gerr := conn.GetInvalidation(ctx, &input) + if gerr != nil { + return actionwait.FetchResult[struct{}]{}, fmt.Errorf("getting invalidation status: %w", gerr) + } + status := aws.ToString(output.Invalidation.Status) + return actionwait.FetchResult[struct{}]{Status: actionwait.Status(status)}, nil + }, actionwait.Options[struct{}]{ + Timeout: timeout, + Interval: actionwait.FixedInterval(actionwait.DefaultPollInterval), + ProgressInterval: 60 * time.Second, + SuccessStates: []actionwait.Status{"Completed"}, + TransitionalStates: []actionwait.Status{ + "InProgress", + }, + ProgressSink: func(fr actionwait.FetchResult[any], meta actionwait.ProgressMeta) { + resp.SendProgress(action.InvokeProgressEvent{Message: fmt.Sprintf("Invalidation %s is currently '%s', continuing to wait for completion...", invalidationID, fr.Status)}) + }, + }) if err != nil { - resp.Diagnostics.AddError( - "Timeout Waiting for Invalidation to Complete", - fmt.Sprintf("CloudFront invalidation %s did not complete within %s: %s", invalidationID, timeout, err), - ) + var timeoutErr *actionwait.TimeoutError + var unexpectedErr *actionwait.UnexpectedStateError + if errors.As(err, &timeoutErr) { + resp.Diagnostics.AddError( + "Timeout Waiting for Invalidation to Complete", + fmt.Sprintf("CloudFront invalidation %s did not complete within %s: %s", invalidationID, timeout, err), + ) + } else if errors.As(err, &unexpectedErr) { + resp.Diagnostics.AddError( + "Invalid Invalidation State", + fmt.Sprintf("CloudFront invalidation %s entered unexpected state: %s", invalidationID, err), + ) + } else { + resp.Diagnostics.AddError( + "Failed While Waiting for Invalidation", + fmt.Sprintf("Error waiting for CloudFront invalidation %s: %s", invalidationID, err), + ) + } return } @@ -237,64 +276,3 @@ func (a *createInvalidationAction) Invoke(ctx context.Context, req action.Invoke "paths": paths, }) } - -// waitForInvalidationComplete waits for an invalidation to complete with progress updates -func (a *createInvalidationAction) waitForInvalidationComplete(ctx context.Context, conn *cloudfront.Client, distributionID, invalidationID string, timeout time.Duration, resp *action.InvokeResponse) error { - const ( - pollInterval = 30 * time.Second - progressInterval = 60 * time.Second - ) - - deadline := time.Now().Add(timeout) - lastProgressUpdate := time.Now() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - // Check if we've exceeded the timeout - if time.Now().After(deadline) { - return fmt.Errorf("timeout after %s", timeout) - } - - // Get current invalidation status - input := &cloudfront.GetInvalidationInput{ - DistributionId: aws.String(distributionID), - Id: aws.String(invalidationID), - } - - output, err := conn.GetInvalidation(ctx, input) - if err != nil { - return fmt.Errorf("getting invalidation status: %w", err) - } - - currentStatus := aws.ToString(output.Invalidation.Status) - - // Send progress update every 60 seconds - if time.Since(lastProgressUpdate) >= progressInterval { - resp.SendProgress(action.InvokeProgressEvent{ - Message: fmt.Sprintf("Invalidation %s is currently '%s', continuing to wait for completion...", invalidationID, currentStatus), - }) - lastProgressUpdate = time.Now() - } - - // Check if we've reached completion - if aws.ToString(output.Invalidation.Status) == "Completed" { - return nil - } - - // Check if we're in an unexpected state - validStatuses := []string{ - "InProgress", - } - if !slices.Contains(validStatuses, currentStatus) && currentStatus != "Completed" { - return fmt.Errorf("invalidation entered unexpected status: %s", currentStatus) - } - - // Wait before next poll - time.Sleep(pollInterval) - } -} diff --git a/internal/service/codebuild/start_build_action.go b/internal/service/codebuild/start_build_action.go index 9d823cb64ef4..a67229b96030 100644 --- a/internal/service/codebuild/start_build_action.go +++ b/internal/service/codebuild/start_build_action.go @@ -5,6 +5,8 @@ package codebuild import ( "context" + "errors" + "fmt" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -14,6 +16,8 @@ import ( "github.com/hashicorp/terraform-plugin-framework/action/schema" "github.com/hashicorp/terraform-plugin-framework/types" "github.com/hashicorp/terraform-plugin-log/tflog" + "github.com/hashicorp/terraform-provider-aws/internal/actionwait" + "github.com/hashicorp/terraform-provider-aws/internal/backoff" "github.com/hashicorp/terraform-provider-aws/internal/framework" fwflex "github.com/hashicorp/terraform-provider-aws/internal/framework/flex" fwtypes "github.com/hashicorp/terraform-provider-aws/internal/framework/types" @@ -131,65 +135,53 @@ func (a *startBuildAction) Invoke(ctx context.Context, req action.InvokeRequest, Message: "Build started, waiting for completion...", }) - // Poll for build completion - deadline := time.Now().Add(timeout) - pollInterval := 30 * time.Second - progressInterval := 2 * time.Minute - lastProgressUpdate := time.Now() - - for { - select { - case <-ctx.Done(): - resp.Diagnostics.AddError("Build monitoring cancelled", "Context was cancelled") - return - default: + // Poll for build completion using actionwait with backoff strategy + // Use backoff since builds can take a long time and status changes less frequently + // as the build progresses - start with frequent polling then back off + _, err = actionwait.WaitForStatus(ctx, func(ctx context.Context) (actionwait.FetchResult[*awstypes.Build], error) { + input := codebuild.BatchGetBuildsInput{Ids: []string{buildID}} + batch, berr := conn.BatchGetBuilds(ctx, &input) + if berr != nil { + return actionwait.FetchResult[*awstypes.Build]{}, berr } - - if time.Now().After(deadline) { - resp.Diagnostics.AddError("Build timeout", "Build did not complete within the specified timeout") - return - } - - input := codebuild.BatchGetBuildsInput{ - Ids: []string{buildID}, - } - batchGetBuildsOutput, err := conn.BatchGetBuilds(ctx, &input) - if err != nil { - resp.Diagnostics.AddError("Getting build status", err.Error()) - return - } - - if len(batchGetBuildsOutput.Builds) == 0 { - resp.Diagnostics.AddError("Build not found", "Build was not found in BatchGetBuilds response") - return - } - - build := batchGetBuildsOutput.Builds[0] - status := build.BuildStatus - - if time.Since(lastProgressUpdate) >= progressInterval { - resp.SendProgress(action.InvokeProgressEvent{ - Message: "Build currently in state: " + string(status), - }) - lastProgressUpdate = time.Now() + if len(batch.Builds) == 0 { + return actionwait.FetchResult[*awstypes.Build]{}, fmt.Errorf("build not found in BatchGetBuilds response") } - - switch status { - case awstypes.StatusTypeSucceeded: - resp.SendProgress(action.InvokeProgressEvent{ - Message: "Build completed successfully", - }) - return - case awstypes.StatusTypeFailed, awstypes.StatusTypeFault, awstypes.StatusTypeStopped, awstypes.StatusTypeTimedOut: - resp.Diagnostics.AddError("Build failed", "Build completed with status: "+string(status)) - return - case awstypes.StatusTypeInProgress: - // Continue polling - default: - resp.Diagnostics.AddError("Unexpected build status", "Received unexpected build status: "+string(status)) - return + b := batch.Builds[0] + return actionwait.FetchResult[*awstypes.Build]{Status: actionwait.Status(b.BuildStatus), Value: &b}, nil + }, actionwait.Options[*awstypes.Build]{ + Timeout: timeout, + Interval: actionwait.WithBackoffDelay(backoff.DefaultSDKv2HelperRetryCompatibleDelay()), + ProgressInterval: 2 * time.Minute, + SuccessStates: []actionwait.Status{actionwait.Status(awstypes.StatusTypeSucceeded)}, + TransitionalStates: []actionwait.Status{ + actionwait.Status(awstypes.StatusTypeInProgress), + }, + FailureStates: []actionwait.Status{ + actionwait.Status(awstypes.StatusTypeFailed), + actionwait.Status(awstypes.StatusTypeFault), + actionwait.Status(awstypes.StatusTypeStopped), + actionwait.Status(awstypes.StatusTypeTimedOut), + }, + ProgressSink: func(fr actionwait.FetchResult[any], meta actionwait.ProgressMeta) { + resp.SendProgress(action.InvokeProgressEvent{Message: "Build currently in state: " + string(fr.Status)}) + }, + }) + if err != nil { + var timeoutErr *actionwait.TimeoutError + var failureErr *actionwait.FailureStateError + var unexpectedErr *actionwait.UnexpectedStateError + if errors.As(err, &timeoutErr) { + resp.Diagnostics.AddError("Build timeout", "Build did not complete within the specified timeout") + } else if errors.As(err, &failureErr) { + resp.Diagnostics.AddError("Build failed", "Build completed with status: "+err.Error()) + } else if errors.As(err, &unexpectedErr) { + resp.Diagnostics.AddError("Unexpected build status", err.Error()) + } else { + resp.Diagnostics.AddError("Error waiting for build", err.Error()) } - - time.Sleep(pollInterval) + return } + + resp.SendProgress(action.InvokeProgressEvent{Message: "Build completed successfully"}) } diff --git a/internal/service/ec2/ec2_stop_instance_action.go b/internal/service/ec2/ec2_stop_instance_action.go index a9e9a84146a8..ae234daac1ac 100644 --- a/internal/service/ec2/ec2_stop_instance_action.go +++ b/internal/service/ec2/ec2_stop_instance_action.go @@ -5,8 +5,8 @@ package ec2 import ( "context" + "errors" "fmt" - "slices" "time" "github.com/YakDriver/regexache" @@ -21,10 +21,14 @@ import ( "github.com/hashicorp/terraform-plugin-framework/schema/validator" "github.com/hashicorp/terraform-plugin-framework/types" "github.com/hashicorp/terraform-plugin-log/tflog" + "github.com/hashicorp/terraform-provider-aws/internal/actionwait" "github.com/hashicorp/terraform-provider-aws/internal/framework" "github.com/hashicorp/terraform-provider-aws/names" ) +// stopInstancePollInterval defines polling cadence for stop instance action. +const stopInstancePollInterval = 10 * time.Second + // @Action(aws_ec2_stop_instance, name="Stop Instance") func newStopInstanceAction(_ context.Context) (action.ActionWithConfigure, error) { return &stopInstanceAction{}, nil @@ -180,13 +184,49 @@ func (a *stopInstanceAction) Invoke(ctx context.Context, req action.InvokeReques }) } - // Wait for instance to stop with periodic progress updates - err = a.waitForInstanceStopped(ctx, conn, instanceID, timeout, resp) + // Wait for instance to stop with periodic progress updates using actionwait + // Use fixed interval since EC2 instance state transitions are predictable and + // relatively quick - consistent polling every 10s is optimal for this operation + _, err = actionwait.WaitForStatus(ctx, func(ctx context.Context) (actionwait.FetchResult[struct{}], error) { + instance, derr := findInstanceByID(ctx, conn, instanceID) + if derr != nil { + return actionwait.FetchResult[struct{}]{}, fmt.Errorf("describing instance: %w", derr) + } + state := string(instance.State.Name) + return actionwait.FetchResult[struct{}]{Status: actionwait.Status(state)}, nil + }, actionwait.Options[struct{}]{ + Timeout: timeout, + Interval: actionwait.FixedInterval(stopInstancePollInterval), + ProgressInterval: 30 * time.Second, + SuccessStates: []actionwait.Status{actionwait.Status(awstypes.InstanceStateNameStopped)}, + TransitionalStates: []actionwait.Status{ + actionwait.Status(awstypes.InstanceStateNameRunning), + actionwait.Status(awstypes.InstanceStateNameStopping), + actionwait.Status(awstypes.InstanceStateNameShuttingDown), + }, + ProgressSink: func(fr actionwait.FetchResult[any], meta actionwait.ProgressMeta) { + resp.SendProgress(action.InvokeProgressEvent{Message: fmt.Sprintf("EC2 instance %s is currently in state '%s', continuing to wait for 'stopped'...", instanceID, fr.Status)}) + }, + }) if err != nil { - resp.Diagnostics.AddError( - "Timeout Waiting for Instance to Stop", - fmt.Sprintf("EC2 instance %s did not stop within %s: %s", instanceID, timeout, err), - ) + var timeoutErr *actionwait.TimeoutError + var unexpectedErr *actionwait.UnexpectedStateError + if errors.As(err, &timeoutErr) { + resp.Diagnostics.AddError( + "Timeout Waiting for Instance to Stop", + fmt.Sprintf("EC2 instance %s did not stop within %s: %s", instanceID, timeout, err), + ) + } else if errors.As(err, &unexpectedErr) { + resp.Diagnostics.AddError( + "Unexpected Instance State", + fmt.Sprintf("EC2 instance %s entered unexpected state while stopping: %s", instanceID, err), + ) + } else { + resp.Diagnostics.AddError( + "Error Waiting for Instance to Stop", + fmt.Sprintf("Error while waiting for EC2 instance %s to stop: %s", instanceID, err), + ) + } return } @@ -209,61 +249,3 @@ func canStopInstance(state awstypes.InstanceStateName) bool { return false } } - -// waitForInstanceStopped waits for an instance to reach the stopped state with progress updates -func (a *stopInstanceAction) waitForInstanceStopped(ctx context.Context, conn *ec2.Client, instanceID string, timeout time.Duration, resp *action.InvokeResponse) error { - const ( - pollInterval = 10 * time.Second - progressInterval = 30 * time.Second - ) - - deadline := time.Now().Add(timeout) - lastProgressUpdate := time.Now() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - // Check if we've exceeded the timeout - if time.Now().After(deadline) { - return fmt.Errorf("timeout after %s", timeout) - } - - // Get current instance state - instance, err := findInstanceByID(ctx, conn, instanceID) - if err != nil { - return fmt.Errorf("describing instance: %w", err) - } - - currentState := string(instance.State.Name) - - // Send progress update every 30 seconds - if time.Since(lastProgressUpdate) >= progressInterval { - resp.SendProgress(action.InvokeProgressEvent{ - Message: fmt.Sprintf("EC2 instance %s is currently in state '%s', continuing to wait for 'stopped'...", instanceID, currentState), - }) - lastProgressUpdate = time.Now() - } - - // Check if we've reached the target state - if instance.State.Name == awstypes.InstanceStateNameStopped { - return nil - } - - // Check if we're in an unexpected state - validStates := []awstypes.InstanceStateName{ - awstypes.InstanceStateNameRunning, - awstypes.InstanceStateNameStopping, - awstypes.InstanceStateNameShuttingDown, - } - if !slices.Contains(validStates, instance.State.Name) { - return fmt.Errorf("instance entered unexpected state: %s", currentState) - } - - // Wait before next poll - time.Sleep(pollInterval) - } -}