From 5035e95ad381c5722e9549814ce386ba7754e30e Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Thu, 25 Sep 2025 21:00:08 -0700 Subject: [PATCH 01/17] actionwait: add polling library for actions --- internal/actionwait/wait.go | 207 ++++++++++++++++++++++ internal/actionwait/wait_test.go | 291 +++++++++++++++++++++++++++++++ 2 files changed, 498 insertions(+) create mode 100644 internal/actionwait/wait.go create mode 100644 internal/actionwait/wait_test.go diff --git a/internal/actionwait/wait.go b/internal/actionwait/wait.go new file mode 100644 index 000000000000..23cbaeeb8cb9 --- /dev/null +++ b/internal/actionwait/wait.go @@ -0,0 +1,207 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +// Package actionwait provides a lightweight, action-focused polling helper +// for imperative Terraform actions which need to await asynchronous AWS +// operation completion with periodic user progress events. +package actionwait + +import ( + "context" + "errors" + "slices" + "time" +) + +// Status represents a string status value returned from a polled API. +type Status string + +// FetchResult wraps the latest status (and optional value) from a poll attempt. +// Value may be a richer SDK structure (pointer) or zero for simple cases. +type FetchResult[T any] struct { + Status Status + Value T +} + +// FetchFunc retrieves the latest state of an asynchronous operation. +// It should be side-effect free aside from the remote read. +type FetchFunc[T any] func(context.Context) (FetchResult[T], error) + +// IntervalStrategy allows pluggable poll interval behavior (fixed, backoff, etc.). +type IntervalStrategy interface { // nolint:interfacebloat // single method interface + NextPoll(attempt uint) time.Duration +} + +// FixedInterval implements IntervalStrategy with a constant delay. +type FixedInterval time.Duration + +// NextPoll returns the fixed duration. +func (fi FixedInterval) NextPoll(uint) time.Duration { return time.Duration(fi) } + +// Options configure the WaitForStatus loop. +type Options[T any] struct { + Timeout time.Duration // Required total timeout. + Interval IntervalStrategy // Poll interval strategy (default: 30s fixed). + ProgressInterval time.Duration // Throttle for ProgressSink (default: disabled if <=0). + SuccessStates []Status // Required (>=1) terminal success states. + TransitionalStates []Status // Optional allowed in-flight states. + FailureStates []Status // Optional explicit failure states. + ConsecutiveSuccess int // Number of consecutive successes required (default 1). + ProgressSink func(fr FetchResult[any], meta ProgressMeta) +} + +// ProgressMeta supplies metadata for progress callbacks. +type ProgressMeta struct { + Attempt uint + Elapsed time.Duration + Remaining time.Duration + Deadline time.Time + NextPollIn time.Duration +} + +// ErrTimeout is returned when the operation does not reach a success state within Timeout. +type ErrTimeout struct { + LastStatus Status + Timeout time.Duration +} + +func (e *ErrTimeout) Error() string { + return "timeout waiting for target status after " + e.Timeout.String() +} + +// ErrFailureState indicates the operation entered a declared failure state. +type ErrFailureState struct { + Status Status +} + +func (e *ErrFailureState) Error() string { + return "operation entered failure state: " + string(e.Status) +} + +// ErrUnexpectedState indicates the operation entered a state outside success/transitional/failure sets. +type ErrUnexpectedState struct { + Status Status + Allowed []Status +} + +func (e *ErrUnexpectedState) Error() string { + return "operation entered unexpected state: " + string(e.Status) +} + +// sentinel errors helpers +var ( + _ error = (*ErrTimeout)(nil) + _ error = (*ErrFailureState)(nil) + _ error = (*ErrUnexpectedState)(nil) +) + +// WaitForStatus polls using fetch until a success state, failure state, timeout, unexpected state, +// context cancellation, or fetch error occurs. +// On success, the final FetchResult is returned with nil error. +func WaitForStatus[T any](ctx context.Context, fetch FetchFunc[T], opts Options[T]) (FetchResult[T], error) { // nolint:cyclop + var zero FetchResult[T] + + if opts.Timeout <= 0 { + return zero, errors.New("actionwait: Timeout must be > 0") + } + if len(opts.SuccessStates) == 0 { + return zero, errors.New("actionwait: at least one SuccessState required") + } + if opts.ConsecutiveSuccess <= 0 { + opts.ConsecutiveSuccess = 1 + } + if opts.Interval == nil { + opts.Interval = FixedInterval(30 * time.Second) + } + + start := time.Now() + deadline := start.Add(opts.Timeout) + var lastProgress time.Time + var attempt uint + var successStreak int + var last FetchResult[T] + + // Precompute allowed states for unexpected classification (success + transitional + failure) + // Failure states are excluded from Allowed to ensure they classify distinctly. + allowedTransient := append([]Status{}, opts.SuccessStates...) + allowedTransient = append(allowedTransient, opts.TransitionalStates...) + + for { + if ctx.Err() != nil { + return last, ctx.Err() + } + now := time.Now() + if now.After(deadline) { + return last, &ErrTimeout{LastStatus: last.Status, Timeout: opts.Timeout} + } + + fr, err := fetch(ctx) + if err != nil { + return fr, err + } + last = fr + + // Classification precedence: failure -> success -> transitional -> unexpected + if contains(opts.FailureStates, fr.Status) { + return fr, &ErrFailureState{Status: fr.Status} + } + if contains(opts.SuccessStates, fr.Status) { + successStreak++ + if successStreak >= opts.ConsecutiveSuccess { + return fr, nil + } + } else { + successStreak = 0 + if len(opts.TransitionalStates) > 0 { + if !contains(opts.TransitionalStates, fr.Status) { + return fr, &ErrUnexpectedState{Status: fr.Status, Allowed: allowedTransient} + } + } + } + + // Progress callback throttling + if opts.ProgressSink != nil && opts.ProgressInterval > 0 { + if lastProgress.IsZero() || time.Since(lastProgress) >= opts.ProgressInterval { + nextPoll := opts.Interval.NextPoll(attempt) + opts.ProgressSink(anyFetchResult(fr), ProgressMeta{ + Attempt: attempt, + Elapsed: time.Since(start), + Remaining: maxDuration(0, time.Until(deadline)), // time.Until for clarity + Deadline: deadline, + NextPollIn: nextPoll, + }) + lastProgress = time.Now() + } + } + + // Sleep until next attempt + sleep := opts.Interval.NextPoll(attempt) + if sleep > 0 { + timer := time.NewTimer(sleep) + select { + case <-ctx.Done(): + timer.Stop() + return last, ctx.Err() + case <-timer.C: + } + } + attempt++ + } +} + +// anyFetchResult converts a typed FetchResult[T] into FetchResult[any] for ProgressSink. +func anyFetchResult[T any](fr FetchResult[T]) FetchResult[any] { + return FetchResult[any]{Status: fr.Status, Value: any(fr.Value)} +} + +// contains tests membership in a slice of Status. +func contains(haystack []Status, needle Status) bool { + return slices.Contains(haystack, needle) +} + +func maxDuration(a, b time.Duration) time.Duration { + if a > b { + return a + } + return b +} diff --git a/internal/actionwait/wait_test.go b/internal/actionwait/wait_test.go new file mode 100644 index 000000000000..3799737ddcca --- /dev/null +++ b/internal/actionwait/wait_test.go @@ -0,0 +1,291 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package actionwait + +import ( + "context" + "errors" + "sync/atomic" + "testing" + "time" +) + +// fastFixedInterval returns a very small fixed interval to speed tests. +const fastFixedInterval = 5 * time.Millisecond + +// makeCtx creates a context with generous overall test timeout safeguard. +func makeCtx(t *testing.T) context.Context { //nolint:revive // test helper + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + t.Cleanup(cancel) + return ctx +} + +func TestWaitForStatus_ValidationErrors(t *testing.T) { + // No t.Parallel here since we rely on wall clock; subtests parallelized individually. + cases := map[string]Options[struct{}]{ + "missing timeout": {SuccessStates: []Status{"ok"}}, + "missing success": {Timeout: time.Second}, + } + + for name, opts := range cases { + opts := opts + //nolint:paralleltest // simple and quick + t.Run(name, func(t *testing.T) { + ctx := makeCtx(t) + _, err := WaitForStatus(ctx, func(context.Context) (FetchResult[struct{}], error) { + return FetchResult[struct{}]{Status: "irrelevant"}, nil + }, opts) + if err == nil { + t.Fatalf("expected validation error") + } + }) + } +} + +func TestWaitForStatus_SuccessImmediate(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + fr, err := WaitForStatus(ctx, func(context.Context) (FetchResult[int], error) { + return FetchResult[int]{Status: "DONE", Value: 42}, nil + }, Options[int]{ + Timeout: 250 * time.Millisecond, + SuccessStates: []Status{"DONE"}, + Interval: FixedInterval(fastFixedInterval), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if fr.Value != 42 || fr.Status != "DONE" { + t.Fatalf("unexpected result: %#v", fr) + } +} + +func TestWaitForStatus_SuccessAfterTransitions(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + var calls int32 + fr, err := WaitForStatus(ctx, func(context.Context) (FetchResult[string], error) { + c := atomic.AddInt32(&calls, 1) + switch c { + case 1, 2: + return FetchResult[string]{Status: "IN_PROGRESS", Value: "step"}, nil + default: + return FetchResult[string]{Status: "COMPLETE", Value: "done"}, nil + } + }, Options[string]{ + Timeout: 500 * time.Millisecond, + SuccessStates: []Status{"COMPLETE"}, + TransitionalStates: []Status{"IN_PROGRESS"}, + Interval: FixedInterval(fastFixedInterval), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if fr.Status != "COMPLETE" || fr.Value != "done" { + t.Fatalf("unexpected final result: %#v", fr) + } +} + +func TestWaitForStatus_FailureState(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + fr, err := WaitForStatus(ctx, func(context.Context) (FetchResult[struct{}], error) { + return FetchResult[struct{}]{Status: "FAILED"}, nil + }, Options[struct{}]{ + Timeout: 200 * time.Millisecond, + SuccessStates: []Status{"SUCCEEDED"}, + FailureStates: []Status{"FAILED"}, + Interval: FixedInterval(fastFixedInterval), + }) + if err == nil { + t.Fatal("expected failure error") + } + if _, ok := err.(*ErrFailureState); !ok { //nolint:errorlint + t.Fatalf("expected ErrFailureState, got %T", err) + } + if fr.Status != "FAILED" { + t.Fatalf("unexpected status: %v", fr.Status) + } +} + +func TestWaitForStatus_UnexpectedState_WithTransitional(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + _, err := WaitForStatus(ctx, func(context.Context) (FetchResult[int], error) { + return FetchResult[int]{Status: "UNKNOWN"}, nil + }, Options[int]{ + Timeout: 200 * time.Millisecond, + SuccessStates: []Status{"OK"}, + TransitionalStates: []Status{"PENDING"}, + Interval: FixedInterval(fastFixedInterval), + }) + if err == nil { + t.Fatal("expected unexpected state error") + } + if _, ok := err.(*ErrUnexpectedState); !ok { //nolint:errorlint + t.Fatalf("expected ErrUnexpectedState, got %T", err) + } +} + +func TestWaitForStatus_NoTransitionalListAllowsAnyUntilTimeout(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + start := time.Now() + _, err := WaitForStatus(ctx, func(context.Context) (FetchResult[struct{}], error) { + return FetchResult[struct{}]{Status: "WHATEVER"}, nil + }, Options[struct{}]{ + Timeout: 50 * time.Millisecond, + SuccessStates: []Status{"DONE"}, + Interval: FixedInterval(10 * time.Millisecond), + }) + if err == nil { + t.Fatal("expected timeout error") + } + if _, ok := err.(*ErrTimeout); !ok { //nolint:errorlint + t.Fatalf("expected ErrTimeout, got %T", err) + } + if time.Since(start) < 40*time.Millisecond { // sanity that we actually waited + t.Fatalf("timeout returned too early") + } +} + +func TestWaitForStatus_ContextCancel(t *testing.T) { + t.Parallel() + ctx, cancel := context.WithCancel(makeCtx(t)) + go func() { + time.Sleep(20 * time.Millisecond) + cancel() + }() + _, err := WaitForStatus(ctx, func(context.Context) (FetchResult[struct{}], error) { + return FetchResult[struct{}]{Status: "PENDING"}, nil + }, Options[struct{}]{ + Timeout: 500 * time.Millisecond, + SuccessStates: []Status{"DONE"}, + Interval: FixedInterval(fastFixedInterval), + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("expected context.Canceled, got %v", err) + } +} + +func TestWaitForStatus_FetchErrorPropagation(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + testErr := errors.New("boom") + _, err := WaitForStatus(ctx, func(context.Context) (FetchResult[int], error) { + return FetchResult[int]{}, testErr + }, Options[int]{ + Timeout: 200 * time.Millisecond, + SuccessStates: []Status{"OK"}, + Interval: FixedInterval(fastFixedInterval), + }) + if !errors.Is(err, testErr) { + t.Fatalf("expected fetch error, got %v", err) + } +} + +func TestWaitForStatus_ConsecutiveSuccess(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + var toggle int32 + // alternate success / transitional until two consecutive successes happen + fr, err := WaitForStatus(ctx, func(context.Context) (FetchResult[string], error) { + n := atomic.AddInt32(&toggle, 1) + // Pattern: BUILDING, READY, READY, READY ... ensures at least two consecutive successes by third attempt + if n == 1 { + return FetchResult[string]{Status: "BUILDING", Value: "val"}, nil + } + return FetchResult[string]{Status: "READY", Value: "val"}, nil + }, Options[string]{ + Timeout: 750 * time.Millisecond, + SuccessStates: []Status{"READY"}, + TransitionalStates: []Status{"BUILDING"}, + ConsecutiveSuccess: 2, + Interval: FixedInterval(2 * time.Millisecond), + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if fr.Status != "READY" { + t.Fatalf("expected READY, got %v", fr.Status) + } + if atomic.LoadInt32(&toggle) < 3 { // at least three fetches required (BUILDING, READY, READY) + t.Fatalf("expected multiple attempts, got %d", toggle) + } +} + +func TestWaitForStatus_ProgressSinkThrottling(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + var progressCalls int32 + var fetchCalls int32 + _, _ = WaitForStatus(ctx, func(context.Context) (FetchResult[int], error) { + atomic.AddInt32(&fetchCalls, 1) + if fetchCalls >= 5 { + return FetchResult[int]{Status: "DONE"}, nil + } + return FetchResult[int]{Status: "WORKING"}, nil + }, Options[int]{ + Timeout: 500 * time.Millisecond, + SuccessStates: []Status{"DONE"}, + TransitionalStates: []Status{"WORKING"}, + Interval: FixedInterval(5 * time.Millisecond), + ProgressInterval: 15 * time.Millisecond, // should group roughly 3 polls + ProgressSink: func(fr FetchResult[any], meta ProgressMeta) { + atomic.AddInt32(&progressCalls, 1) + if fr.Status != "WORKING" && fr.Status != "DONE" { + t.Fatalf("unexpected status in progress sink: %v", fr.Status) + } + if meta.NextPollIn <= 0 { + t.Fatalf("expected positive NextPollIn") + } + }, + }) + // With 5 fetch calls and 15ms progress vs 5ms poll, expect fewer progress events than fetches + if progressCalls <= 1 || progressCalls >= fetchCalls { + t.Fatalf("unexpected progress call count: %d (fetches %d)", progressCalls, fetchCalls) + } +} + +func TestWaitForStatus_ConsecutiveSuccessDefault(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + fr, err := WaitForStatus(ctx, func(context.Context) (FetchResult[struct{}], error) { + return FetchResult[struct{}]{Status: "READY"}, nil + }, Options[struct{}]{ + Timeout: 100 * time.Millisecond, + SuccessStates: []Status{"READY"}, + Interval: FixedInterval(fastFixedInterval), + // ConsecutiveSuccess left zero to trigger defaulting logic + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if fr.Status != "READY" { + t.Fatalf("unexpected status: %v", fr.Status) + } +} + +func TestWaitForStatus_ProgressSinkDisabled(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + var progressCalls int32 + _, err := WaitForStatus(ctx, func(context.Context) (FetchResult[int], error) { + return FetchResult[int]{Status: "DONE"}, nil + }, Options[int]{ + Timeout: 100 * time.Millisecond, + SuccessStates: []Status{"DONE"}, + Interval: FixedInterval(fastFixedInterval), + ProgressInterval: 0, // disabled + ProgressSink: func(FetchResult[any], ProgressMeta) { + atomic.AddInt32(&progressCalls, 1) + }, + }) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if progressCalls != 0 { // should not be invoked when ProgressInterval <= 0 + t.Fatalf("expected zero progress sink calls, got %d", progressCalls) + } +} From cf036b9bb9336fb39fdcdfd210ded148b876076a Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Thu, 25 Sep 2025 21:01:50 -0700 Subject: [PATCH 02/17] Update stop instance action to use actionwait --- .../service/ec2/ec2_stop_instance_action.go | 105 +++++++----------- 1 file changed, 41 insertions(+), 64 deletions(-) diff --git a/internal/service/ec2/ec2_stop_instance_action.go b/internal/service/ec2/ec2_stop_instance_action.go index a9e9a84146a8..a390c76fd2c8 100644 --- a/internal/service/ec2/ec2_stop_instance_action.go +++ b/internal/service/ec2/ec2_stop_instance_action.go @@ -6,7 +6,6 @@ package ec2 import ( "context" "fmt" - "slices" "time" "github.com/YakDriver/regexache" @@ -21,6 +20,7 @@ import ( "github.com/hashicorp/terraform-plugin-framework/schema/validator" "github.com/hashicorp/terraform-plugin-framework/types" "github.com/hashicorp/terraform-plugin-log/tflog" + "github.com/hashicorp/terraform-provider-aws/internal/actionwait" "github.com/hashicorp/terraform-provider-aws/internal/framework" "github.com/hashicorp/terraform-provider-aws/names" ) @@ -180,13 +180,46 @@ func (a *stopInstanceAction) Invoke(ctx context.Context, req action.InvokeReques }) } - // Wait for instance to stop with periodic progress updates - err = a.waitForInstanceStopped(ctx, conn, instanceID, timeout, resp) + // Wait for instance to stop with periodic progress updates using actionwait + _, err = actionwait.WaitForStatus(ctx, func(ctx context.Context) (actionwait.FetchResult[struct{}], error) { + instance, derr := findInstanceByID(ctx, conn, instanceID) + if derr != nil { + return actionwait.FetchResult[struct{}]{}, fmt.Errorf("describing instance: %w", derr) + } + state := string(instance.State.Name) + return actionwait.FetchResult[struct{}]{Status: actionwait.Status(state)}, nil + }, actionwait.Options[struct{}]{ + Timeout: timeout, + Interval: actionwait.FixedInterval(10 * time.Second), + ProgressInterval: 30 * time.Second, + SuccessStates: []actionwait.Status{actionwait.Status(awstypes.InstanceStateNameStopped)}, + TransitionalStates: []actionwait.Status{ + actionwait.Status(awstypes.InstanceStateNameRunning), + actionwait.Status(awstypes.InstanceStateNameStopping), + actionwait.Status(awstypes.InstanceStateNameShuttingDown), + }, + ProgressSink: func(fr actionwait.FetchResult[any], meta actionwait.ProgressMeta) { + resp.SendProgress(action.InvokeProgressEvent{Message: fmt.Sprintf("EC2 instance %s is currently in state '%s', continuing to wait for 'stopped'...", instanceID, fr.Status)}) + }, + }) if err != nil { - resp.Diagnostics.AddError( - "Timeout Waiting for Instance to Stop", - fmt.Sprintf("EC2 instance %s did not stop within %s: %s", instanceID, timeout, err), - ) + switch err.(type) { + case *actionwait.ErrTimeout: + resp.Diagnostics.AddError( + "Timeout Waiting for Instance to Stop", + fmt.Sprintf("EC2 instance %s did not stop within %s: %s", instanceID, timeout, err), + ) + case *actionwait.ErrUnexpectedState: + resp.Diagnostics.AddError( + "Unexpected Instance State", + fmt.Sprintf("EC2 instance %s entered unexpected state while stopping: %s", instanceID, err), + ) + default: + resp.Diagnostics.AddError( + "Error Waiting for Instance to Stop", + fmt.Sprintf("Error while waiting for EC2 instance %s to stop: %s", instanceID, err), + ) + } return } @@ -210,60 +243,4 @@ func canStopInstance(state awstypes.InstanceStateName) bool { } } -// waitForInstanceStopped waits for an instance to reach the stopped state with progress updates -func (a *stopInstanceAction) waitForInstanceStopped(ctx context.Context, conn *ec2.Client, instanceID string, timeout time.Duration, resp *action.InvokeResponse) error { - const ( - pollInterval = 10 * time.Second - progressInterval = 30 * time.Second - ) - - deadline := time.Now().Add(timeout) - lastProgressUpdate := time.Now() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - // Check if we've exceeded the timeout - if time.Now().After(deadline) { - return fmt.Errorf("timeout after %s", timeout) - } - - // Get current instance state - instance, err := findInstanceByID(ctx, conn, instanceID) - if err != nil { - return fmt.Errorf("describing instance: %w", err) - } - - currentState := string(instance.State.Name) - - // Send progress update every 30 seconds - if time.Since(lastProgressUpdate) >= progressInterval { - resp.SendProgress(action.InvokeProgressEvent{ - Message: fmt.Sprintf("EC2 instance %s is currently in state '%s', continuing to wait for 'stopped'...", instanceID, currentState), - }) - lastProgressUpdate = time.Now() - } - - // Check if we've reached the target state - if instance.State.Name == awstypes.InstanceStateNameStopped { - return nil - } - - // Check if we're in an unexpected state - validStates := []awstypes.InstanceStateName{ - awstypes.InstanceStateNameRunning, - awstypes.InstanceStateNameStopping, - awstypes.InstanceStateNameShuttingDown, - } - if !slices.Contains(validStates, instance.State.Name) { - return fmt.Errorf("instance entered unexpected state: %s", currentState) - } - - // Wait before next poll - time.Sleep(pollInterval) - } -} +// Legacy polling helper removed; replaced with actionwait. From 1a770687f9822173118d5801e820c8df128f1301 Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Thu, 25 Sep 2025 21:02:18 -0700 Subject: [PATCH 03/17] Update start build action to use actionwait --- .../service/codebuild/start_build_action.go | 97 ++++++++----------- 1 file changed, 41 insertions(+), 56 deletions(-) diff --git a/internal/service/codebuild/start_build_action.go b/internal/service/codebuild/start_build_action.go index 9d823cb64ef4..ca3f0c9e45da 100644 --- a/internal/service/codebuild/start_build_action.go +++ b/internal/service/codebuild/start_build_action.go @@ -5,6 +5,7 @@ package codebuild import ( "context" + "fmt" "time" "github.com/aws/aws-sdk-go-v2/aws" @@ -14,6 +15,7 @@ import ( "github.com/hashicorp/terraform-plugin-framework/action/schema" "github.com/hashicorp/terraform-plugin-framework/types" "github.com/hashicorp/terraform-plugin-log/tflog" + "github.com/hashicorp/terraform-provider-aws/internal/actionwait" "github.com/hashicorp/terraform-provider-aws/internal/framework" fwflex "github.com/hashicorp/terraform-provider-aws/internal/framework/flex" fwtypes "github.com/hashicorp/terraform-provider-aws/internal/framework/types" @@ -131,65 +133,48 @@ func (a *startBuildAction) Invoke(ctx context.Context, req action.InvokeRequest, Message: "Build started, waiting for completion...", }) - // Poll for build completion - deadline := time.Now().Add(timeout) - pollInterval := 30 * time.Second - progressInterval := 2 * time.Minute - lastProgressUpdate := time.Now() - - for { - select { - case <-ctx.Done(): - resp.Diagnostics.AddError("Build monitoring cancelled", "Context was cancelled") - return - default: - } - - if time.Now().After(deadline) { - resp.Diagnostics.AddError("Build timeout", "Build did not complete within the specified timeout") - return - } - - input := codebuild.BatchGetBuildsInput{ - Ids: []string{buildID}, + // Poll for build completion using actionwait + _, err = actionwait.WaitForStatus(ctx, func(ctx context.Context) (actionwait.FetchResult[*awstypes.Build], error) { + batch, berr := conn.BatchGetBuilds(ctx, &codebuild.BatchGetBuildsInput{Ids: []string{buildID}}) + if berr != nil { + return actionwait.FetchResult[*awstypes.Build]{}, berr } - batchGetBuildsOutput, err := conn.BatchGetBuilds(ctx, &input) - if err != nil { - resp.Diagnostics.AddError("Getting build status", err.Error()) - return - } - - if len(batchGetBuildsOutput.Builds) == 0 { - resp.Diagnostics.AddError("Build not found", "Build was not found in BatchGetBuilds response") - return - } - - build := batchGetBuildsOutput.Builds[0] - status := build.BuildStatus - - if time.Since(lastProgressUpdate) >= progressInterval { - resp.SendProgress(action.InvokeProgressEvent{ - Message: "Build currently in state: " + string(status), - }) - lastProgressUpdate = time.Now() + if len(batch.Builds) == 0 { + return actionwait.FetchResult[*awstypes.Build]{}, fmt.Errorf("build not found in BatchGetBuilds response") } - - switch status { - case awstypes.StatusTypeSucceeded: - resp.SendProgress(action.InvokeProgressEvent{ - Message: "Build completed successfully", - }) - return - case awstypes.StatusTypeFailed, awstypes.StatusTypeFault, awstypes.StatusTypeStopped, awstypes.StatusTypeTimedOut: - resp.Diagnostics.AddError("Build failed", "Build completed with status: "+string(status)) - return - case awstypes.StatusTypeInProgress: - // Continue polling + b := batch.Builds[0] + return actionwait.FetchResult[*awstypes.Build]{Status: actionwait.Status(b.BuildStatus), Value: &b}, nil + }, actionwait.Options[*awstypes.Build]{ + Timeout: timeout, + Interval: actionwait.FixedInterval(30 * time.Second), + ProgressInterval: 2 * time.Minute, + SuccessStates: []actionwait.Status{actionwait.Status(awstypes.StatusTypeSucceeded)}, + TransitionalStates: []actionwait.Status{ + actionwait.Status(awstypes.StatusTypeInProgress), + }, + FailureStates: []actionwait.Status{ + actionwait.Status(awstypes.StatusTypeFailed), + actionwait.Status(awstypes.StatusTypeFault), + actionwait.Status(awstypes.StatusTypeStopped), + actionwait.Status(awstypes.StatusTypeTimedOut), + }, + ProgressSink: func(fr actionwait.FetchResult[any], meta actionwait.ProgressMeta) { + resp.SendProgress(action.InvokeProgressEvent{Message: "Build currently in state: " + string(fr.Status)}) + }, + }) + if err != nil { + switch err.(type) { + case *actionwait.ErrTimeout: + resp.Diagnostics.AddError("Build timeout", "Build did not complete within the specified timeout") + case *actionwait.ErrFailureState: + resp.Diagnostics.AddError("Build failed", "Build completed with status: "+err.Error()) + case *actionwait.ErrUnexpectedState: + resp.Diagnostics.AddError("Unexpected build status", err.Error()) default: - resp.Diagnostics.AddError("Unexpected build status", "Received unexpected build status: "+string(status)) - return + resp.Diagnostics.AddError("Error waiting for build", err.Error()) } - - time.Sleep(pollInterval) + return } + + resp.SendProgress(action.InvokeProgressEvent{Message: "Build completed successfully"}) } From 94cc26367f0e744a918ab791e0ea0f088250f9a1 Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Thu, 25 Sep 2025 21:02:34 -0700 Subject: [PATCH 04/17] Update create invalidation action to use actionwait --- .../cloudfront/create_invalidation_action.go | 109 +++++++----------- 1 file changed, 42 insertions(+), 67 deletions(-) diff --git a/internal/service/cloudfront/create_invalidation_action.go b/internal/service/cloudfront/create_invalidation_action.go index 0cd271486ef5..48e553d02cf1 100644 --- a/internal/service/cloudfront/create_invalidation_action.go +++ b/internal/service/cloudfront/create_invalidation_action.go @@ -6,7 +6,6 @@ package cloudfront import ( "context" "fmt" - "slices" "time" "github.com/YakDriver/regexache" @@ -23,6 +22,7 @@ import ( "github.com/hashicorp/terraform-plugin-framework/types" "github.com/hashicorp/terraform-plugin-log/tflog" "github.com/hashicorp/terraform-plugin-sdk/v2/helper/id" + "github.com/hashicorp/terraform-provider-aws/internal/actionwait" "github.com/hashicorp/terraform-provider-aws/internal/framework" fwtypes "github.com/hashicorp/terraform-provider-aws/internal/framework/types" "github.com/hashicorp/terraform-provider-aws/names" @@ -216,13 +216,47 @@ func (a *createInvalidationAction) Invoke(ctx context.Context, req action.Invoke Message: fmt.Sprintf("Invalidation %s created, waiting for completion...", invalidationID), }) - // Wait for invalidation to complete with periodic progress updates - err = a.waitForInvalidationComplete(ctx, conn, distributionID, invalidationID, timeout, resp) + // Wait for invalidation to complete with periodic progress updates using actionwait + _, err = actionwait.WaitForStatus(ctx, func(ctx context.Context) (actionwait.FetchResult[struct{}], error) { + output, gerr := conn.GetInvalidation(ctx, &cloudfront.GetInvalidationInput{ + DistributionId: aws.String(distributionID), + Id: aws.String(invalidationID), + }) + if gerr != nil { + return actionwait.FetchResult[struct{}]{}, fmt.Errorf("getting invalidation status: %w", gerr) + } + status := aws.ToString(output.Invalidation.Status) + return actionwait.FetchResult[struct{}]{Status: actionwait.Status(status)}, nil + }, actionwait.Options[struct{}]{ + Timeout: timeout, + Interval: actionwait.FixedInterval(30 * time.Second), + ProgressInterval: 60 * time.Second, + SuccessStates: []actionwait.Status{"Completed"}, + TransitionalStates: []actionwait.Status{ + "InProgress", + }, + ProgressSink: func(fr actionwait.FetchResult[any], meta actionwait.ProgressMeta) { + resp.SendProgress(action.InvokeProgressEvent{Message: fmt.Sprintf("Invalidation %s is currently '%s', continuing to wait for completion...", invalidationID, fr.Status)}) + }, + }) if err != nil { - resp.Diagnostics.AddError( - "Timeout Waiting for Invalidation to Complete", - fmt.Sprintf("CloudFront invalidation %s did not complete within %s: %s", invalidationID, timeout, err), - ) + switch err.(type) { + case *actionwait.ErrTimeout: + resp.Diagnostics.AddError( + "Timeout Waiting for Invalidation to Complete", + fmt.Sprintf("CloudFront invalidation %s did not complete within %s: %s", invalidationID, timeout, err), + ) + case *actionwait.ErrUnexpectedState: + resp.Diagnostics.AddError( + "Invalid Invalidation State", + fmt.Sprintf("CloudFront invalidation %s entered unexpected state: %s", invalidationID, err), + ) + default: + resp.Diagnostics.AddError( + "Failed While Waiting for Invalidation", + fmt.Sprintf("Error waiting for CloudFront invalidation %s: %s", invalidationID, err), + ) + } return } @@ -238,63 +272,4 @@ func (a *createInvalidationAction) Invoke(ctx context.Context, req action.Invoke }) } -// waitForInvalidationComplete waits for an invalidation to complete with progress updates -func (a *createInvalidationAction) waitForInvalidationComplete(ctx context.Context, conn *cloudfront.Client, distributionID, invalidationID string, timeout time.Duration, resp *action.InvokeResponse) error { - const ( - pollInterval = 30 * time.Second - progressInterval = 60 * time.Second - ) - - deadline := time.Now().Add(timeout) - lastProgressUpdate := time.Now() - - for { - select { - case <-ctx.Done(): - return ctx.Err() - default: - } - - // Check if we've exceeded the timeout - if time.Now().After(deadline) { - return fmt.Errorf("timeout after %s", timeout) - } - - // Get current invalidation status - input := &cloudfront.GetInvalidationInput{ - DistributionId: aws.String(distributionID), - Id: aws.String(invalidationID), - } - - output, err := conn.GetInvalidation(ctx, input) - if err != nil { - return fmt.Errorf("getting invalidation status: %w", err) - } - - currentStatus := aws.ToString(output.Invalidation.Status) - - // Send progress update every 60 seconds - if time.Since(lastProgressUpdate) >= progressInterval { - resp.SendProgress(action.InvokeProgressEvent{ - Message: fmt.Sprintf("Invalidation %s is currently '%s', continuing to wait for completion...", invalidationID, currentStatus), - }) - lastProgressUpdate = time.Now() - } - - // Check if we've reached completion - if aws.ToString(output.Invalidation.Status) == "Completed" { - return nil - } - - // Check if we're in an unexpected state - validStatuses := []string{ - "InProgress", - } - if !slices.Contains(validStatuses, currentStatus) && currentStatus != "Completed" { - return fmt.Errorf("invalidation entered unexpected status: %s", currentStatus) - } - - // Wait before next poll - time.Sleep(pollInterval) - } -} +// Legacy helper removed; polling now centralized in actionwait. From 15e8770b1cac8da132d38c7c677e46e3433778d3 Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Thu, 25 Sep 2025 21:10:35 -0700 Subject: [PATCH 05/17] Remove comments --- internal/service/cloudfront/create_invalidation_action.go | 2 -- internal/service/ec2/ec2_stop_instance_action.go | 2 -- 2 files changed, 4 deletions(-) diff --git a/internal/service/cloudfront/create_invalidation_action.go b/internal/service/cloudfront/create_invalidation_action.go index 48e553d02cf1..b5cf4de81973 100644 --- a/internal/service/cloudfront/create_invalidation_action.go +++ b/internal/service/cloudfront/create_invalidation_action.go @@ -271,5 +271,3 @@ func (a *createInvalidationAction) Invoke(ctx context.Context, req action.Invoke "paths": paths, }) } - -// Legacy helper removed; polling now centralized in actionwait. diff --git a/internal/service/ec2/ec2_stop_instance_action.go b/internal/service/ec2/ec2_stop_instance_action.go index a390c76fd2c8..b03e47262014 100644 --- a/internal/service/ec2/ec2_stop_instance_action.go +++ b/internal/service/ec2/ec2_stop_instance_action.go @@ -242,5 +242,3 @@ func canStopInstance(state awstypes.InstanceStateName) bool { return false } } - -// Legacy polling helper removed; replaced with actionwait. From 684dbb3ddd48485ed0084c3ce3379def3d32ade0 Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Thu, 25 Sep 2025 21:37:31 -0700 Subject: [PATCH 06/17] Linting --- internal/actionwait/wait.go | 39 ++++++++++--------- internal/actionwait/wait_test.go | 21 +++++----- .../cloudfront/create_invalidation_action.go | 12 +++--- .../service/codebuild/start_build_action.go | 15 ++++--- .../service/ec2/ec2_stop_instance_action.go | 15 ++++--- 5 files changed, 58 insertions(+), 44 deletions(-) diff --git a/internal/actionwait/wait.go b/internal/actionwait/wait.go index 23cbaeeb8cb9..3030e0cb94d3 100644 --- a/internal/actionwait/wait.go +++ b/internal/actionwait/wait.go @@ -13,6 +13,9 @@ import ( "time" ) +// DefaultPollInterval is the default fixed polling interval used when no custom IntervalStrategy is provided. +const DefaultPollInterval = 30 * time.Second + // Status represents a string status value returned from a polled API. type Status string @@ -28,7 +31,7 @@ type FetchResult[T any] struct { type FetchFunc[T any] func(context.Context) (FetchResult[T], error) // IntervalStrategy allows pluggable poll interval behavior (fixed, backoff, etc.). -type IntervalStrategy interface { // nolint:interfacebloat // single method interface +type IntervalStrategy interface { //nolint:interfacebloat // single method interface (tiny intentional interface) NextPoll(attempt uint) time.Duration } @@ -59,46 +62,46 @@ type ProgressMeta struct { NextPollIn time.Duration } -// ErrTimeout is returned when the operation does not reach a success state within Timeout. -type ErrTimeout struct { +// TimeoutError is returned when the operation does not reach a success state within Timeout. +type TimeoutError struct { LastStatus Status Timeout time.Duration } -func (e *ErrTimeout) Error() string { +func (e *TimeoutError) Error() string { return "timeout waiting for target status after " + e.Timeout.String() } -// ErrFailureState indicates the operation entered a declared failure state. -type ErrFailureState struct { +// FailureStateError indicates the operation entered a declared failure state. +type FailureStateError struct { Status Status } -func (e *ErrFailureState) Error() string { +func (e *FailureStateError) Error() string { return "operation entered failure state: " + string(e.Status) } -// ErrUnexpectedState indicates the operation entered a state outside success/transitional/failure sets. -type ErrUnexpectedState struct { +// UnexpectedStateError indicates the operation entered a state outside success/transitional/failure sets. +type UnexpectedStateError struct { Status Status Allowed []Status } -func (e *ErrUnexpectedState) Error() string { +func (e *UnexpectedStateError) Error() string { return "operation entered unexpected state: " + string(e.Status) } // sentinel errors helpers var ( - _ error = (*ErrTimeout)(nil) - _ error = (*ErrFailureState)(nil) - _ error = (*ErrUnexpectedState)(nil) + _ error = (*TimeoutError)(nil) + _ error = (*FailureStateError)(nil) + _ error = (*UnexpectedStateError)(nil) ) // WaitForStatus polls using fetch until a success state, failure state, timeout, unexpected state, // context cancellation, or fetch error occurs. // On success, the final FetchResult is returned with nil error. -func WaitForStatus[T any](ctx context.Context, fetch FetchFunc[T], opts Options[T]) (FetchResult[T], error) { // nolint:cyclop +func WaitForStatus[T any](ctx context.Context, fetch FetchFunc[T], opts Options[T]) (FetchResult[T], error) { //nolint:cyclop // complexity driven by classification/state machine; readability preferred var zero FetchResult[T] if opts.Timeout <= 0 { @@ -111,7 +114,7 @@ func WaitForStatus[T any](ctx context.Context, fetch FetchFunc[T], opts Options[ opts.ConsecutiveSuccess = 1 } if opts.Interval == nil { - opts.Interval = FixedInterval(30 * time.Second) + opts.Interval = FixedInterval(DefaultPollInterval) } start := time.Now() @@ -132,7 +135,7 @@ func WaitForStatus[T any](ctx context.Context, fetch FetchFunc[T], opts Options[ } now := time.Now() if now.After(deadline) { - return last, &ErrTimeout{LastStatus: last.Status, Timeout: opts.Timeout} + return last, &TimeoutError{LastStatus: last.Status, Timeout: opts.Timeout} } fr, err := fetch(ctx) @@ -143,7 +146,7 @@ func WaitForStatus[T any](ctx context.Context, fetch FetchFunc[T], opts Options[ // Classification precedence: failure -> success -> transitional -> unexpected if contains(opts.FailureStates, fr.Status) { - return fr, &ErrFailureState{Status: fr.Status} + return fr, &FailureStateError{Status: fr.Status} } if contains(opts.SuccessStates, fr.Status) { successStreak++ @@ -154,7 +157,7 @@ func WaitForStatus[T any](ctx context.Context, fetch FetchFunc[T], opts Options[ successStreak = 0 if len(opts.TransitionalStates) > 0 { if !contains(opts.TransitionalStates, fr.Status) { - return fr, &ErrUnexpectedState{Status: fr.Status, Allowed: allowedTransient} + return fr, &UnexpectedStateError{Status: fr.Status, Allowed: allowedTransient} } } } diff --git a/internal/actionwait/wait_test.go b/internal/actionwait/wait_test.go index 3799737ddcca..7f5b1a172309 100644 --- a/internal/actionwait/wait_test.go +++ b/internal/actionwait/wait_test.go @@ -15,23 +15,24 @@ import ( const fastFixedInterval = 5 * time.Millisecond // makeCtx creates a context with generous overall test timeout safeguard. -func makeCtx(t *testing.T) context.Context { //nolint:revive // test helper +func makeCtx(t *testing.T) context.Context { // test helper ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) t.Cleanup(cancel) return ctx } func TestWaitForStatus_ValidationErrors(t *testing.T) { - // No t.Parallel here since we rely on wall clock; subtests parallelized individually. + t.Parallel() + // Subtests parallelized; each uses its own context with timeout. cases := map[string]Options[struct{}]{ "missing timeout": {SuccessStates: []Status{"ok"}}, "missing success": {Timeout: time.Second}, } - for name, opts := range cases { + for name, opts := range cases { // Go 1.22+ copyloopvar: explicit copy not needed opts := opts - //nolint:paralleltest // simple and quick t.Run(name, func(t *testing.T) { + t.Parallel() ctx := makeCtx(t) _, err := WaitForStatus(ctx, func(context.Context) (FetchResult[struct{}], error) { return FetchResult[struct{}]{Status: "irrelevant"}, nil @@ -101,8 +102,8 @@ func TestWaitForStatus_FailureState(t *testing.T) { if err == nil { t.Fatal("expected failure error") } - if _, ok := err.(*ErrFailureState); !ok { //nolint:errorlint - t.Fatalf("expected ErrFailureState, got %T", err) + if _, ok := err.(*FailureStateError); !ok { //nolint:errorlint // direct type assertion adequate in tests + t.Fatalf("expected FailureStateError, got %T", err) } if fr.Status != "FAILED" { t.Fatalf("unexpected status: %v", fr.Status) @@ -123,8 +124,8 @@ func TestWaitForStatus_UnexpectedState_WithTransitional(t *testing.T) { if err == nil { t.Fatal("expected unexpected state error") } - if _, ok := err.(*ErrUnexpectedState); !ok { //nolint:errorlint - t.Fatalf("expected ErrUnexpectedState, got %T", err) + if _, ok := err.(*UnexpectedStateError); !ok { //nolint:errorlint // direct type assertion adequate in tests + t.Fatalf("expected UnexpectedStateError, got %T", err) } } @@ -142,8 +143,8 @@ func TestWaitForStatus_NoTransitionalListAllowsAnyUntilTimeout(t *testing.T) { if err == nil { t.Fatal("expected timeout error") } - if _, ok := err.(*ErrTimeout); !ok { //nolint:errorlint - t.Fatalf("expected ErrTimeout, got %T", err) + if _, ok := err.(*TimeoutError); !ok { //nolint:errorlint // direct type assertion adequate in tests + t.Fatalf("expected TimeoutError, got %T", err) } if time.Since(start) < 40*time.Millisecond { // sanity that we actually waited t.Fatalf("timeout returned too early") diff --git a/internal/service/cloudfront/create_invalidation_action.go b/internal/service/cloudfront/create_invalidation_action.go index b5cf4de81973..469d01ea912f 100644 --- a/internal/service/cloudfront/create_invalidation_action.go +++ b/internal/service/cloudfront/create_invalidation_action.go @@ -5,6 +5,7 @@ package cloudfront import ( "context" + "errors" "fmt" "time" @@ -229,7 +230,7 @@ func (a *createInvalidationAction) Invoke(ctx context.Context, req action.Invoke return actionwait.FetchResult[struct{}]{Status: actionwait.Status(status)}, nil }, actionwait.Options[struct{}]{ Timeout: timeout, - Interval: actionwait.FixedInterval(30 * time.Second), + Interval: actionwait.FixedInterval(actionwait.DefaultPollInterval), ProgressInterval: 60 * time.Second, SuccessStates: []actionwait.Status{"Completed"}, TransitionalStates: []actionwait.Status{ @@ -240,18 +241,19 @@ func (a *createInvalidationAction) Invoke(ctx context.Context, req action.Invoke }, }) if err != nil { - switch err.(type) { - case *actionwait.ErrTimeout: + var timeoutErr *actionwait.TimeoutError + var unexpectedErr *actionwait.UnexpectedStateError + if errors.As(err, &timeoutErr) { resp.Diagnostics.AddError( "Timeout Waiting for Invalidation to Complete", fmt.Sprintf("CloudFront invalidation %s did not complete within %s: %s", invalidationID, timeout, err), ) - case *actionwait.ErrUnexpectedState: + } else if errors.As(err, &unexpectedErr) { resp.Diagnostics.AddError( "Invalid Invalidation State", fmt.Sprintf("CloudFront invalidation %s entered unexpected state: %s", invalidationID, err), ) - default: + } else { resp.Diagnostics.AddError( "Failed While Waiting for Invalidation", fmt.Sprintf("Error waiting for CloudFront invalidation %s: %s", invalidationID, err), diff --git a/internal/service/codebuild/start_build_action.go b/internal/service/codebuild/start_build_action.go index ca3f0c9e45da..ac7252ab352c 100644 --- a/internal/service/codebuild/start_build_action.go +++ b/internal/service/codebuild/start_build_action.go @@ -5,6 +5,7 @@ package codebuild import ( "context" + "errors" "fmt" "time" @@ -146,7 +147,7 @@ func (a *startBuildAction) Invoke(ctx context.Context, req action.InvokeRequest, return actionwait.FetchResult[*awstypes.Build]{Status: actionwait.Status(b.BuildStatus), Value: &b}, nil }, actionwait.Options[*awstypes.Build]{ Timeout: timeout, - Interval: actionwait.FixedInterval(30 * time.Second), + Interval: actionwait.FixedInterval(actionwait.DefaultPollInterval), ProgressInterval: 2 * time.Minute, SuccessStates: []actionwait.Status{actionwait.Status(awstypes.StatusTypeSucceeded)}, TransitionalStates: []actionwait.Status{ @@ -163,14 +164,16 @@ func (a *startBuildAction) Invoke(ctx context.Context, req action.InvokeRequest, }, }) if err != nil { - switch err.(type) { - case *actionwait.ErrTimeout: + var timeoutErr *actionwait.TimeoutError + var failureErr *actionwait.FailureStateError + var unexpectedErr *actionwait.UnexpectedStateError + if errors.As(err, &timeoutErr) { resp.Diagnostics.AddError("Build timeout", "Build did not complete within the specified timeout") - case *actionwait.ErrFailureState: + } else if errors.As(err, &failureErr) { resp.Diagnostics.AddError("Build failed", "Build completed with status: "+err.Error()) - case *actionwait.ErrUnexpectedState: + } else if errors.As(err, &unexpectedErr) { resp.Diagnostics.AddError("Unexpected build status", err.Error()) - default: + } else { resp.Diagnostics.AddError("Error waiting for build", err.Error()) } return diff --git a/internal/service/ec2/ec2_stop_instance_action.go b/internal/service/ec2/ec2_stop_instance_action.go index b03e47262014..95eb642eb8ed 100644 --- a/internal/service/ec2/ec2_stop_instance_action.go +++ b/internal/service/ec2/ec2_stop_instance_action.go @@ -5,6 +5,7 @@ package ec2 import ( "context" + "errors" "fmt" "time" @@ -25,6 +26,9 @@ import ( "github.com/hashicorp/terraform-provider-aws/names" ) +// ec2StopInstancePollInterval defines polling cadence for EC2 stop instance action. +const ec2StopInstancePollInterval = 10 * time.Second + // @Action(aws_ec2_stop_instance, name="Stop Instance") func newStopInstanceAction(_ context.Context) (action.ActionWithConfigure, error) { return &stopInstanceAction{}, nil @@ -190,7 +194,7 @@ func (a *stopInstanceAction) Invoke(ctx context.Context, req action.InvokeReques return actionwait.FetchResult[struct{}]{Status: actionwait.Status(state)}, nil }, actionwait.Options[struct{}]{ Timeout: timeout, - Interval: actionwait.FixedInterval(10 * time.Second), + Interval: actionwait.FixedInterval(ec2StopInstancePollInterval), ProgressInterval: 30 * time.Second, SuccessStates: []actionwait.Status{actionwait.Status(awstypes.InstanceStateNameStopped)}, TransitionalStates: []actionwait.Status{ @@ -203,18 +207,19 @@ func (a *stopInstanceAction) Invoke(ctx context.Context, req action.InvokeReques }, }) if err != nil { - switch err.(type) { - case *actionwait.ErrTimeout: + var timeoutErr *actionwait.TimeoutError + var unexpectedErr *actionwait.UnexpectedStateError + if errors.As(err, &timeoutErr) { resp.Diagnostics.AddError( "Timeout Waiting for Instance to Stop", fmt.Sprintf("EC2 instance %s did not stop within %s: %s", instanceID, timeout, err), ) - case *actionwait.ErrUnexpectedState: + } else if errors.As(err, &unexpectedErr) { resp.Diagnostics.AddError( "Unexpected Instance State", fmt.Sprintf("EC2 instance %s entered unexpected state while stopping: %s", instanceID, err), ) - default: + } else { resp.Diagnostics.AddError( "Error Waiting for Instance to Stop", fmt.Sprintf("Error while waiting for EC2 instance %s to stop: %s", instanceID, err), From 15968943c609405734bc74b211b8b28d5f3a3de6 Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Thu, 25 Sep 2025 21:40:59 -0700 Subject: [PATCH 07/17] Modern --- internal/actionwait/wait_test.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/internal/actionwait/wait_test.go b/internal/actionwait/wait_test.go index 7f5b1a172309..9f29bcac798b 100644 --- a/internal/actionwait/wait_test.go +++ b/internal/actionwait/wait_test.go @@ -29,8 +29,7 @@ func TestWaitForStatus_ValidationErrors(t *testing.T) { "missing success": {Timeout: time.Second}, } - for name, opts := range cases { // Go 1.22+ copyloopvar: explicit copy not needed - opts := opts + for name, opts := range cases { t.Run(name, func(t *testing.T) { t.Parallel() ctx := makeCtx(t) From 235adbab2929964d4beab3688eab5ec4e4cf566f Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Thu, 25 Sep 2025 21:44:20 -0700 Subject: [PATCH 08/17] Service name in var --- internal/service/ec2/ec2_stop_instance_action.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/internal/service/ec2/ec2_stop_instance_action.go b/internal/service/ec2/ec2_stop_instance_action.go index 95eb642eb8ed..d4b348c90a96 100644 --- a/internal/service/ec2/ec2_stop_instance_action.go +++ b/internal/service/ec2/ec2_stop_instance_action.go @@ -26,8 +26,8 @@ import ( "github.com/hashicorp/terraform-provider-aws/names" ) -// ec2StopInstancePollInterval defines polling cadence for EC2 stop instance action. -const ec2StopInstancePollInterval = 10 * time.Second +// stopInstancePollInterval defines polling cadence for stop instance action. +const stopInstancePollInterval = 10 * time.Second // @Action(aws_ec2_stop_instance, name="Stop Instance") func newStopInstanceAction(_ context.Context) (action.ActionWithConfigure, error) { @@ -194,7 +194,7 @@ func (a *stopInstanceAction) Invoke(ctx context.Context, req action.InvokeReques return actionwait.FetchResult[struct{}]{Status: actionwait.Status(state)}, nil }, actionwait.Options[struct{}]{ Timeout: timeout, - Interval: actionwait.FixedInterval(ec2StopInstancePollInterval), + Interval: actionwait.FixedInterval(stopInstancePollInterval), ProgressInterval: 30 * time.Second, SuccessStates: []actionwait.Status{actionwait.Status(awstypes.InstanceStateNameStopped)}, TransitionalStates: []actionwait.Status{ From b181b5b8236ffee24e1add9d36ae7259fe4bfd2c Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Thu, 25 Sep 2025 22:04:46 -0700 Subject: [PATCH 09/17] Stack/heap --- internal/service/cloudfront/create_invalidation_action.go | 5 +++-- internal/service/codebuild/start_build_action.go | 3 ++- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/internal/service/cloudfront/create_invalidation_action.go b/internal/service/cloudfront/create_invalidation_action.go index 469d01ea912f..61e81a20265f 100644 --- a/internal/service/cloudfront/create_invalidation_action.go +++ b/internal/service/cloudfront/create_invalidation_action.go @@ -219,10 +219,11 @@ func (a *createInvalidationAction) Invoke(ctx context.Context, req action.Invoke // Wait for invalidation to complete with periodic progress updates using actionwait _, err = actionwait.WaitForStatus(ctx, func(ctx context.Context) (actionwait.FetchResult[struct{}], error) { - output, gerr := conn.GetInvalidation(ctx, &cloudfront.GetInvalidationInput{ + input := cloudfront.GetInvalidationInput{ DistributionId: aws.String(distributionID), Id: aws.String(invalidationID), - }) + } + output, gerr := conn.GetInvalidation(ctx, &input) if gerr != nil { return actionwait.FetchResult[struct{}]{}, fmt.Errorf("getting invalidation status: %w", gerr) } diff --git a/internal/service/codebuild/start_build_action.go b/internal/service/codebuild/start_build_action.go index ac7252ab352c..57e85e714ab5 100644 --- a/internal/service/codebuild/start_build_action.go +++ b/internal/service/codebuild/start_build_action.go @@ -136,7 +136,8 @@ func (a *startBuildAction) Invoke(ctx context.Context, req action.InvokeRequest, // Poll for build completion using actionwait _, err = actionwait.WaitForStatus(ctx, func(ctx context.Context) (actionwait.FetchResult[*awstypes.Build], error) { - batch, berr := conn.BatchGetBuilds(ctx, &codebuild.BatchGetBuildsInput{Ids: []string{buildID}}) + input := codebuild.BatchGetBuildsInput{Ids: []string{buildID}} + batch, berr := conn.BatchGetBuilds(ctx, &input) if berr != nil { return actionwait.FetchResult[*awstypes.Build]{}, berr } From 76b288a2ce4cd90e50a594472a6ba4c23750d78b Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Thu, 25 Sep 2025 22:19:22 -0700 Subject: [PATCH 10/17] make gen --- .teamcity/scripts/provider_tests/acceptance_tests.sh | 1 + .teamcity/scripts/provider_tests/unit_tests.sh | 1 + 2 files changed, 2 insertions(+) diff --git a/.teamcity/scripts/provider_tests/acceptance_tests.sh b/.teamcity/scripts/provider_tests/acceptance_tests.sh index f42d935c248d..539ee5a129f8 100644 --- a/.teamcity/scripts/provider_tests/acceptance_tests.sh +++ b/.teamcity/scripts/provider_tests/acceptance_tests.sh @@ -34,6 +34,7 @@ fi TF_ACC=1 go test \ ./internal/acctest/... \ + ./internal/actionwait/... \ ./internal/attrmap/... \ ./internal/backoff/... \ ./internal/conns/... \ diff --git a/.teamcity/scripts/provider_tests/unit_tests.sh b/.teamcity/scripts/provider_tests/unit_tests.sh index 2f0072dc45e0..c3bf8454caf7 100644 --- a/.teamcity/scripts/provider_tests/unit_tests.sh +++ b/.teamcity/scripts/provider_tests/unit_tests.sh @@ -6,6 +6,7 @@ set -euo pipefail go test \ ./internal/acctest/... \ + ./internal/actionwait/... \ ./internal/attrmap/... \ ./internal/backoff/... \ ./internal/conns/... \ From 939bdf3348c560d661874d3de3f402a0f78b2f13 Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Tue, 30 Sep 2025 16:56:05 -0400 Subject: [PATCH 11/17] Improve early returns --- internal/actionwait/wait.go | 185 ++++++++++++++++++++++++------------ 1 file changed, 123 insertions(+), 62 deletions(-) diff --git a/internal/actionwait/wait.go b/internal/actionwait/wait.go index 3030e0cb94d3..0711b638908c 100644 --- a/internal/actionwait/wait.go +++ b/internal/actionwait/wait.go @@ -10,6 +10,7 @@ import ( "context" "errors" "slices" + "strings" "time" ) @@ -88,7 +89,15 @@ type UnexpectedStateError struct { } func (e *UnexpectedStateError) Error() string { - return "operation entered unexpected state: " + string(e.Status) + if len(e.Allowed) == 0 { + return "operation entered unexpected state: " + string(e.Status) + } + allowedStr := make([]string, len(e.Allowed)) + for i, s := range e.Allowed { + allowedStr[i] = string(s) + } + return "operation entered unexpected state: " + string(e.Status) + " (allowed: " + + strings.Join(allowedStr, ", ") + ")" } // sentinel errors helpers @@ -102,21 +111,13 @@ var ( // context cancellation, or fetch error occurs. // On success, the final FetchResult is returned with nil error. func WaitForStatus[T any](ctx context.Context, fetch FetchFunc[T], opts Options[T]) (FetchResult[T], error) { //nolint:cyclop // complexity driven by classification/state machine; readability preferred - var zero FetchResult[T] - - if opts.Timeout <= 0 { - return zero, errors.New("actionwait: Timeout must be > 0") - } - if len(opts.SuccessStates) == 0 { - return zero, errors.New("actionwait: at least one SuccessState required") - } - if opts.ConsecutiveSuccess <= 0 { - opts.ConsecutiveSuccess = 1 - } - if opts.Interval == nil { - opts.Interval = FixedInterval(DefaultPollInterval) + if err := validateOptions(opts); err != nil { + var zero FetchResult[T] + return zero, err } + normalizeOptions(&opts) + start := time.Now() deadline := start.Add(opts.Timeout) var lastProgress time.Time @@ -130,64 +131,37 @@ func WaitForStatus[T any](ctx context.Context, fetch FetchFunc[T], opts Options[ allowedTransient = append(allowedTransient, opts.TransitionalStates...) for { + // Early return: context cancelled if ctx.Err() != nil { return last, ctx.Err() } - now := time.Now() - if now.After(deadline) { + + // Early return: timeout exceeded + if time.Now().After(deadline) { return last, &TimeoutError{LastStatus: last.Status, Timeout: opts.Timeout} } + // Fetch current status fr, err := fetch(ctx) if err != nil { - return fr, err + return fr, err // Early return: fetch error } last = fr - // Classification precedence: failure -> success -> transitional -> unexpected - if contains(opts.FailureStates, fr.Status) { - return fr, &FailureStateError{Status: fr.Status} - } - if contains(opts.SuccessStates, fr.Status) { - successStreak++ - if successStreak >= opts.ConsecutiveSuccess { - return fr, nil - } - } else { - successStreak = 0 - if len(opts.TransitionalStates) > 0 { - if !contains(opts.TransitionalStates, fr.Status) { - return fr, &UnexpectedStateError{Status: fr.Status, Allowed: allowedTransient} - } - } + // Classify status and determine if we should terminate + isTerminal, classifyErr := classifyStatus(fr, opts, &successStreak, allowedTransient) + if isTerminal { + return fr, classifyErr // Early return: terminal state (success or failure) } - // Progress callback throttling - if opts.ProgressSink != nil && opts.ProgressInterval > 0 { - if lastProgress.IsZero() || time.Since(lastProgress) >= opts.ProgressInterval { - nextPoll := opts.Interval.NextPoll(attempt) - opts.ProgressSink(anyFetchResult(fr), ProgressMeta{ - Attempt: attempt, - Elapsed: time.Since(start), - Remaining: maxDuration(0, time.Until(deadline)), // time.Until for clarity - Deadline: deadline, - NextPollIn: nextPoll, - }) - lastProgress = time.Now() - } - } + // Handle progress reporting + handleProgressReport(opts, fr, start, deadline, attempt, &lastProgress) - // Sleep until next attempt - sleep := opts.Interval.NextPoll(attempt) - if sleep > 0 { - timer := time.NewTimer(sleep) - select { - case <-ctx.Done(): - timer.Stop() - return last, ctx.Err() - case <-timer.C: - } + // Sleep until next attempt, with context cancellation check + if err := sleepWithContext(ctx, opts.Interval.NextPoll(attempt)); err != nil { + return last, err // Early return: context cancelled during sleep } + attempt++ } } @@ -197,14 +171,101 @@ func anyFetchResult[T any](fr FetchResult[T]) FetchResult[any] { return FetchResult[any]{Status: fr.Status, Value: any(fr.Value)} } -// contains tests membership in a slice of Status. -func contains(haystack []Status, needle Status) bool { - return slices.Contains(haystack, needle) -} - func maxDuration(a, b time.Duration) time.Duration { if a > b { return a } return b } + +// validateOptions performs early validation of required options. +func validateOptions[T any](opts Options[T]) error { + if opts.Timeout <= 0 { + return errors.New("actionwait: Timeout must be > 0") + } + if len(opts.SuccessStates) == 0 { + return errors.New("actionwait: at least one SuccessState required") + } + if opts.ConsecutiveSuccess < 0 { + return errors.New("actionwait: ConsecutiveSuccess cannot be negative") + } + if opts.ProgressInterval < 0 { + return errors.New("actionwait: ProgressInterval cannot be negative") + } + return nil +} + +// normalizeOptions sets defaults for optional configuration. +func normalizeOptions[T any](opts *Options[T]) { + if opts.ConsecutiveSuccess <= 0 { + opts.ConsecutiveSuccess = 1 + } + if opts.Interval == nil { + opts.Interval = FixedInterval(DefaultPollInterval) + } +} + +// classifyStatus determines the next action based on the current status. +// Returns: (isTerminal, error) - if isTerminal is true, polling should stop. +func classifyStatus[T any](fr FetchResult[T], opts Options[T], successStreak *int, allowedTransient []Status) (bool, error) { + // Classification precedence: failure -> success -> transitional -> unexpected + if slices.Contains(opts.FailureStates, fr.Status) { + return true, &FailureStateError{Status: fr.Status} + } + + if slices.Contains(opts.SuccessStates, fr.Status) { + *successStreak++ + if *successStreak >= opts.ConsecutiveSuccess { + return true, nil // Success! + } + return false, nil // Continue polling for consecutive successes + } + + // Not a success state, reset streak + *successStreak = 0 + + // Check if transitional state is allowed + // If TransitionalStates is specified, status must be in that list + // If TransitionalStates is empty, any non-success/non-failure state is allowed + if len(opts.TransitionalStates) > 0 && !slices.Contains(opts.TransitionalStates, fr.Status) { + return true, &UnexpectedStateError{Status: fr.Status, Allowed: allowedTransient} + } + + return false, nil // Continue polling +} + +// handleProgressReport sends progress updates if conditions are met. +func handleProgressReport[T any](opts Options[T], fr FetchResult[T], start time.Time, deadline time.Time, attempt uint, lastProgress *time.Time) { + if opts.ProgressSink == nil || opts.ProgressInterval <= 0 { + return + } + + if lastProgress.IsZero() || time.Since(*lastProgress) >= opts.ProgressInterval { + nextPoll := opts.Interval.NextPoll(attempt) + opts.ProgressSink(anyFetchResult(fr), ProgressMeta{ + Attempt: attempt, + Elapsed: time.Since(start), + Remaining: maxDuration(0, time.Until(deadline)), + Deadline: deadline, + NextPollIn: nextPoll, + }) + *lastProgress = time.Now() + } +} + +// sleepWithContext sleeps for the specified duration while respecting context cancellation. +func sleepWithContext(ctx context.Context, duration time.Duration) error { + if duration <= 0 { + return nil + } + + timer := time.NewTimer(duration) + defer timer.Stop() + + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} From cd48e50c93039442f27485c0b80b4aa722906d51 Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Tue, 30 Sep 2025 16:56:20 -0400 Subject: [PATCH 12/17] Fix tests --- internal/actionwait/wait_test.go | 37 ++++++++++++++++++++++++++++++-- 1 file changed, 35 insertions(+), 2 deletions(-) diff --git a/internal/actionwait/wait_test.go b/internal/actionwait/wait_test.go index 9f29bcac798b..f85abb74b538 100644 --- a/internal/actionwait/wait_test.go +++ b/internal/actionwait/wait_test.go @@ -6,6 +6,7 @@ package actionwait import ( "context" "errors" + "strings" "sync/atomic" "testing" "time" @@ -25,8 +26,10 @@ func TestWaitForStatus_ValidationErrors(t *testing.T) { t.Parallel() // Subtests parallelized; each uses its own context with timeout. cases := map[string]Options[struct{}]{ - "missing timeout": {SuccessStates: []Status{"ok"}}, - "missing success": {Timeout: time.Second}, + "missing timeout": {SuccessStates: []Status{"ok"}}, + "missing success": {Timeout: time.Second}, + "negative consecutive": {Timeout: time.Second, SuccessStates: []Status{"ok"}, ConsecutiveSuccess: -1}, + "negative progress interval": {Timeout: time.Second, SuccessStates: []Status{"ok"}, ProgressInterval: -time.Second}, } for name, opts := range cases { @@ -289,3 +292,33 @@ func TestWaitForStatus_ProgressSinkDisabled(t *testing.T) { t.Fatalf("expected zero progress sink calls, got %d", progressCalls) } } + +func TestWaitForStatus_UnexpectedStateErrorMessage(t *testing.T) { + t.Parallel() + ctx := makeCtx(t) + _, err := WaitForStatus(ctx, func(context.Context) (FetchResult[int], error) { + return FetchResult[int]{Status: "UNKNOWN"}, nil + }, Options[int]{ + Timeout: 200 * time.Millisecond, + SuccessStates: []Status{"OK"}, + TransitionalStates: []Status{"PENDING", "IN_PROGRESS"}, + Interval: FixedInterval(fastFixedInterval), + }) + if err == nil { + t.Fatal("expected unexpected state error") + } + unexpectedErr, ok := err.(*UnexpectedStateError) + if !ok { //nolint:errorlint // direct type assertion adequate in tests + t.Fatalf("expected UnexpectedStateError, got %T", err) + } + errMsg := unexpectedErr.Error() + if !strings.Contains(errMsg, "UNKNOWN") { + t.Errorf("error message should contain status 'UNKNOWN', got: %s", errMsg) + } + if !strings.Contains(errMsg, "allowed:") { + t.Errorf("error message should list allowed states, got: %s", errMsg) + } + if !strings.Contains(errMsg, "PENDING") { + t.Errorf("error message should contain allowed state 'PENDING', got: %s", errMsg) + } +} From 5cdef321dba9e677e79b18f3612fb81c034f9791 Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Tue, 30 Sep 2025 17:14:48 -0400 Subject: [PATCH 13/17] Modern errors --- internal/actionwait/wait_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/internal/actionwait/wait_test.go b/internal/actionwait/wait_test.go index f85abb74b538..b4e73f226ae4 100644 --- a/internal/actionwait/wait_test.go +++ b/internal/actionwait/wait_test.go @@ -307,8 +307,8 @@ func TestWaitForStatus_UnexpectedStateErrorMessage(t *testing.T) { if err == nil { t.Fatal("expected unexpected state error") } - unexpectedErr, ok := err.(*UnexpectedStateError) - if !ok { //nolint:errorlint // direct type assertion adequate in tests + var unexpectedErr *UnexpectedStateError + if !errors.As(err, &unexpectedErr) { t.Fatalf("expected UnexpectedStateError, got %T", err) } errMsg := unexpectedErr.Error() From 63624453f6c7ee6d0e9987fda6f0f2c0fbf7bfa9 Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Thu, 2 Oct 2025 18:11:33 -0400 Subject: [PATCH 14/17] Split errors from wait --- internal/actionwait/errors.go | 70 +++++++ internal/actionwait/errors_test.go | 283 +++++++++++++++++++++++++++++ internal/actionwait/wait.go | 45 ----- 3 files changed, 353 insertions(+), 45 deletions(-) create mode 100644 internal/actionwait/errors.go create mode 100644 internal/actionwait/errors_test.go diff --git a/internal/actionwait/errors.go b/internal/actionwait/errors.go new file mode 100644 index 000000000000..58e440763b88 --- /dev/null +++ b/internal/actionwait/errors.go @@ -0,0 +1,70 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package actionwait + +import ( + "errors" + "strings" + "time" +) + +// TimeoutError is returned when the operation does not reach a success state within Timeout. +type TimeoutError struct { + LastStatus Status + Timeout time.Duration +} + +func (e *TimeoutError) Error() string { + return "timeout waiting for target status after " + e.Timeout.String() +} + +// FailureStateError indicates the operation entered a declared failure state. +type FailureStateError struct { + Status Status +} + +func (e *FailureStateError) Error() string { + return "operation entered failure state: " + string(e.Status) +} + +// UnexpectedStateError indicates the operation entered a state outside success/transitional/failure sets. +type UnexpectedStateError struct { + Status Status + Allowed []Status +} + +func (e *UnexpectedStateError) Error() string { + if len(e.Allowed) == 0 { + return "operation entered unexpected state: " + string(e.Status) + } + allowedStr := make([]string, len(e.Allowed)) + for i, s := range e.Allowed { + allowedStr[i] = string(s) + } + return "operation entered unexpected state: " + string(e.Status) + " (allowed: " + + strings.Join(allowedStr, ", ") + ")" +} + +// Error type assertions for compile-time verification +var ( + _ error = (*TimeoutError)(nil) + _ error = (*FailureStateError)(nil) + _ error = (*UnexpectedStateError)(nil) +) + +// Helper functions for error type checking +func IsTimeout(err error) bool { + var timeoutErr *TimeoutError + return errors.As(err, &timeoutErr) +} + +func IsFailureState(err error) bool { + var failureErr *FailureStateError + return errors.As(err, &failureErr) +} + +func IsUnexpectedState(err error) bool { + var unexpectedErr *UnexpectedStateError + return errors.As(err, &unexpectedErr) +} diff --git a/internal/actionwait/errors_test.go b/internal/actionwait/errors_test.go new file mode 100644 index 000000000000..c7379df86521 --- /dev/null +++ b/internal/actionwait/errors_test.go @@ -0,0 +1,283 @@ +// Copyright (c) HashiCorp, Inc. +// SPDX-License-Identifier: MPL-2.0 + +package actionwait + +import ( + "errors" + "strings" + "testing" + "time" +) + +func TestTimeoutError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err *TimeoutError + wantMsg string + wantType string + }{ + { + name: "with last status", + err: &TimeoutError{ + LastStatus: "CREATING", + Timeout: 5 * time.Minute, + }, + wantMsg: "timeout waiting for target status after 5m0s", + wantType: "*actionwait.TimeoutError", + }, + { + name: "with empty status", + err: &TimeoutError{ + LastStatus: "", + Timeout: 30 * time.Second, + }, + wantMsg: "timeout waiting for target status after 30s", + wantType: "*actionwait.TimeoutError", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.err.Error(); got != tt.wantMsg { + t.Errorf("TimeoutError.Error() = %q, want %q", got, tt.wantMsg) + } + + // Verify it implements error interface + var err error = tt.err + if got := err.Error(); got != tt.wantMsg { + t.Errorf("TimeoutError as error.Error() = %q, want %q", got, tt.wantMsg) + } + }) + } +} + +func TestFailureStateError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err *FailureStateError + wantMsg string + }{ + { + name: "with status", + err: &FailureStateError{ + Status: "FAILED", + }, + wantMsg: "operation entered failure state: FAILED", + }, + { + name: "with empty status", + err: &FailureStateError{ + Status: "", + }, + wantMsg: "operation entered failure state: ", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.err.Error(); got != tt.wantMsg { + t.Errorf("FailureStateError.Error() = %q, want %q", got, tt.wantMsg) + } + }) + } +} + +func TestUnexpectedStateError(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + err *UnexpectedStateError + wantMsg string + }{ + { + name: "no allowed states", + err: &UnexpectedStateError{ + Status: "UNKNOWN", + Allowed: nil, + }, + wantMsg: "operation entered unexpected state: UNKNOWN", + }, + { + name: "empty allowed states", + err: &UnexpectedStateError{ + Status: "UNKNOWN", + Allowed: []Status{}, + }, + wantMsg: "operation entered unexpected state: UNKNOWN", + }, + { + name: "single allowed state", + err: &UnexpectedStateError{ + Status: "UNKNOWN", + Allowed: []Status{"AVAILABLE"}, + }, + wantMsg: "operation entered unexpected state: UNKNOWN (allowed: AVAILABLE)", + }, + { + name: "multiple allowed states", + err: &UnexpectedStateError{ + Status: "UNKNOWN", + Allowed: []Status{"CREATING", "AVAILABLE", "UPDATING"}, + }, + wantMsg: "operation entered unexpected state: UNKNOWN (allowed: CREATING, AVAILABLE, UPDATING)", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := tt.err.Error(); got != tt.wantMsg { + t.Errorf("UnexpectedStateError.Error() = %q, want %q", got, tt.wantMsg) + } + }) + } +} + +func TestErrorTypeChecking(t *testing.T) { + t.Parallel() + + // Create instances of each error type + timeoutErr := &TimeoutError{LastStatus: "CREATING", Timeout: time.Minute} + failureErr := &FailureStateError{Status: "FAILED"} + unexpectedErr := &UnexpectedStateError{Status: "UNKNOWN", Allowed: []Status{"AVAILABLE"}} + genericErr := errors.New("generic error") + + tests := []struct { + name string + err error + wantIsTimeout bool + wantIsFailure bool + wantIsUnexpected bool + }{ + { + name: "TimeoutError", + err: timeoutErr, + wantIsTimeout: true, + wantIsFailure: false, + wantIsUnexpected: false, + }, + { + name: "FailureStateError", + err: failureErr, + wantIsTimeout: false, + wantIsFailure: true, + wantIsUnexpected: false, + }, + { + name: "UnexpectedStateError", + err: unexpectedErr, + wantIsTimeout: false, + wantIsFailure: false, + wantIsUnexpected: true, + }, + { + name: "generic error", + err: genericErr, + wantIsTimeout: false, + wantIsFailure: false, + wantIsUnexpected: false, + }, + { + name: "nil error", + err: nil, + wantIsTimeout: false, + wantIsFailure: false, + wantIsUnexpected: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + if got := IsTimeout(tt.err); got != tt.wantIsTimeout { + t.Errorf("IsTimeout(%v) = %v, want %v", tt.err, got, tt.wantIsTimeout) + } + + if got := IsFailureState(tt.err); got != tt.wantIsFailure { + t.Errorf("IsFailureState(%v) = %v, want %v", tt.err, got, tt.wantIsFailure) + } + + if got := IsUnexpectedState(tt.err); got != tt.wantIsUnexpected { + t.Errorf("IsUnexpectedState(%v) = %v, want %v", tt.err, got, tt.wantIsUnexpected) + } + }) + } +} + +func TestWrappedErrors(t *testing.T) { + t.Parallel() + + // Test that error type checking works with wrapped errors + baseErr := &TimeoutError{LastStatus: "CREATING", Timeout: time.Minute} + wrappedErr := errors.New("wrapped: " + baseErr.Error()) + + // Direct error should be detected + if !IsTimeout(baseErr) { + t.Errorf("IsTimeout should detect direct TimeoutError") + } + + // Wrapped string error should NOT be detected (this is expected behavior) + if IsTimeout(wrappedErr) { + t.Errorf("IsTimeout should not detect string-wrapped error") + } + + // But wrapped with errors.Join should work + joinedErr := errors.Join(baseErr, errors.New("additional context")) + if !IsTimeout(joinedErr) { + t.Errorf("IsTimeout should detect error in errors.Join") + } +} + +func TestErrorMessages(t *testing.T) { + t.Parallel() + + // Verify error messages contain expected components for debugging + timeoutErr := &TimeoutError{ + LastStatus: "PENDING", + Timeout: 2 * time.Minute, + } + + msg := timeoutErr.Error() + if !strings.Contains(msg, "timeout") { + t.Errorf("TimeoutError message should contain 'timeout', got: %q", msg) + } + if !strings.Contains(msg, "2m0s") { + t.Errorf("TimeoutError message should contain timeout duration, got: %q", msg) + } + + failureErr := &FailureStateError{Status: "ERROR"} + msg = failureErr.Error() + if !strings.Contains(msg, "failure state") { + t.Errorf("FailureStateError message should contain 'failure state', got: %q", msg) + } + if !strings.Contains(msg, "ERROR") { + t.Errorf("FailureStateError message should contain status, got: %q", msg) + } + + unexpectedErr := &UnexpectedStateError{ + Status: "WEIRD", + Allowed: []Status{"GOOD", "BETTER"}, + } + msg = unexpectedErr.Error() + if !strings.Contains(msg, "unexpected state") { + t.Errorf("UnexpectedStateError message should contain 'unexpected state', got: %q", msg) + } + if !strings.Contains(msg, "WEIRD") { + t.Errorf("UnexpectedStateError message should contain actual status, got: %q", msg) + } + if !strings.Contains(msg, "GOOD, BETTER") { + t.Errorf("UnexpectedStateError message should contain allowed states, got: %q", msg) + } +} diff --git a/internal/actionwait/wait.go b/internal/actionwait/wait.go index 0711b638908c..6da34b8e4ed3 100644 --- a/internal/actionwait/wait.go +++ b/internal/actionwait/wait.go @@ -10,7 +10,6 @@ import ( "context" "errors" "slices" - "strings" "time" ) @@ -63,50 +62,6 @@ type ProgressMeta struct { NextPollIn time.Duration } -// TimeoutError is returned when the operation does not reach a success state within Timeout. -type TimeoutError struct { - LastStatus Status - Timeout time.Duration -} - -func (e *TimeoutError) Error() string { - return "timeout waiting for target status after " + e.Timeout.String() -} - -// FailureStateError indicates the operation entered a declared failure state. -type FailureStateError struct { - Status Status -} - -func (e *FailureStateError) Error() string { - return "operation entered failure state: " + string(e.Status) -} - -// UnexpectedStateError indicates the operation entered a state outside success/transitional/failure sets. -type UnexpectedStateError struct { - Status Status - Allowed []Status -} - -func (e *UnexpectedStateError) Error() string { - if len(e.Allowed) == 0 { - return "operation entered unexpected state: " + string(e.Status) - } - allowedStr := make([]string, len(e.Allowed)) - for i, s := range e.Allowed { - allowedStr[i] = string(s) - } - return "operation entered unexpected state: " + string(e.Status) + " (allowed: " + - strings.Join(allowedStr, ", ") + ")" -} - -// sentinel errors helpers -var ( - _ error = (*TimeoutError)(nil) - _ error = (*FailureStateError)(nil) - _ error = (*UnexpectedStateError)(nil) -) - // WaitForStatus polls using fetch until a success state, failure state, timeout, unexpected state, // context cancellation, or fetch error occurs. // On success, the final FetchResult is returned with nil error. From a70a78a5917183019e3e692126975b0783bd18c3 Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Thu, 2 Oct 2025 18:18:44 -0400 Subject: [PATCH 15/17] Add backoff strategy --- internal/actionwait/wait.go | 28 +++++++++ internal/actionwait/wait_test.go | 102 +++++++++++++++++++++++++++++++ 2 files changed, 130 insertions(+) diff --git a/internal/actionwait/wait.go b/internal/actionwait/wait.go index 6da34b8e4ed3..1f095f7deef0 100644 --- a/internal/actionwait/wait.go +++ b/internal/actionwait/wait.go @@ -11,6 +11,8 @@ import ( "errors" "slices" "time" + + "github.com/hashicorp/terraform-provider-aws/internal/backoff" ) // DefaultPollInterval is the default fixed polling interval used when no custom IntervalStrategy is provided. @@ -41,6 +43,32 @@ type FixedInterval time.Duration // NextPoll returns the fixed duration. func (fi FixedInterval) NextPoll(uint) time.Duration { return time.Duration(fi) } +// BackoffInterval implements IntervalStrategy using a backoff.Delay strategy. +// This allows actionwait to leverage sophisticated backoff algorithms while +// maintaining the declarative status-based polling approach. +type BackoffInterval struct { + delay backoff.Delay +} + +// NextPoll returns the next polling interval using the wrapped backoff delay strategy. +func (bi BackoffInterval) NextPoll(attempt uint) time.Duration { + return bi.delay.Next(attempt) +} + +// WithBackoffDelay creates an IntervalStrategy that uses the provided backoff.Delay. +// This bridges actionwait's IntervalStrategy interface with the backoff package's +// delay strategies (fixed, exponential, SDK-compatible, etc.). +// +// Example usage: +// +// opts := actionwait.Options[MyType]{ +// Interval: actionwait.WithBackoffDelay(backoff.FixedDelay(time.Second)), +// // ... other options +// } +func WithBackoffDelay(delay backoff.Delay) IntervalStrategy { + return BackoffInterval{delay: delay} +} + // Options configure the WaitForStatus loop. type Options[T any] struct { Timeout time.Duration // Required total timeout. diff --git a/internal/actionwait/wait_test.go b/internal/actionwait/wait_test.go index b4e73f226ae4..02d3532d7036 100644 --- a/internal/actionwait/wait_test.go +++ b/internal/actionwait/wait_test.go @@ -10,6 +10,8 @@ import ( "sync/atomic" "testing" "time" + + "github.com/hashicorp/terraform-provider-aws/internal/backoff" ) // fastFixedInterval returns a very small fixed interval to speed tests. @@ -322,3 +324,103 @@ func TestWaitForStatus_UnexpectedStateErrorMessage(t *testing.T) { t.Errorf("error message should contain allowed state 'PENDING', got: %s", errMsg) } } + +func TestBackoffInterval(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + delay backoff.Delay + attempts []uint + expectedDurations []time.Duration + }{ + { + name: "fixed delay", + delay: backoff.FixedDelay(100 * time.Millisecond), + attempts: []uint{0, 1, 2, 3}, + expectedDurations: []time.Duration{0, 100 * time.Millisecond, 100 * time.Millisecond, 100 * time.Millisecond}, + }, + { + name: "zero delay", + delay: backoff.ZeroDelay, + attempts: []uint{0, 1, 2}, + expectedDurations: []time.Duration{0, 0, 0}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + interval := BackoffInterval{delay: tt.delay} + + for i, attempt := range tt.attempts { + got := interval.NextPoll(attempt) + want := tt.expectedDurations[i] + if got != want { + t.Errorf("NextPoll(%d) = %v, want %v", attempt, got, want) + } + } + }) + } +} + +func TestWithBackoffDelay(t *testing.T) { + t.Parallel() + + delay := backoff.FixedDelay(50 * time.Millisecond) + interval := WithBackoffDelay(delay) + + // Verify it implements IntervalStrategy + var _ IntervalStrategy = interval + + // Test that it wraps the delay correctly + if got := interval.NextPoll(0); got != 0 { + t.Errorf("NextPoll(0) = %v, want 0", got) + } + if got := interval.NextPoll(1); got != 50*time.Millisecond { + t.Errorf("NextPoll(1) = %v, want 50ms", got) + } +} + +func TestBackoffIntegration(t *testing.T) { + t.Parallel() + + ctx := makeCtx(t) + + var callCount atomic.Int32 + fetch := func(context.Context) (FetchResult[string], error) { + count := callCount.Add(1) + switch count { + case 1: + return FetchResult[string]{Status: "CREATING", Value: "attempt1"}, nil + case 2: + return FetchResult[string]{Status: "AVAILABLE", Value: "success"}, nil + default: + t.Errorf("unexpected call count: %d", count) + return FetchResult[string]{}, errors.New("too many calls") + } + } + + opts := Options[string]{ + Timeout: 2 * time.Second, + Interval: WithBackoffDelay(backoff.FixedDelay(fastFixedInterval)), + SuccessStates: []Status{"AVAILABLE"}, + TransitionalStates: []Status{"CREATING"}, + } + + result, err := WaitForStatus(ctx, fetch, opts) + if err != nil { + t.Fatalf("WaitForStatus() error = %v", err) + } + + if result.Status != "AVAILABLE" { + t.Errorf("result.Status = %q, want %q", result.Status, "AVAILABLE") + } + if result.Value != "success" { + t.Errorf("result.Value = %q, want %q", result.Value, "success") + } + if callCount.Load() != 2 { + t.Errorf("expected 2 fetch calls, got %d", callCount.Load()) + } +} From 1cac7735df6058a9c1b47e904687adff8fc5c3d8 Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Thu, 2 Oct 2025 18:28:33 -0400 Subject: [PATCH 16/17] Switch codebuild to backoff --- internal/service/cloudfront/create_invalidation_action.go | 2 ++ internal/service/codebuild/start_build_action.go | 7 +++++-- internal/service/ec2/ec2_stop_instance_action.go | 2 ++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/internal/service/cloudfront/create_invalidation_action.go b/internal/service/cloudfront/create_invalidation_action.go index 61e81a20265f..1f8275ce5dd1 100644 --- a/internal/service/cloudfront/create_invalidation_action.go +++ b/internal/service/cloudfront/create_invalidation_action.go @@ -218,6 +218,8 @@ func (a *createInvalidationAction) Invoke(ctx context.Context, req action.Invoke }) // Wait for invalidation to complete with periodic progress updates using actionwait + // Use fixed interval since CloudFront invalidations have predictable timing and + // don't benefit from exponential backoff - status changes are infrequent and consistent _, err = actionwait.WaitForStatus(ctx, func(ctx context.Context) (actionwait.FetchResult[struct{}], error) { input := cloudfront.GetInvalidationInput{ DistributionId: aws.String(distributionID), diff --git a/internal/service/codebuild/start_build_action.go b/internal/service/codebuild/start_build_action.go index 57e85e714ab5..a67229b96030 100644 --- a/internal/service/codebuild/start_build_action.go +++ b/internal/service/codebuild/start_build_action.go @@ -17,6 +17,7 @@ import ( "github.com/hashicorp/terraform-plugin-framework/types" "github.com/hashicorp/terraform-plugin-log/tflog" "github.com/hashicorp/terraform-provider-aws/internal/actionwait" + "github.com/hashicorp/terraform-provider-aws/internal/backoff" "github.com/hashicorp/terraform-provider-aws/internal/framework" fwflex "github.com/hashicorp/terraform-provider-aws/internal/framework/flex" fwtypes "github.com/hashicorp/terraform-provider-aws/internal/framework/types" @@ -134,7 +135,9 @@ func (a *startBuildAction) Invoke(ctx context.Context, req action.InvokeRequest, Message: "Build started, waiting for completion...", }) - // Poll for build completion using actionwait + // Poll for build completion using actionwait with backoff strategy + // Use backoff since builds can take a long time and status changes less frequently + // as the build progresses - start with frequent polling then back off _, err = actionwait.WaitForStatus(ctx, func(ctx context.Context) (actionwait.FetchResult[*awstypes.Build], error) { input := codebuild.BatchGetBuildsInput{Ids: []string{buildID}} batch, berr := conn.BatchGetBuilds(ctx, &input) @@ -148,7 +151,7 @@ func (a *startBuildAction) Invoke(ctx context.Context, req action.InvokeRequest, return actionwait.FetchResult[*awstypes.Build]{Status: actionwait.Status(b.BuildStatus), Value: &b}, nil }, actionwait.Options[*awstypes.Build]{ Timeout: timeout, - Interval: actionwait.FixedInterval(actionwait.DefaultPollInterval), + Interval: actionwait.WithBackoffDelay(backoff.DefaultSDKv2HelperRetryCompatibleDelay()), ProgressInterval: 2 * time.Minute, SuccessStates: []actionwait.Status{actionwait.Status(awstypes.StatusTypeSucceeded)}, TransitionalStates: []actionwait.Status{ diff --git a/internal/service/ec2/ec2_stop_instance_action.go b/internal/service/ec2/ec2_stop_instance_action.go index d4b348c90a96..ae234daac1ac 100644 --- a/internal/service/ec2/ec2_stop_instance_action.go +++ b/internal/service/ec2/ec2_stop_instance_action.go @@ -185,6 +185,8 @@ func (a *stopInstanceAction) Invoke(ctx context.Context, req action.InvokeReques } // Wait for instance to stop with periodic progress updates using actionwait + // Use fixed interval since EC2 instance state transitions are predictable and + // relatively quick - consistent polling every 10s is optimal for this operation _, err = actionwait.WaitForStatus(ctx, func(ctx context.Context) (actionwait.FetchResult[struct{}], error) { instance, derr := findInstanceByID(ctx, conn, instanceID) if derr != nil { From a9c99d888831ce041ee0a8d3b0c57f563bc69426 Mon Sep 17 00:00:00 2001 From: Dirk Avery Date: Fri, 3 Oct 2025 10:37:53 -0400 Subject: [PATCH 17/17] Remove explicit implementation check --- internal/actionwait/wait_test.go | 3 --- 1 file changed, 3 deletions(-) diff --git a/internal/actionwait/wait_test.go b/internal/actionwait/wait_test.go index 02d3532d7036..22096adb77f3 100644 --- a/internal/actionwait/wait_test.go +++ b/internal/actionwait/wait_test.go @@ -371,9 +371,6 @@ func TestWithBackoffDelay(t *testing.T) { delay := backoff.FixedDelay(50 * time.Millisecond) interval := WithBackoffDelay(delay) - // Verify it implements IntervalStrategy - var _ IntervalStrategy = interval - // Test that it wraps the delay correctly if got := interval.NextPoll(0); got != 0 { t.Errorf("NextPoll(0) = %v, want 0", got)