diff --git a/pkg/epp/flowcontrol/controller/controller.go b/pkg/epp/flowcontrol/controller/controller.go index aeb0bdd87..c1f9a3004 100644 --- a/pkg/epp/flowcontrol/controller/controller.go +++ b/pkg/epp/flowcontrol/controller/controller.go @@ -32,6 +32,7 @@ import ( "time" "github.com/go-logr/logr" + k8srand "k8s.io/apimachinery/pkg/util/rand" "k8s.io/utils/clock" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" @@ -57,10 +58,11 @@ type shardProcessor interface { // shardProcessorFactory defines the signature for a function that creates a `shardProcessor`. // This enables dependency injection for testing. type shardProcessorFactory func( + ctx context.Context, shard contracts.RegistryShard, saturationDetector contracts.SaturationDetector, - clock clock.Clock, - expiryCleanupInterval time.Duration, + clock clock.WithTicker, + cleanupSweepInterval time.Duration, enqueueChannelBufferSize int, logger logr.Logger, ) shardProcessor @@ -79,6 +81,14 @@ type managedWorker struct { // // The controller's `Run` loop executes periodically, acting as a garbage collector that keeps the pool of running // workers synchronized with the dynamic shard topology of the `FlowRegistry`. +// +// Request Lifecycle Management: +// +// 1. Asynchronous Finalization (Controller-Owned): The Controller actively monitors the request Context +// (TTL/Cancellation) in EnqueueAndWait. If the Context expires, the Controller immediately Finalizes the item and +// unblocks the caller. +// 2. Synchronous Finalization (Processor-Owned): The Processor handles Dispatch, Capacity Rejection, and Shutdown. +// 3. Cleanup (Processor-Owned): The Processor periodically sweeps externally finalized items to reclaim capacity. type FlowController struct { // --- Immutable dependencies (set at construction) --- @@ -129,18 +139,20 @@ func NewFlowController( // Use the real shard processor implementation by default. fc.shardProcessorFactory = func( + ctx context.Context, shard contracts.RegistryShard, saturationDetector contracts.SaturationDetector, - clock clock.Clock, - expiryCleanupInterval time.Duration, + clock clock.WithTicker, + cleanupSweepInterval time.Duration, enqueueChannelBufferSize int, logger logr.Logger, ) shardProcessor { return internal.NewShardProcessor( + ctx, shard, saturationDetector, clock, - expiryCleanupInterval, + cleanupSweepInterval, enqueueChannelBufferSize, logger) } @@ -189,63 +201,162 @@ func (fc *FlowController) run(ctx context.Context) { // stack and its `context.Context`. The system only needs to signal this specific goroutine to unblock it. // - Direct Backpressure: If queues are full, `EnqueueAndWait` returns an error immediately, providing direct // backpressure to the caller. -func (fc *FlowController) EnqueueAndWait(req types.FlowControlRequest) (types.QueueOutcome, error) { +func (fc *FlowController) EnqueueAndWait( + ctx context.Context, + req types.FlowControlRequest, +) (types.QueueOutcome, error) { if req == nil { return types.QueueOutcomeRejectedOther, errors.New("request cannot be nil") } - effectiveTTL := req.InitialEffectiveTTL() - if effectiveTTL <= 0 { - effectiveTTL = fc.config.DefaultRequestTTL - } - enqueueTime := fc.clock.Now() + // 1. Create the derived context that governs this request's lifecycle (Parent Cancellation + TTL). + reqCtx, cancel, enqueueTime := fc.createRequestContext(ctx, req) + defer cancel() + + // 2. Enter the distribution loop to find a home for the request. + // This loop is responsible for retrying on ErrShardDraining. for { - select { + + select { // Non-blocking check on controller lifecycle. case <-fc.parentCtx.Done(): return types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, types.ErrFlowControllerNotRunning) default: - // The controller is running, proceed. } - // We must create a fresh `FlowItem` on each attempt since finalization is idempotent. - // However, we use the original, preserved `enqueueTime`. - item := internal.NewItem(req, effectiveTTL, enqueueTime) - if outcome, err := fc.distributeRequest(item); err != nil { - return outcome, fmt.Errorf("%w: %w", types.ErrRejected, err) + // Attempt to distribute the request once. + item, err := fc.tryDistribution(reqCtx, req, enqueueTime) + if err != nil { + // Distribution failed terminally (e.g., no shards, context cancelled during blocking submit). + // The item has already been finalized by tryDistribution. + finalState := item.FinalState() + return finalState.Outcome, finalState.Err } - // Block until the request is finalized (dispatched, rejected, or evicted). - // The finalization logic internally monitors for context cancellation and TTL expiry. - finalState := <-item.Done() - if errors.Is(finalState.Err, contracts.ErrShardDraining) { - fc.logger.V(logutil.DEBUG).Info("Shard is draining, retrying request", "requestID", req.ID()) - // Benign race with the chosen `contracts.RegistryShard` becoming Draining post selection but before the item was - // enqueued into its respective `contracts.ManagedQueue`. Simply try again. + // Distribution was successful; ownership of the item has been transferred to a processor. + // Now, we block here in awaitFinalization until the request is finalized by either the processor (e.g., dispatched, + // rejected) or the controller itself (e.g., caller's context cancelled/TTL expired). + outcome, err := fc.awaitFinalization(reqCtx, item) + if errors.Is(err, contracts.ErrShardDraining) { + // This is a benign race condition where the chosen shard started draining after acceptance. + fc.logger.V(logutil.DEBUG).Info("Selected shard is Draining, retrying request distribution", + "flowKey", req.FlowKey(), "requestID", req.ID()) + // Introduce a small, randomized delay (1-10ms) to prevent tight spinning loops and thundering herds during retry + // scenarios (e.g., shard draining) + // TODO: Replace this with a more sophisticated backoff strategy when our data parallelism story matures. + // For now, this is more than sufficient. + jitterMs := k8srand.Intn(10) + 1 + fc.clock.Sleep(time.Duration(jitterMs) * time.Millisecond) continue } + // The outcome is terminal (Dispatched, Evicted, or a non-retriable rejection). + return outcome, err + } +} + +var errNoShards = errors.New("no viable active shards available") + +// tryDistribution handles a single attempt to select a shard and submit a request. +// If this function returns an error, it guarantees that the provided `item` has been finalized. +func (fc *FlowController) tryDistribution( + reqCtx context.Context, + req types.FlowControlRequest, + enqueueTime time.Time, +) (*internal.FlowItem, error) { + // Calculate effective TTL for item initialization (reqCtx is the enforcement mechanism). + effectiveTTL := fc.config.DefaultRequestTTL + if deadline, ok := reqCtx.Deadline(); ok { + if ttl := deadline.Sub(enqueueTime); ttl > 0 { + effectiveTTL = ttl + } + } + + // We must create a fresh FlowItem on each attempt as finalization is per-lifecycle. + item := internal.NewItem(req, effectiveTTL, enqueueTime) + + candidates, err := fc.selectDistributionCandidates(item.OriginalRequest().FlowKey()) + if err != nil { + outcome := types.QueueOutcomeRejectedOther + if errors.Is(err, errNoShards) { + outcome = types.QueueOutcomeRejectedCapacity + } + finalErr := fmt.Errorf("%w: request not accepted: %w", types.ErrRejected, err) + item.FinalizeWithOutcome(outcome, finalErr) + return item, finalErr + } + + outcome, err := fc.distributeRequest(reqCtx, item, candidates) + if err == nil { + // Success: Ownership of the item has been transferred to the processor. + return item, nil + } + + // For any distribution error, the controller retains ownership and must finalize the item. + var finalErr error + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + // We propagate the original context error here, EnqueueAndWait will rely on item.FinalState().Err. + finalErr = err + item.Finalize(context.Cause(reqCtx)) + } else { // e.g., + finalErr = fmt.Errorf("%w: request not accepted: %w", types.ErrRejected, err) + item.FinalizeWithOutcome(outcome, finalErr) + } + return item, finalErr +} + +// awaitFinalization blocks until an item is finalized, either by the processor (synchronously) or by the controller +// itself due to context expiry (asynchronously). +func (fc *FlowController) awaitFinalization( + reqCtx context.Context, + item *internal.FlowItem, +) (types.QueueOutcome, error) { + select { + case <-reqCtx.Done(): + // Asynchronous Finalization (Controller-initiated): + // The request Context expired (Cancellation/TTL) while the item was being processed. + cause := context.Cause(reqCtx) + item.Finalize(cause) + + // The processor will eventually discard this "zombie" item during its cleanup sweep. + finalState := item.FinalState() + return finalState.Outcome, finalState.Err + + case finalState := <-item.Done(): + // Synchronous Finalization (Processor-initiated): + // The processor finalized the item (Dispatch, Reject, Shutdown). return finalState.Outcome, finalState.Err } } -// distributeRequest implements a flow-aware, two-phase "Join-Shortest-Queue-by-Bytes" (JSQ-Bytes) distribution strategy -// with graceful backpressure. It selects the optimal worker for a given item and attempts to submit it. -// -// The algorithm operates as follows: -// 1. Candidate Selection: It identifies all Active shards for the item's flow and ranks them by the current byte size -// of that flow's queue, from least to most loaded. -// 2. Phase 1 (Non-blocking Fast Failover): It iterates through the ranked candidates and attempts a non-blocking -// submission. The first successful submission wins. -// 3. Phase 2 (Blocking Fallback): If all non-blocking attempts fail, it performs a single blocking submission to the -// least-loaded candidate, providing backpressure. -func (fc *FlowController) distributeRequest(item *internal.FlowItem) (types.QueueOutcome, error) { - key := item.OriginalRequest().FlowKey() - reqID := item.OriginalRequest().ID() - type candidate struct { - processor shardProcessor - shardID string - byteSize uint64 +// createRequestContext derives the context that governs a request's lifecycle, enforcing the TTL deadline. +func (fc *FlowController) createRequestContext( + ctx context.Context, + req types.FlowControlRequest, +) (context.Context, context.CancelFunc, time.Time) { + enqueueTime := fc.clock.Now() + effectiveTTL := req.InitialEffectiveTTL() + if effectiveTTL <= 0 { + effectiveTTL = fc.config.DefaultRequestTTL } + + if effectiveTTL > 0 { + reqCtx, cancel := context.WithDeadlineCause(ctx, enqueueTime.Add(effectiveTTL), types.ErrTTLExpired) + return reqCtx, cancel, enqueueTime + } + reqCtx, cancel := context.WithCancel(ctx) + return reqCtx, cancel, enqueueTime +} + +// candidate holds the information needed to evaluate a shard as a potential target for a request. +type candidate struct { + processor shardProcessor + shardID string + byteSize uint64 +} + +// selectDistributionCandidates identifies all Active shards for the item's flow and ranks them by the current byte size +// of that flow's queue, from least to most loaded. +func (fc *FlowController) selectDistributionCandidates(key types.FlowKey) ([]candidate, error) { var candidates []candidate err := fc.registry.WithConnection(key, func(conn contracts.ActiveFlowConnection) error { shards := conn.ActiveShards() @@ -262,41 +373,58 @@ func (fc *FlowController) distributeRequest(item *internal.FlowItem) (types.Queu return nil }) if err != nil { - return types.QueueOutcomeRejectedOther, fmt.Errorf("failed to acquire lease for request %q (flow %s): %w", - reqID, key, err) + return nil, fmt.Errorf("failed to acquire lease for flow %s: %w", key, err) } if len(candidates) == 0 { - return types.QueueOutcomeRejectedCapacity, fmt.Errorf("no viable Active shards available for request %q (flow %s)", - reqID, key) + return nil, fmt.Errorf("%w for flow %s", errNoShards, key) } slices.SortFunc(candidates, func(a, b candidate) int { return cmp.Compare(a.byteSize, b.byteSize) }) - // --- Phase 1: Fast, non-blocking failover attempt --- + return candidates, nil +} + +// distributeRequest implements a flow-aware, two-phase "Join-Shortest-Queue-by-Bytes" (JSQ-Bytes) distribution strategy +// with graceful backpressure. It attempts to submit an item to the best-ranked candidate from the provided list. +// +// The algorithm operates as follows: +// 1. Phase 1 (Non-blocking Fast Failover): It iterates through the ranked candidates and attempts a non-blocking +// submission. The first successful submission wins. +// 2. Phase 2 (Blocking Fallback): If all non-blocking attempts fail, it performs a single blocking submission to the +// least-loaded candidate, providing backpressure. +// +// The provided context (ctx) is used for the blocking submission phase (SubmitOrBlock). +// +// Ownership Contract: +// - Returns nil: Success. Ownership transferred to Processor. +// - Returns error: Failure (Context expiry, shutdown,, etc.). +// Ownership retained by Controller. The Controller MUST finalize the item. +func (fc *FlowController) distributeRequest( + ctx context.Context, + item *internal.FlowItem, + candidates []candidate, +) (types.QueueOutcome, error) { + reqID := item.OriginalRequest().ID() for _, c := range candidates { if err := c.processor.Submit(item); err == nil { - return types.QueueOutcomeNotYetFinalized, nil // Success + return types.QueueOutcomeNotYetFinalized, nil } - fc.logger.V(logutil.DEBUG).Info("Processor busy during fast failover, trying next candidate", + fc.logger.V(logutil.TRACE).Info("Processor busy during fast failover, trying next candidate", "shardID", c.shardID, "requestID", reqID) } - // --- Phase 2: All processors busy. Attempt a single blocking send to the best candidate. --- + // All processors are busy. Attempt a single blocking submission to the least-loaded candidate. bestCandidate := candidates[0] - fc.logger.V(logutil.DEBUG).Info("All processors busy, attempting blocking submit to best candidate", - "shardID", bestCandidate.shardID, "requestID", reqID, "queueByteSize", bestCandidate.byteSize) - - err = bestCandidate.processor.SubmitOrBlock(item.OriginalRequest().Context(), item) + fc.logger.V(logutil.TRACE).Info("All processors busy, attempting blocking submit to best candidate", + "shardID", bestCandidate.shardID, "requestID", reqID) + err := bestCandidate.processor.SubmitOrBlock(ctx, item) if err != nil { - // If even the blocking attempt fails (e.g., context cancelled or processor shut down), the request is definitively - // rejected. - return types.QueueOutcomeRejectedCapacity, fmt.Errorf( - "all viable shard processors are at capacity for request %q (flow %s): %w", reqID, key, err) + return types.QueueOutcomeRejectedOther, fmt.Errorf("%w: request not accepted: %w", types.ErrRejected, err) } - return types.QueueOutcomeNotYetFinalized, nil + return types.QueueOutcomeNotYetFinalized, nil // Success, ownership transferred. } // getOrStartWorker implements the lazy-loading and startup of shard processors. @@ -311,6 +439,7 @@ func (fc *FlowController) getOrStartWorker(shard contracts.RegistryShard) *manag // Construct a new worker, but do not start its processor goroutine yet. processorCtx, cancel := context.WithCancel(fc.parentCtx) processor := fc.shardProcessorFactory( + processorCtx, shard, fc.saturationDetector, fc.clock, diff --git a/pkg/epp/flowcontrol/controller/controller_test.go b/pkg/epp/flowcontrol/controller/controller_test.go index c832e2d6d..74802f2df 100644 --- a/pkg/epp/flowcontrol/controller/controller_test.go +++ b/pkg/epp/flowcontrol/controller/controller_test.go @@ -14,6 +14,12 @@ See the License for the specific language governing permissions and limitations under the License. */ +// Note on Time-Based Lifecycle Tests: +// Tests validating the controller's handling of request TTLs (e.g., OnReqCtxTimeout*) rely on real-time timers +// (context.WithDeadline). The injected testclock.FakeClock is used to control the timing of internal loops (like +// reconciliation), but it cannot manipulate the timers used by the standard context package. Therefore, these specific +// tests use time.Sleep or assertions on real-time durations. + package controller import ( @@ -68,23 +74,40 @@ func withShardProcessorFactory(factory shardProcessorFactory) flowControllerOpti // testHarness holds the `FlowController` and its dependencies under test. type testHarness struct { - fc *FlowController - cfg Config - mockRegistry *mockRegistryClient - mockDetector *mocks.MockSaturationDetector + fc *FlowController + cfg Config + // clock is the clock interface used by the controller. + clock clock.WithTicker + mockRegistry *mockRegistryClient + mockDetector *mocks.MockSaturationDetector + // mockClock provides access to FakeClock methods (Step, HasWaiters) if and only if the underlying clock is a + // FakeClock. mockClock *testclock.FakeClock mockProcessorFactory *mockShardProcessorFactory } // newUnitHarness creates a test environment with a mock processor factory, suitable for focused unit tests of the -// controller's logic. It starts the controller's run loop and returns a cancel function to stop it. +// controller's logic. It starts the controller's run loop using the provided context for lifecycle management. func newUnitHarness(t *testing.T, ctx context.Context, cfg Config, registry *mockRegistryClient) *testHarness { t.Helper() mockDetector := &mocks.MockSaturationDetector{} + + // Initialize the FakeClock with the current system time. + // The controller implementation uses the injected clock to calculate the deadline timestamp,vbut uses the standard + // context.WithDeadline (which relies on the system clock) to enforce it. + // If the FakeClock's time is far from the system time, deadlines calculated based on the FakeClockvmight already be + // expired according to the system clock, causing immediate TTL failures. mockClock := testclock.NewFakeClock(time.Now()) + mockProcessorFactory := &mockShardProcessorFactory{ processors: make(map[string]*mockShardProcessor), } + + // Default the registry if nil, simplifying tests that don't focus on registry interaction. + if registry == nil { + registry = &mockRegistryClient{} + } + opts := []flowControllerOption{ withRegistryClient(registry), withClock(mockClock), @@ -96,6 +119,7 @@ func newUnitHarness(t *testing.T, ctx context.Context, cfg Config, registry *moc h := &testHarness{ fc: fc, cfg: cfg, + clock: mockClock, mockRegistry: registry, mockDetector: mockDetector, mockClock: mockClock, @@ -109,7 +133,13 @@ func newUnitHarness(t *testing.T, ctx context.Context, cfg Config, registry *moc func newIntegrationHarness(t *testing.T, ctx context.Context, cfg Config, registry *mockRegistryClient) *testHarness { t.Helper() mockDetector := &mocks.MockSaturationDetector{} + // Align FakeClock with system time. See explanation in newUnitHarness. + mockClock := testclock.NewFakeClock(time.Now()) + if registry == nil { + registry = &mockRegistryClient{} + } + opts := []flowControllerOption{ withRegistryClient(registry), withClock(mockClock), @@ -120,6 +150,7 @@ func newIntegrationHarness(t *testing.T, ctx context.Context, cfg Config, regist h := &testHarness{ fc: fc, cfg: cfg, + clock: mockClock, mockRegistry: registry, mockDetector: mockDetector, mockClock: mockClock, @@ -166,9 +197,11 @@ func (m *mockRegistryClient) ShardStats() []contracts.ShardStats { type mockShardProcessor struct { SubmitFunc func(item *internal.FlowItem) error SubmitOrBlockFunc func(ctx context.Context, item *internal.FlowItem) error - runCtx context.Context - runCtxMu sync.RWMutex - runStarted chan struct{} + // runCtx captures the context provided to the Run method for lifecycle assertions. + runCtx context.Context + runCtxMu sync.RWMutex + // runStarted is closed when the Run method is called, allowing tests to synchronize with worker startup. + runStarted chan struct{} } func (m *mockShardProcessor) Submit(item *internal.FlowItem) error { @@ -192,9 +225,11 @@ func (m *mockShardProcessor) Run(ctx context.Context) { if m.runStarted != nil { close(m.runStarted) } + // Block until the context is cancelled, simulating a running worker. <-ctx.Done() } +// Context returns the context captured during the Run method call. func (m *mockShardProcessor) Context() context.Context { m.runCtxMu.RLock() defer m.runCtxMu.RUnlock() @@ -207,10 +242,12 @@ type mockShardProcessorFactory struct { processors map[string]*mockShardProcessor } +// new is the factory function conforming to the `shardProcessorFactory` signature. func (f *mockShardProcessorFactory) new( + _ context.Context, // The factory does not use the lifecycle context; it's passed to the processor's Run method later. shard contracts.RegistryShard, _ contracts.SaturationDetector, - _ clock.Clock, + _ clock.WithTicker, _ time.Duration, _ int, _ logr.Logger, @@ -220,7 +257,7 @@ func (f *mockShardProcessorFactory) new( if proc, ok := f.processors[shard.ID()]; ok { return proc } - // Return a default mock processor if one is not registered. + // Return a default mock processor if one is not explicitly registered by the test. return &mockShardProcessor{} } @@ -262,9 +299,8 @@ func (b *mockShardBuilder) build() contracts.RegistryShard { var defaultFlowKey = types.FlowKey{ID: "test-flow", Priority: 100} -func newTestRequest(ctx context.Context, key types.FlowKey) *typesmocks.MockFlowControlRequest { +func newTestRequest(key types.FlowKey) *typesmocks.MockFlowControlRequest { return &typesmocks.MockFlowControlRequest{ - Ctx: ctx, FlowKeyV: key, ByteSizeV: 100, IDV: "req-" + key.ID, @@ -273,6 +309,8 @@ func newTestRequest(ctx context.Context, key types.FlowKey) *typesmocks.MockFlow // --- Test Cases --- +// TestFlowController_EnqueueAndWait covers the primary API entry point, focusing on validation, distribution logic, +// retries, and the request lifecycle (including post-distribution cancellation/timeout). func TestFlowController_EnqueueAndWait(t *testing.T) { t.Parallel() @@ -281,37 +319,80 @@ func TestFlowController_EnqueueAndWait(t *testing.T) { t.Run("OnNilRequest", func(t *testing.T) { t.Parallel() - h := newUnitHarness(t, t.Context(), Config{}, &mockRegistryClient{}) + h := newUnitHarness(t, t.Context(), Config{}, nil) - outcome, err := h.fc.EnqueueAndWait(nil) + outcome, err := h.fc.EnqueueAndWait(context.Background(), nil) require.Error(t, err, "EnqueueAndWait must reject a nil request") - assert.Equal(t, "request cannot be nil", err.Error()) - assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, "outcome should be QueueOutcomeRejectedOther") + assert.Equal(t, "request cannot be nil", err.Error(), "error message must be specific") + assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, + "outcome should be QueueOutcomeRejectedOther for invalid inputs") }) + t.Run("OnReqCtxExpiredBeforeDistribution", func(t *testing.T) { + t.Parallel() + // Test that if the request context provided to EnqueueAndWait is already expired, it returns immediately. + h := newUnitHarness(t, t.Context(), Config{DefaultRequestTTL: 1 * time.Minute}, nil) + + // Configure registry to return a shard. + shardA := newMockShard("shard-A").build() + h.mockRegistry.WithConnectionFunc = func(_ types.FlowKey, fn func(_ contracts.ActiveFlowConnection) error) error { + return fn(&mockActiveFlowConnection{ActiveShardsV: []contracts.RegistryShard{shardA}}) + } + // Configure processor to block until context expiry. + h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{ + SubmitFunc: func(_ *internal.FlowItem) error { return internal.ErrProcessorBusy }, + SubmitOrBlockFunc: func(ctx context.Context, _ *internal.FlowItem) error { + <-ctx.Done() // Wait for the context to be done. + return context.Cause(ctx) // Return the cause. + }, + } + + req := newTestRequest(defaultFlowKey) + // Use a context with a deadline in the past. + reqCtx, cancel := context.WithDeadlineCause( + context.Background(), + h.clock.Now().Add(-1*time.Second), + types.ErrTTLExpired) + defer cancel() + + outcome, err := h.fc.EnqueueAndWait(reqCtx, req) + require.Error(t, err, "EnqueueAndWait must fail if request context deadline is exceeded") + assert.ErrorIs(t, err, types.ErrRejected, "error should wrap ErrRejected") + assert.ErrorIs(t, err, types.ErrTTLExpired, "error should wrap types.ErrTTLExpired from the context cause") + assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, "outcome should be QueueOutcomeRejectedOther") + }) t.Run("OnControllerShutdown", func(t *testing.T) { t.Parallel() + // Create a context specifically for the controller's lifecycle. ctx, cancel := context.WithCancel(t.Context()) - h := newUnitHarness(t, ctx, Config{}, &mockRegistryClient{}) + h := newUnitHarness(t, ctx, Config{}, nil) cancel() // Immediately stop the controller. - req := newTestRequest(context.Background(), defaultFlowKey) - outcome, err := h.fc.EnqueueAndWait(req) + // Wait for the controller's run loop and all workers (none in this case) to exit. + // We need to wait because the shutdown process is asynchronous. + h.fc.wg.Wait() + + req := newTestRequest(defaultFlowKey) + // The request context is valid, but the controller itself is stopped. + outcome, err := h.fc.EnqueueAndWait(context.Background(), req) require.Error(t, err, "EnqueueAndWait must reject requests if controller is not running") assert.ErrorIs(t, err, types.ErrRejected, "error should wrap ErrRejected") assert.ErrorIs(t, err, types.ErrFlowControllerNotRunning, "error should wrap ErrFlowControllerNotRunning") - assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, "outcome should be QueueOutcomeRejectedOther") + assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, + "outcome should be QueueOutcomeRejectedOther on shutdown") }) t.Run("OnNoShardsAvailable", func(t *testing.T) { t.Parallel() - h := newUnitHarness(t, t.Context(), Config{}, &mockRegistryClient{}) + // The default mockRegistryClient returns an empty list of ActiveShards. + h := newUnitHarness(t, t.Context(), Config{}, nil) - req := newTestRequest(context.Background(), defaultFlowKey) - outcome, err := h.fc.EnqueueAndWait(req) + req := newTestRequest(defaultFlowKey) + outcome, err := h.fc.EnqueueAndWait(context.Background(), req) require.Error(t, err, "EnqueueAndWait must reject requests if no shards are available") assert.ErrorIs(t, err, types.ErrRejected, "error should wrap ErrRejected") - assert.Equal(t, types.QueueOutcomeRejectedCapacity, outcome, "outcome should be QueueOutcomeRejectedCapacity") + assert.Equal(t, types.QueueOutcomeRejectedCapacity, outcome, + "outcome should be QueueOutcomeRejectedCapacity when no shards exist for the flow") }) t.Run("OnRegistryConnectionError", func(t *testing.T) { @@ -319,7 +400,8 @@ func TestFlowController_EnqueueAndWait(t *testing.T) { mockRegistry := &mockRegistryClient{} h := newUnitHarness(t, t.Context(), Config{}, mockRegistry) - expectedErr := errors.New("connection failed") + expectedErr := errors.New("simulated connection failure") + // Configure the registry to fail when attempting to retrieve ActiveFlowConnection. mockRegistry.WithConnectionFunc = func( _ types.FlowKey, _ func(conn contracts.ActiveFlowConnection) error, @@ -327,23 +409,27 @@ func TestFlowController_EnqueueAndWait(t *testing.T) { return expectedErr } - req := newTestRequest(context.Background(), defaultFlowKey) - outcome, err := h.fc.EnqueueAndWait(req) + req := newTestRequest(defaultFlowKey) + outcome, err := h.fc.EnqueueAndWait(context.Background(), req) require.Error(t, err, "EnqueueAndWait must reject requests if registry connection fails") assert.ErrorIs(t, err, types.ErrRejected, "error should wrap ErrRejected") assert.ErrorIs(t, err, expectedErr, "error should wrap the underlying connection error") - assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, "outcome should be QueueOutcomeRejectedOther") + assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, + "outcome should be QueueOutcomeRejectedOther for transient registry errors") }) + // This test validates the documented invariant handling in distributeRequest. t.Run("PanicsOnManagedQueueError", func(t *testing.T) { t.Parallel() mockRegistry := &mockRegistryClient{} h := newUnitHarness(t, t.Context(), Config{}, mockRegistry) + // Create a faulty shard that successfully leases the flow but fails to return the + // ManagedQueue. faultyShard := &mocks.MockRegistryShard{ IDFunc: func() string { return "faulty-shard" }, ManagedQueueFunc: func(_ types.FlowKey) (contracts.ManagedQueue, error) { - return nil, errors.New("queue retrieval failed") + return nil, errors.New("invariant violation: queue retrieval failed") }, } mockRegistry.WithConnectionFunc = func( @@ -353,22 +439,30 @@ func TestFlowController_EnqueueAndWait(t *testing.T) { return fn(&mockActiveFlowConnection{ActiveShardsV: []contracts.RegistryShard{faultyShard}}) } - req := newTestRequest(context.Background(), defaultFlowKey) + req := newTestRequest(defaultFlowKey) assert.Panics(t, func() { - _, _ = h.fc.EnqueueAndWait(req) - }, "EnqueueAndWait did not panic as expected on a ManagedQueue error") + _, _ = h.fc.EnqueueAndWait(context.Background(), req) + }, "EnqueueAndWait must panic when a registry implementation violates the ManagedQueue contract") }) }) + // Distribution tests validate the JSQ-Bytes algorithm, the two-phase submission strategy, and error handling during + // the handoff, including time-based failures during blocking fallback. t.Run("Distribution", func(t *testing.T) { t.Parallel() + // Define a long default TTL to prevent unexpected timeouts unless a test case explicitly sets a shorter one. + const defaultTestTTL = 5 * time.Second + testCases := []struct { name string shards []contracts.RegistryShard setupProcessors func(t *testing.T, h *testHarness) + // requestTTL overrides the default TTL for time-sensitive tests. + requestTTL time.Duration expectedOutcome types.QueueOutcome expectErr bool + expectErrIs error }{ { name: "SubmitSucceeds_NonBlocking_WithSingleActiveShard", @@ -376,7 +470,8 @@ func TestFlowController_EnqueueAndWait(t *testing.T) { setupProcessors: func(t *testing.T, h *testHarness) { h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{ SubmitFunc: func(item *internal.FlowItem) error { - go item.Finalize(types.QueueOutcomeDispatched, nil) + // Simulate asynchronous processing and successful dispatch. + go item.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil) return nil }, } @@ -386,19 +481,20 @@ func TestFlowController_EnqueueAndWait(t *testing.T) { { name: "DistributesToLeastLoadedShard_WithMultipleActiveShards", shards: []contracts.RegistryShard{ - newMockShard("shard-A").withByteSize(1000).build(), - newMockShard("shard-B").withByteSize(100).build(), + newMockShard("shard-A").withByteSize(1000).build(), // More loaded + newMockShard("shard-B").withByteSize(100).build(), // Least loaded }, setupProcessors: func(t *testing.T, h *testHarness) { h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{ SubmitFunc: func(_ *internal.FlowItem) error { - t.Error("Submit was called on the more loaded shard (shard-A), which is incorrect") + t.Error("Submit was called on the more loaded shard (shard-A); JSQ-Bytes algorithm failed") return internal.ErrProcessorBusy }, } h.mockProcessorFactory.processors["shard-B"] = &mockShardProcessor{ SubmitFunc: func(item *internal.FlowItem) error { - go item.Finalize(types.QueueOutcomeDispatched, nil) + item.SetHandle(&typesmocks.MockQueueItemHandle{}) + go item.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil) return nil }, } @@ -412,13 +508,16 @@ func TestFlowController_EnqueueAndWait(t *testing.T) { newMockShard("shard-B").withByteSize(100).build(), }, setupProcessors: func(t *testing.T, h *testHarness) { + // Both processors reject the initial non-blocking Submit. h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{ SubmitFunc: func(_ *internal.FlowItem) error { return internal.ErrProcessorBusy }, } + // Shard-B is the least loaded, so it should receive the blocking fallback (SubmitOrBlock). h.mockProcessorFactory.processors["shard-B"] = &mockShardProcessor{ SubmitFunc: func(_ *internal.FlowItem) error { return internal.ErrProcessorBusy }, SubmitOrBlockFunc: func(_ context.Context, item *internal.FlowItem) error { - go item.Finalize(types.QueueOutcomeDispatched, nil) + // The blocking call succeeds. + go item.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil) return nil }, } @@ -426,24 +525,78 @@ func TestFlowController_EnqueueAndWait(t *testing.T) { expectedOutcome: types.QueueOutcomeDispatched, }, { - name: "Rejects_AfterBlocking_WithAllProcessorsRemainingBusy", + // Validates the scenario where the request's TTL expires while the controller is blocked waiting for capacity. + // NOTE: This relies on real time passing, as context.WithDeadline timers cannot be controlled by FakeClock. + name: "Rejects_AfterBlocking_WhenTTL_Expires", + shards: []contracts.RegistryShard{newMockShard("shard-A").build()}, + requestTTL: 50 * time.Millisecond, // Short TTL to keep the test fast. + setupProcessors: func(t *testing.T, h *testHarness) { + h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{ + // Reject the non-blocking attempt. + SubmitFunc: func(_ *internal.FlowItem) error { return internal.ErrProcessorBusy }, + // Block the fallback attempt until the context (carrying the TTL deadline) expires. + SubmitOrBlockFunc: func(ctx context.Context, _ *internal.FlowItem) error { + <-ctx.Done() + return ctx.Err() + }, + } + }, + // No runActions needed; we rely on the real-time timer to expire. + // When the blocking call fails due to context expiry, the outcome is RejectedOther. + expectedOutcome: types.QueueOutcomeRejectedOther, + expectErr: true, + // The error must reflect the specific cause of the context cancellation (ErrTTLExpired). + expectErrIs: types.ErrTTLExpired, + }, + { + name: "Rejects_OnProcessorShutdownDuringSubmit", shards: []contracts.RegistryShard{newMockShard("shard-A").build()}, setupProcessors: func(t *testing.T, h *testHarness) { h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{ - SubmitFunc: func(_ *internal.FlowItem) error { return internal.ErrProcessorBusy }, - SubmitOrBlockFunc: func(_ context.Context, _ *internal.FlowItem) error { return context.DeadlineExceeded }, + // Simulate the processor shutting down during the non-blocking handoff. + SubmitFunc: func(_ *internal.FlowItem) error { return types.ErrFlowControllerNotRunning }, + SubmitOrBlockFunc: func(_ context.Context, _ *internal.FlowItem) error { + return types.ErrFlowControllerNotRunning + }, } }, - expectedOutcome: types.QueueOutcomeRejectedCapacity, + expectedOutcome: types.QueueOutcomeRejectedOther, expectErr: true, + expectErrIs: types.ErrFlowControllerNotRunning, + }, + { + name: "Rejects_OnProcessorShutdownDuringSubmitOrBlock", + shards: []contracts.RegistryShard{newMockShard("shard-A").build()}, + setupProcessors: func(t *testing.T, h *testHarness) { + h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{ + SubmitFunc: func(_ *internal.FlowItem) error { return internal.ErrProcessorBusy }, + // Simulate the processor shutting down during the blocking handoff. + SubmitOrBlockFunc: func(_ context.Context, _ *internal.FlowItem) error { + return types.ErrFlowControllerNotRunning + }, + } + }, + expectedOutcome: types.QueueOutcomeRejectedOther, + expectErr: true, + expectErrIs: types.ErrFlowControllerNotRunning, }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() + + // Arrange mockRegistry := &mockRegistryClient{} - h := newUnitHarness(t, t.Context(), Config{}, mockRegistry) + + // Configure the harness with the appropriate TTL. + harnessConfig := Config{DefaultRequestTTL: defaultTestTTL} + if tc.requestTTL > 0 { + harnessConfig.DefaultRequestTTL = tc.requestTTL + } + h := newUnitHarness(t, t.Context(), harnessConfig, mockRegistry) + + // Configure the registry to return the specified shards. mockRegistry.WithConnectionFunc = func( _ types.FlowKey, fn func(conn contracts.ActiveFlowConnection) error, @@ -451,11 +604,33 @@ func TestFlowController_EnqueueAndWait(t *testing.T) { return fn(&mockActiveFlowConnection{ActiveShardsV: tc.shards}) } tc.setupProcessors(t, h) - outcome, err := h.fc.EnqueueAndWait(newTestRequest(context.Background(), defaultFlowKey)) + + // Act + var outcome types.QueueOutcome + var err error + + startTime := time.Now() // Capture real start time for duration checks. + // Use a background context for the parent; the request lifecycle is governed by the config/derived context. + outcome, err = h.fc.EnqueueAndWait(context.Background(), newTestRequest(defaultFlowKey)) + + // Assert if tc.expectErr { - require.Error(t, err, "expected an error but got nil") + require.Error(t, err, "expected an error during EnqueueAndWait but got nil") + assert.ErrorIs(t, err, tc.expectErrIs, "error should wrap the expected underlying cause") + // All failures during the distribution phase (capacity, timeout, shutdown) should result in a rejection. + assert.ErrorIs(t, err, types.ErrRejected, "rejection errors must wrap types.ErrRejected") + + // Specific assertion for real-time TTL tests. + if errors.Is(tc.expectErrIs, types.ErrTTLExpired) { + duration := time.Since(startTime) + // Ensure the test didn't return instantly. Use a tolerance for CI environments. + // This validates that the real-time wait actually occurred. + assert.GreaterOrEqual(t, duration, tc.requestTTL-30*time.Millisecond, + "EnqueueAndWait returned faster than the TTL allows, indicating the timer did not function correctly") + } + } else { - require.NoError(t, err, "expected no error but got: %v", err) + require.NoError(t, err, "expected no error during EnqueueAndWait but got: %v", err) } assert.Equal(t, tc.expectedOutcome, outcome, "outcome did not match expected value") }) @@ -465,6 +640,8 @@ func TestFlowController_EnqueueAndWait(t *testing.T) { t.Run("Retry", func(t *testing.T) { t.Parallel() + // This test specifically validates the behavior when the request context is cancelled externally while the + // controller is blocked in the SubmitOrBlock phase. t.Run("Rejects_OnRequestContextCancelledWhileBlocking", func(t *testing.T) { t.Parallel() mockRegistry := &mockRegistryClient{ @@ -477,22 +654,34 @@ func TestFlowController_EnqueueAndWait(t *testing.T) { }) }, } - h := newUnitHarness(t, t.Context(), Config{}, mockRegistry) + // Use a long TTL to ensure the failure is due to cancellation, not timeout. + h := newUnitHarness(t, t.Context(), Config{DefaultRequestTTL: 10 * time.Second}, mockRegistry) h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{ - SubmitFunc: func(_ *internal.FlowItem) error { return internal.ErrProcessorBusy }, - SubmitOrBlockFunc: func(ctx context.Context, _ *internal.FlowItem) error { <-ctx.Done(); return ctx.Err() }, + // Reject non-blocking attempt. + SubmitFunc: func(_ *internal.FlowItem) error { return internal.ErrProcessorBusy }, + // Block the fallback attempt until the context is cancelled. + SubmitOrBlockFunc: func(ctx context.Context, _ *internal.FlowItem) error { + <-ctx.Done() + return ctx.Err() + }, } + + // Create a cancellable context for the request. reqCtx, cancelReq := context.WithCancel(context.Background()) - go func() { time.Sleep(50 * time.Millisecond); cancelReq() }() + // Cancel the request shortly after starting the operation. + // We use real time sleep here as we are testing external cancellation signals interacting with the context. + go func() { time.Sleep(10 * time.Millisecond); cancelReq() }() - outcome, err := h.fc.EnqueueAndWait(newTestRequest(reqCtx, defaultFlowKey)) + outcome, err := h.fc.EnqueueAndWait(reqCtx, newTestRequest(defaultFlowKey)) require.Error(t, err, "EnqueueAndWait must fail when context is cancelled during a blocking submit") assert.ErrorIs(t, err, types.ErrRejected, "error should wrap ErrRejected") - assert.ErrorIs(t, err, context.Canceled, "error should wrap the underlying ctx.Err()") - assert.Equal(t, types.QueueOutcomeRejectedCapacity, outcome, "outcome should be QueueOutcomeRejectedCapacity") + assert.ErrorIs(t, err, context.Canceled, "error should wrap the underlying ctx.Err() (context.Canceled)") + assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, + "outcome should be QueueOutcomeRejectedOther when cancelled during distribution") }) + // This test validates the retry mechanism when a processor reports that its shard is draining. t.Run("RetriesAndSucceeds_OnProcessorReportsShardDraining", func(t *testing.T) { t.Parallel() var callCount atomic.Int32 @@ -504,106 +693,278 @@ func TestFlowController_EnqueueAndWait(t *testing.T) { attempt := callCount.Add(1) shardA := newMockShard("shard-A").withByteSize(100).build() shardB := newMockShard("shard-B").withByteSize(1000).build() + if attempt == 1 { + // Attempt 1: Shard A is the least loaded and is selected. return fn(&mockActiveFlowConnection{ActiveShardsV: []contracts.RegistryShard{shardA, shardB}}) } + // Attempt 2 (Retry): Assume Shard A is now draining and removed from the active set by the registry. return fn(&mockActiveFlowConnection{ActiveShardsV: []contracts.RegistryShard{shardB}}) }, } - h := newUnitHarness(t, t.Context(), Config{}, mockRegistry) + // Use a long TTL to ensure retries don't time out. + h := newUnitHarness(t, t.Context(), Config{DefaultRequestTTL: 10 * time.Second}, mockRegistry) + + // Configure Shard A's processor to reject the request due to draining. h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{ SubmitFunc: func(item *internal.FlowItem) error { - go item.Finalize(types.QueueOutcomeRejectedOther, contracts.ErrShardDraining) + // The processor accepts the item but then asynchronously finalizes it with ErrShardDraining. + item.SetHandle(&typesmocks.MockQueueItemHandle{}) + go item.FinalizeWithOutcome(types.QueueOutcomeRejectedOther, contracts.ErrShardDraining) return nil }, } + // Configure Shard B's processor to successfully dispatch the request on the retry. h.mockProcessorFactory.processors["shard-B"] = &mockShardProcessor{ SubmitFunc: func(item *internal.FlowItem) error { - go item.Finalize(types.QueueOutcomeDispatched, nil) + go item.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil) return nil }, } - outcome, err := h.fc.EnqueueAndWait(newTestRequest(context.Background(), defaultFlowKey)) + // Act + outcome, err := h.fc.EnqueueAndWait(context.Background(), newTestRequest(defaultFlowKey)) + + // Assert require.NoError(t, err, "EnqueueAndWait must succeed after retrying on a healthy shard") assert.Equal(t, types.QueueOutcomeDispatched, outcome, "outcome should be QueueOutcomeDispatched") assert.Equal(t, int32(2), callCount.Load(), "registry must be consulted for Active shards on each retry attempt") }) }) + + // Lifecycle covers the post-distribution phase, focusing on how the controller handles context cancellation and TTL + // expiry while the request is buffered or queued by the processor (Asynchronous Finalization). + t.Run("Lifecycle", func(t *testing.T) { + t.Parallel() + + // Validates that the controller correctly initiates asynchronous finalization when the request context is cancelled + // after ownership has been transferred to the processor. + t.Run("OnReqCtxCancelledAfterDistribution", func(t *testing.T) { + t.Parallel() + // Use a long TTL to ensure the failure is due to cancellation. + h := newUnitHarness(t, t.Context(), Config{DefaultRequestTTL: 10 * time.Second}, nil) + + shardA := newMockShard("shard-A").build() + h.mockRegistry.WithConnectionFunc = func(_ types.FlowKey, fn func(_ contracts.ActiveFlowConnection) error) error { + return fn(&mockActiveFlowConnection{ActiveShardsV: []contracts.RegistryShard{shardA}}) + } + + // Channel for synchronization. + itemSubmitted := make(chan *internal.FlowItem, 1) + + // Configure the processor to accept the item but never finalize it, simulating a queued request. + h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{ + SubmitFunc: func(item *internal.FlowItem) error { + item.SetHandle(&typesmocks.MockQueueItemHandle{}) + itemSubmitted <- item + return nil + }, + } + + reqCtx, cancelReq := context.WithCancel(context.Background()) + req := newTestRequest(defaultFlowKey) + + var outcome types.QueueOutcome + var err error + done := make(chan struct{}) + go func() { + outcome, err = h.fc.EnqueueAndWait(reqCtx, req) + close(done) + }() + + // 1. Wait for the item to be successfully distributed. + var item *internal.FlowItem + select { + case item = <-itemSubmitted: + // Success. Ownership has transferred. EnqueueAndWait is now in the select loop. + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for item to be submitted to the processor") + } + + // 2. Cancel the request context. + cancelReq() + + // 3. Wait for EnqueueAndWait to return. + select { + case <-done: + // Success. The controller detected the cancellation and unblocked the caller. + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for EnqueueAndWait to return after cancellation") + } + + // 4. Assertions for EnqueueAndWait's return values. + require.Error(t, err, "EnqueueAndWait should return an error when the request is cancelled post-distribution") + // The outcome should be Evicted (as the handle was set). + assert.ErrorIs(t, err, types.ErrEvicted, "error should wrap ErrEvicted") + // The underlying cause must be propagated. + assert.ErrorIs(t, err, types.ErrContextCancelled, "error should wrap ErrContextCancelled") + assert.Equal(t, types.QueueOutcomeEvictedContextCancelled, outcome, "outcome should be EvictedContextCancelled") + + // 5. Assert that the FlowItem itself was indeed finalized by the controller. + finalState := item.FinalState() + require.NotNil(t, finalState, "Item should have been finalized asynchronously by the controller") + assert.Equal(t, types.QueueOutcomeEvictedContextCancelled, finalState.Outcome, + "Item's internal outcome must match the returned outcome") + }) + + // Validates the asynchronous finalization path due to TTL expiry. + // Note: This relies on real time passing, as context.WithDeadline timers cannot be controlled by FakeClock. + t.Run("OnReqCtxTimeoutAfterDistribution", func(t *testing.T) { + t.Parallel() + // Configure a short TTL to keep the test reasonably fast. + const requestTTL = 50 * time.Millisecond + h := newUnitHarness(t, t.Context(), Config{DefaultRequestTTL: requestTTL}, nil) + + shardA := newMockShard("shard-A").build() + h.mockRegistry.WithConnectionFunc = func(_ types.FlowKey, fn func(_ contracts.ActiveFlowConnection) error) error { + return fn(&mockActiveFlowConnection{ActiveShardsV: []contracts.RegistryShard{shardA}}) + } + + itemSubmitted := make(chan *internal.FlowItem, 1) + + // Configure the processor to accept the item but never finalize it. + h.mockProcessorFactory.processors["shard-A"] = &mockShardProcessor{ + SubmitFunc: func(item *internal.FlowItem) error { + item.SetHandle(&typesmocks.MockQueueItemHandle{}) + itemSubmitted <- item + return nil + }, + } + + req := newTestRequest(defaultFlowKey) + // Use a context for the call itself that won't time out independently. + enqueueCtx, enqueueCancel := context.WithTimeout(context.Background(), 2*time.Second) + defer enqueueCancel() + + var outcome types.QueueOutcome + var err error + done := make(chan struct{}) + + startTime := time.Now() // Capture start time to validate duration. + go func() { + outcome, err = h.fc.EnqueueAndWait(enqueueCtx, req) + close(done) + }() + + // 1. Wait for the item to be submitted. + var item *internal.FlowItem + select { + case item = <-itemSubmitted: + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for item to be submitted to the processor") + } + + // 2.Wait for the TTL to expire (Real time). We do NOT call Step(). + // Wait for EnqueueAndWait to return due to the TTL expiry. + select { + case <-done: + // Success. Now validate that enough time actually passed. + duration := time.Since(startTime) + assert.GreaterOrEqual(t, duration, requestTTL-30*time.Millisecond, // tolerance for CI environments + "EnqueueAndWait returned faster than the TTL allows, indicating the timer did not function correctly") + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for EnqueueAndWait to return after TTL expiry") + } + + // 4. Assertions for EnqueueAndWait's return values. + require.Error(t, err, "EnqueueAndWait should return an error when TTL expires post-distribution") + assert.ErrorIs(t, err, types.ErrEvicted, "error should wrap ErrEvicted") + assert.ErrorIs(t, err, types.ErrTTLExpired, "error should wrap the underlying cause (types.ErrTTLExpired)") + assert.Equal(t, types.QueueOutcomeEvictedTTL, outcome, "outcome should be EvictedTTL") + + // 5. Assert FlowItem final state. + finalState := item.FinalState() + require.NotNil(t, finalState, "Item should have been finalized asynchronously by the controller") + assert.Equal(t, types.QueueOutcomeEvictedTTL, finalState.Outcome, + "Item's internal outcome must match the returned outcome") + }) + }) } -func TestFlowController_Lifecycle(t *testing.T) { +// TestFlowController_WorkerManagement covers the lifecycle of the shard processors (workers), including startup, +// reconciliation (garbage collection), and shutdown. +func TestFlowController_WorkerManagement(t *testing.T) { t.Parallel() + // Reconciliation validates that the controller correctly identifies and shuts down workers whose shards no longer + // exist in the registry. t.Run("Reconciliation", func(t *testing.T) { t.Parallel() + // Setup: A registry that initially knows about "shard-A" and "stale-shard", but later only reports "shard-A". mockRegistry := &mockRegistryClient{ - // Configure the mock registry to report the new state without the stale shard. ShardStatsFunc: func() []contracts.ShardStats { + // The current state of the world according to the registry. return []contracts.ShardStats{{ID: "shard-A"}} }} h := newUnitHarness(t, t.Context(), Config{}, mockRegistry) - // Pre-populate the controller with initial workers. + // Pre-populate the controller with initial workers, simulating a previous state. initialShards := []string{"shard-A", "stale-shard"} for _, shardID := range initialShards { currentShardID := shardID + // Initialize the processor mocks with the channel needed to synchronize startup. h.mockProcessorFactory.processors[currentShardID] = &mockShardProcessor{runStarted: make(chan struct{})} shard := &mocks.MockRegistryShard{IDFunc: func() string { return currentShardID }} + // Start the worker using the internal mechanism. h.fc.getOrStartWorker(shard) } require.Len(t, h.mockProcessorFactory.processors, 2, "pre-condition: initial workers not set up correctly") - // Wait for all workers to have started and set their contexts before proceeding with the test. + // Wait for all worker goroutines to have started and captured their contexts. for id, p := range h.mockProcessorFactory.processors { proc := p select { case <-proc.runStarted: - // Success + // Worker is running. case <-time.After(2 * time.Second): t.Fatalf("timed out waiting for worker %s to start", id) } } - // Manually trigger the reconciliation loop logic. + // Act: Manually trigger the reconciliation logic. h.fc.reconcileProcessors() t.Run("StaleWorkerIsCancelled", func(t *testing.T) { staleProc := h.mockProcessorFactory.processors["stale-shard"] require.NotNil(t, staleProc.Context(), "precondition: stale processor context should have been captured") + // The context of the removed worker must be cancelled to signal shutdown. select { case <-staleProc.Context().Done(): - // Success + // Success: Context was cancelled. case <-time.After(100 * time.Millisecond): - t.Error("context of removed worker must be cancelled") + t.Error("context of the stale worker was not cancelled during reconciliation") } }) t.Run("ActiveWorkerIsNotCancelled", func(t *testing.T) { activeProc := h.mockProcessorFactory.processors["shard-A"] require.NotNil(t, activeProc.Context(), "precondition: active processor context should have been captured") + // The context of an active worker must remain open. select { case <-activeProc.Context().Done(): - t.Error("context of remaining worker must not be cancelled") + t.Error("context of the active worker was incorrectly cancelled during reconciliation") default: - // Success + // Success: Context is still active. } }) t.Run("WorkerMapIsUpdated", func(t *testing.T) { + // The stale worker must be removed from the controller's concurrent map. _, ok := h.fc.workers.Load("stale-shard") - assert.False(t, ok, "stale worker must be removed from the controller's map") + assert.False(t, ok, "stale worker must be deleted from the controller's map") _, ok = h.fc.workers.Load("shard-A") assert.True(t, ok, "active worker must remain in the controller's map") }) }) + // Validates that the reconciliation loop runs periodically based on the configured interval. t.Run("Reconciliation_IsTriggeredByTicker", func(t *testing.T) { t.Parallel() - reconciliationInterval := 10 * time.Second + const reconciliationInterval = 10 * time.Second mockRegistry := &mockRegistryClient{} + // Count the number of times the reconciliation logic (which calls ShardStats) runs. var reconcileCount atomic.Int32 mockRegistry.ShardStatsFunc = func() []contracts.ShardStats { reconcileCount.Add(1) @@ -611,45 +972,55 @@ func TestFlowController_Lifecycle(t *testing.T) { } h := newUnitHarness(t, t.Context(), Config{ProcessorReconciliationInterval: reconciliationInterval}, mockRegistry) + // Ensure we are using the FakeClock specifically for this test, as we need Step/HasWaiters. + require.NotNil(t, h.mockClock, "This test requires the harness to be using FakeClock") // Wait for the reconciliation loop to start and create the ticker. - // This prevents a race where the clock is stepped before the ticker is registered. - require.Eventually(t, h.mockClock.HasWaiters, time.Second, 10*time.Millisecond, "ticker was not created") + // This prevents a race where the clock is stepped before the ticker is registered with the FakeClock. + require.Eventually(t, h.mockClock.HasWaiters, time.Second, 10*time.Millisecond, + "reconciliation ticker was not created") - // Advance the clock to trigger the next reconciliation. + // Advance the clock to trigger the first reconciliation. h.mockClock.Step(reconciliationInterval) assert.Eventually(t, func() bool { return reconcileCount.Load() == 1 - }, time.Second, 10*time.Millisecond, "reconciliation was not triggered by the ticker") + }, time.Second, 10*time.Millisecond, "reconciliation was not triggered by the first ticker event") // Advance the clock again to ensure it continues to fire. h.mockClock.Step(reconciliationInterval) assert.Eventually(t, func() bool { return reconcileCount.Load() == 2 - }, time.Second, 10*time.Millisecond, "reconciliation did not fire on the second tick") + }, time.Second, 10*time.Millisecond, "reconciliation did not fire on the second ticker event") }) + // Validates the atomicity of worker creation and ensures resource cleanup for the loser of the race. t.Run("WorkerCreationRace", func(t *testing.T) { t.Parallel() - // This test requires manual control over the shard processor factory to deterministically create a race. + // This test orchestrates a deterministic race condition. factoryEntered := make(chan *mockShardProcessor, 2) continueFactory := make(chan struct{}) + // Map to store the construction context for each processor instance, allowing us to verify cleanup. + constructionContexts := sync.Map{} - h := newUnitHarness(t, t.Context(), Config{}, &mockRegistryClient{}) + h := newUnitHarness(t, t.Context(), Config{}, nil) + + // Inject a custom factory to control the timing of worker creation. h.fc.shardProcessorFactory = func( + ctx context.Context, // The context created by getOrStartWorker for the potential new processor. shard contracts.RegistryShard, _ contracts.SaturationDetector, - _ clock.Clock, + _ clock.WithTicker, _ time.Duration, _ int, _ logr.Logger, ) shardProcessor { - // This factory function will be called by `startNewWorker`. - // We use channels to pause execution here, allowing two goroutines to enter this function before one "wins" - // the `LoadOrStore` race. + // This function is called by getOrStartWorker before the LoadOrStore check. proc := &mockShardProcessor{runStarted: make(chan struct{})} + constructionContexts.Store(proc, ctx) // Capture the construction context. + + // Signal entry and then block, allowing another goroutine to enter. factoryEntered <- proc <-continueFactory return proc @@ -669,18 +1040,17 @@ func TestFlowController_Lifecycle(t *testing.T) { h.fc.getOrStartWorker(shard) }() - // Wait for both goroutines to have entered the factory and created a processor. - // This confirms they both missed the initial `workers.Load` check. + // 1. Wait for both goroutines to enter the factory and create their respective processor instances. proc1 := <-factoryEntered proc2 := <-factoryEntered - // Unblock both goroutines, allowing them to race to `workers.LoadOrStore`. + // 2. Unblock both goroutines, allowing them to race to workers.LoadOrStore. close(continueFactory) wg.Wait() - // One processor "won" and was stored, the other "lost" and should have been cancelled. + // 3. Identify the winner and the loser. actual, ok := h.fc.workers.Load("race-shard") - require.True(t, ok, "a worker should have been stored in the map") + require.True(t, ok, "a worker must have been successfully stored in the map") storedWorker := actual.(*managedWorker) winnerProc := storedWorker.processor.(*mockShardProcessor) @@ -692,16 +1062,17 @@ func TestFlowController_Lifecycle(t *testing.T) { loserProc = proc1 } - // Wait for the `Run` method to be called on the winning processor to ensure its context is available. + // 4. Validate the state of the winning processor. + // Wait for the Run method to be called on the winner (only the winner should start). select { case <-winnerProc.runStarted: - // Success. + // Success. case <-time.After(1 * time.Second): - t.Fatal("timed out waiting for winning worker to start") + t.Fatal("timed out waiting for the winning worker's Run method to be called") } // The winning processor's context must remain active. - require.NotNil(t, winnerProc.Context(), "winner's context should not be nil") + require.NotNil(t, winnerProc.Context(), "winner's context should not be nil (Run was called)") select { case <-winnerProc.Context().Done(): t.Error("context of the winning worker should not be cancelled") @@ -709,49 +1080,61 @@ func TestFlowController_Lifecycle(t *testing.T) { // Success } - // The losing processor's `Run` method must not be called, and its context should be nil. + // 5. Validate the state of the losing processor and resource cleanup. + // The losing processor's Run method must NOT be called. select { case <-loserProc.runStarted: - t.Error("Run was called on the losing worker, but it should not have been") + t.Error("Run was incorrectly called on the losing worker") default: // Success } - assert.Nil(t, loserProc.Context(), "loser's context should be nil as Run is never called") + + // Verify the context created for the loser during construction was cancelled by getOrStartWorker. + loserCtxRaw, ok := constructionContexts.Load(loserProc) + require.True(t, ok, "loser processor construction context should have been captured") + loserCtx := loserCtxRaw.(context.Context) + + select { + case <-loserCtx.Done(): + // Success: Context was cancelled, preventing resource leaks. + case <-time.After(100 * time.Millisecond): + t.Error("context of the losing worker was not cancelled, this will leak resources") + } }) } -func TestFlowController_Concurrency(t *testing.T) { - const ( - numShards = 4 - numGoroutines = 50 - numRequests = 200 - ) - - // Set up a realistic registry that vends real components to the processor. +// Helper function to create a realistic mock registry environment for integration/concurrency tests. +func setupRegistryForConcurrency(t *testing.T, numShards int, flowKey types.FlowKey) *mockRegistryClient { + t.Helper() mockRegistry := &mockRegistryClient{} shards := make([]contracts.RegistryShard, numShards) - queues := make(map[string]contracts.ManagedQueue) + + // Configure the shards and their dependencies required by the real ShardProcessor implementation. for i := range numShards { + // Capture loop variables for closures. shardID := fmt.Sprintf("shard-%d", i) - queues[shardID] = &mocks.MockManagedQueue{FlowKeyV: defaultFlowKey} // Use the high-fidelity mock queue. + // Use high-fidelity mock queues (MockManagedQueue) that implement the necessary interfaces and synchronization. + currentQueue := &mocks.MockManagedQueue{FlowKeyV: flowKey} + shards[i] = &mocks.MockRegistryShard{ IDFunc: func() string { return shardID }, ManagedQueueFunc: func(_ types.FlowKey) (contracts.ManagedQueue, error) { - return queues[shardID], nil + return currentQueue, nil }, - AllOrderedPriorityLevelsFunc: func() []int { return []int{100} }, + // Configuration required for ShardProcessor initialization and dispatch logic. + AllOrderedPriorityLevelsFunc: func() []int { return []int{flowKey.Priority} }, PriorityBandAccessorFunc: func(priority int) (framework.PriorityBandAccessor, error) { - if priority == 100 { + if priority == flowKey.Priority { return &frameworkmocks.MockPriorityBandAccessor{ - PriorityNameV: "high", - PriorityV: 100, + PriorityV: priority, IterateQueuesFunc: func(f func(framework.FlowQueueAccessor) bool) { - f(queues[shardID].FlowQueueAccessor()) + f(currentQueue.FlowQueueAccessor()) }, }, nil } return nil, fmt.Errorf("unexpected priority %d", priority) }, + // Configure dispatch policies (FIFO). IntraFlowDispatchPolicyFunc: func(_ types.FlowKey) (framework.IntraFlowDispatchPolicy, error) { return &frameworkmocks.MockIntraFlowDispatchPolicy{ SelectItemFunc: func(qa framework.FlowQueueAccessor) (types.QueueItemAccessor, error) { @@ -762,26 +1145,29 @@ func TestFlowController_Concurrency(t *testing.T) { InterFlowDispatchPolicyFunc: func(_ int) (framework.InterFlowDispatchPolicy, error) { return &frameworkmocks.MockInterFlowDispatchPolicy{ SelectQueueFunc: func(band framework.PriorityBandAccessor) (framework.FlowQueueAccessor, error) { - return queues[shardID].FlowQueueAccessor(), nil + return currentQueue.FlowQueueAccessor(), nil }, }, nil }, + // Configure stats reporting based on the live state of the mock queues. StatsFunc: func() contracts.ShardStats { return contracts.ShardStats{ ID: shardID, - TotalLen: uint64(queues[shardID].Len()), - TotalByteSize: queues[shardID].ByteSize(), + TotalLen: uint64(currentQueue.Len()), + TotalByteSize: currentQueue.ByteSize(), PerPriorityBandStats: map[int]contracts.PriorityBandStats{ - 100: { - Len: uint64(queues[shardID].Len()), - ByteSize: queues[shardID].ByteSize(), - CapacityBytes: 1e9, // Effectively unlimited capacity + flowKey.Priority: { + Len: uint64(currentQueue.Len()), + ByteSize: currentQueue.ByteSize(), + CapacityBytes: 1e9, // Effectively unlimited capacity to ensure dispatch success. }, }, } }, } } + + // Configure the registry connection. mockRegistry.WithConnectionFunc = func(_ types.FlowKey, fn func(conn contracts.ActiveFlowConnection) error) error { return fn(&mockActiveFlowConnection{ActiveShardsV: shards}) } @@ -792,25 +1178,121 @@ func TestFlowController_Concurrency(t *testing.T) { } return stats } + return mockRegistry +} + +// TestFlowController_Concurrency_Distribution performs an integration test under high contention, using real +// ShardProcessors. +// It validates the thread-safety of the distribution logic and the overall system throughput. +func TestFlowController_Concurrency_Distribution(t *testing.T) { + const ( + numShards = 4 + numGoroutines = 50 + numRequests = 200 + ) + + // Arrange + mockRegistry := setupRegistryForConcurrency(t, numShards, defaultFlowKey) + + // Initialize the integration harness with real ShardProcessors. h := newIntegrationHarness(t, t.Context(), Config{ - // Use a generous buffer to prevent flakes in the test due to transient queuing delays. + // Use a generous buffer to focus the test on distribution logic rather than backpressure. EnqueueChannelBufferSize: numRequests, - DefaultRequestTTL: 1 * time.Second, + DefaultRequestTTL: 5 * time.Second, ExpiryCleanupInterval: 100 * time.Millisecond, }, mockRegistry) + // Act: Hammer the controller concurrently. var wg sync.WaitGroup wg.Add(numGoroutines) outcomes := make(chan types.QueueOutcome, numRequests) - for range numGoroutines { + + for i := range numGoroutines { + goroutineID := i go func() { defer wg.Done() - for range numRequests / numGoroutines { - req := newTestRequest(logr.NewContext(context.Background(), logr.Discard()), defaultFlowKey) - outcome, err := h.fc.EnqueueAndWait(req) + for j := range numRequests / numGoroutines { + req := newTestRequest(defaultFlowKey) + req.IDV = fmt.Sprintf("req-distrib-%d-%d", goroutineID, j) + + // Use a reasonable timeout for the individual request context. + reqCtx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + ctx := logr.NewContext(reqCtx, logr.Discard()) + outcome, err := h.fc.EnqueueAndWait(ctx, req) + if err != nil { + // Use t.Errorf for concurrent tests to report failures without halting execution. + t.Errorf("EnqueueAndWait failed unexpectedly under load: %v", err) + } + outcomes <- outcome + } + }() + } + + // Wait for all requests to complete. + wg.Wait() + close(outcomes) + + // Assert: All requests should be successfully dispatched. + successCount := 0 + for outcome := range outcomes { + if outcome == types.QueueOutcomeDispatched { + successCount++ + } + } + require.Equal(t, numRequests, successCount, + "all concurrent requests must be dispatched successfully without errors or data races") +} + +// TestFlowController_Concurrency_Backpressure specifically targets the blocking submission path (SubmitOrBlock) by +// configuring the processors with zero buffer capacity. +func TestFlowController_Concurrency_Backpressure(t *testing.T) { + if testing.Short() { + t.Skip("Skipping concurrency integration test in short mode.") + } + t.Parallel() + + const ( + numShards = 2 + numGoroutines = 20 + // Fewer requests than the distribution test, as the blocking path is inherently slower. + numRequests = 40 + ) + + // Arrange: Set up the registry environment. + mockRegistry := setupRegistryForConcurrency(t, numShards, defaultFlowKey) + + // Use the integration harness with a configuration designed to induce backpressure. + h := newIntegrationHarness(t, t.Context(), Config{ + // Zero buffer forces immediate use of SubmitOrBlock if the processor loop is busy. + EnqueueChannelBufferSize: 0, + // Generous TTL to ensure timeouts are not the cause of failure. + DefaultRequestTTL: 10 * time.Second, + ExpiryCleanupInterval: 100 * time.Millisecond, + }, mockRegistry) + + // Act: Concurrently submit requests. + var wg sync.WaitGroup + wg.Add(numGoroutines) + outcomes := make(chan types.QueueOutcome, numRequests) + + for i := range numGoroutines { + goroutineID := i + go func() { + defer wg.Done() + for j := range numRequests / numGoroutines { + req := newTestRequest(defaultFlowKey) + req.IDV = fmt.Sprintf("req-backpressure-%d-%d", goroutineID, j) + + // Use a reasonable timeout for the individual request context to ensure the test finishes promptly if a + // deadlock occurs. + reqCtx, cancel := context.WithTimeout(t.Context(), 5*time.Second) + defer cancel() + + outcome, err := h.fc.EnqueueAndWait(logr.NewContext(reqCtx, logr.Discard()), req) if err != nil { - // Use `t.Errorf` for concurrent tests to avoid halting execution on a single failure. - t.Errorf("EnqueueAndWait failed unexpectedly: %v", err) + t.Errorf("EnqueueAndWait failed unexpectedly under backpressure for request %s: %v", req.ID(), err) } outcomes <- outcome } @@ -819,11 +1301,13 @@ func TestFlowController_Concurrency(t *testing.T) { wg.Wait() close(outcomes) + // Assert: Verify successful dispatch despite high contention and zero buffer. successCount := 0 for outcome := range outcomes { if outcome == types.QueueOutcomeDispatched { successCount++ } } - require.Equal(t, numRequests, successCount, "all concurrent requests should be dispatched successfully") + require.Equal(t, numRequests, successCount, + "all concurrent requests should be dispatched successfully even under high contention and zero buffer capacity") } diff --git a/pkg/epp/flowcontrol/controller/internal/item.go b/pkg/epp/flowcontrol/controller/internal/item.go index d5bdaaf2e..31a28d473 100644 --- a/pkg/epp/flowcontrol/controller/internal/item.go +++ b/pkg/epp/flowcontrol/controller/internal/item.go @@ -17,97 +17,171 @@ limitations under the License. package internal import ( + "context" + "errors" + "fmt" "sync" + "sync/atomic" "time" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" ) -// FinalState encapsulates the terminal outcome of a `FlowItem`'s lifecycle. -// It is sent over the item's `Done()` channel exactly once. +// FinalState encapsulates the terminal outcome of a FlowItem's lifecycle. type FinalState struct { Outcome types.QueueOutcome Err error } -// FlowItem is the internal representation of a request managed by the `FlowController`. +// FlowItem is the internal representation of a request managed by the Flow Controller. +// +// # Lifecycle Management +// +// Finalization (determining outcome) can be initiated by the Controller (e.g., Context expiry) or the Processor (e.g., +// Dispatch/Reject). It sets the outcome and signals the waiting goroutine. +// +// # Synchronization +// +// Atomic operations synchronize state across the Controller and Processor goroutines: +// - finalState (atomic.Pointer): Safely publishes the outcome. +// - handle (atomic.Pointer): Safely publishes the queue admission status. type FlowItem struct { - // --- Immutable fields (set at creation) --- + // --- Immutable fields during a single lifecycle --- enqueueTime time.Time effectiveTTL time.Duration originalRequest types.FlowControlRequest - handle types.QueueItemHandle - // --- Finalization state (protected by onceFinalize) --- + // --- Synchronized State --- - // done is closed exactly once when the item is finalized. - // The closing of this channel establishes a "happens-before" memory barrier, guaranteeing that writes to `outcome` - // and `err` are visible to any goroutine that has successfully read from `done`. - done chan FinalState + // handle stores the types.QueueItemHandle atomically. + // Written by the Processor (SetHandle) when admitted. + // Read by inferOutcome (called by Finalize) to infer the outcome (Rejected vs. Evicted). + // Distinguishing between pre-admission (Rejection) and post-admission (Eviction) during asynchronous finalization + // relies on whether this handle is nil or non-nil. + handle atomic.Pointer[types.QueueItemHandle] - // finalState is safely visible to any goroutine after it has confirmed the channel is closed. - finalState FinalState + // finalState holds the result of the finalization. Stored atomically once. + // Use FinalState() for safe access. + finalState atomic.Pointer[FinalState] - // onceFinalize ensures the `finalize()` logic is idempotent. + // --- Finalization Signaling --- + + // done is the channel used to signal the completion of the item's lifecycle. + // Buffered to size 1 to prevent Finalize from blocking. + done chan *FinalState + + // onceFinalize ensures the finalization logic runs exactly once per lifecycle. onceFinalize sync.Once } -// ensure FlowItem implements the interface. var _ types.QueueItemAccessor = &FlowItem{} -// NewItem creates a new `FlowItem`. +// NewItem allocates and initializes a new FlowItem for a request lifecycle. func NewItem(req types.FlowControlRequest, effectiveTTL time.Duration, enqueueTime time.Time) *FlowItem { return &FlowItem{ enqueueTime: enqueueTime, effectiveTTL: effectiveTTL, originalRequest: req, - // Buffer to size one, preventing finalizing goroutine (e.g., the dispatcher) from blocking if the waiting - // goroutine has already timed out and is no longer reading. - done: make(chan FinalState, 1), + done: make(chan *FinalState, 1), } } -// EnqueueTime returns the time the item was logically accepted by the `FlowController` for queuing. This is used as the -// basis for TTL calculations. +// EnqueueTime returns the time the item was logically accepted by the FlowController. func (fi *FlowItem) EnqueueTime() time.Time { return fi.enqueueTime } -// EffectiveTTL returns the actual time-to-live assigned to this item by the `FlowController`. +// EffectiveTTL returns the actual time-to-live assigned to this item. func (fi *FlowItem) EffectiveTTL() time.Duration { return fi.effectiveTTL } -// OriginalRequest returns the original, underlying `types.FlowControlRequest` object. +// OriginalRequest returns the original types.FlowControlRequest object. func (fi *FlowItem) OriginalRequest() types.FlowControlRequest { return fi.originalRequest } -// Handle returns the `types.QueueItemHandle` that uniquely identifies this item within a specific queue instance. It -// returns nil if the item has not yet been added to a queue. -func (fi *FlowItem) Handle() types.QueueItemHandle { return fi.handle } +// Done returns a read-only channel that will receive the FinalState pointer exactly once. +func (fi *FlowItem) Done() <-chan *FinalState { return fi.done } -// SetHandle associates a `types.QueueItemHandle` with this item. This method is called by a `framework.SafeQueue` -// implementation immediately after the item is added to the queue. -func (fi *FlowItem) SetHandle(handle types.QueueItemHandle) { fi.handle = handle } +// FinalState returns the FinalState if the item has been finalized, or nil otherwise. +// Safe for concurrent access. +func (fi *FlowItem) FinalState() *FinalState { return fi.finalState.Load() } -// Done returns a channel that is closed when the item has been finalized (e.g., dispatched, rejected, or evicted). -func (fi *FlowItem) Done() <-chan FinalState { - return fi.done +// Handle returns the types.QueueItemHandle for this item within a queue. +// Returns nil if the item is not in a queue. Safe for concurrent access. +func (fi *FlowItem) Handle() types.QueueItemHandle { + ptr := fi.handle.Load() + if ptr == nil { + return nil + } + return *ptr } -// Finalize sets the item's terminal state and signals the waiting goroutine by closing its `done` channel idempotently. -// This method is idempotent and is the single point where an item's lifecycle concludes. -// It is intended to be called only by the component that owns the item's lifecycle, such as a `ShardProcessor`. -func (fi *FlowItem) Finalize(outcome types.QueueOutcome, err error) { +// SetHandle associates a types.QueueItemHandle with this item. Called by the queue implementation (via Processor). +// Safe for concurrent access. +func (fi *FlowItem) SetHandle(handle types.QueueItemHandle) { fi.handle.Store(&handle) } + +// Finalize determines the item's terminal state based on the provided cause (e.g., Context error) and the item's +// current admission status (queued or not). +// +// This method is intended for asynchronous finalization initiated by the Controller (e.g., TTL expiry). +// It is idempotent. +func (fi *FlowItem) Finalize(cause error) { fi.onceFinalize.Do(func() { - finalState := FinalState{Outcome: outcome, Err: err} - fi.finalState = finalState - fi.done <- finalState - close(fi.done) + // Atomically load the handle to determine if the item was admitted to a queue. + // This synchronization is critical for correctly inferring the outcome across goroutines. + isQueued := fi.Handle() != nil + outcome, finalErr := inferOutcome(cause, isQueued) + fi.finalizeInternal(outcome, finalErr) }) } -// isFinalized checks if the item has been finalized without blocking or consuming the final state. -// It is a side-effect-free check used by the `ShardProcessor` as a defensive measure to avoid operating on -// already-completed items. -func (fi *FlowItem) isFinalized() bool { - // A buffered channel of size 1 can be safely and non-blockingly checked by its length. - // If the finalize function has run, it will have sent a value, and the length will be 1. - return len(fi.done) > 0 +// FinalizeWithOutcome sets the item's terminal state explicitly. +// +// This method is intended for synchronous finalization by the Processor (Dispatch, Reject) or the Controller +// (Distribution failure). +// It is idempotent. +func (fi *FlowItem) FinalizeWithOutcome(outcome types.QueueOutcome, err error) { + fi.onceFinalize.Do(func() { + fi.finalizeInternal(outcome, err) + }) +} + +// finalizeInternal is the core finalization logic. It must be called within the sync.Once.Do block. +// It captures the state, stores it atomically, and signals the Done channel. +func (fi *FlowItem) finalizeInternal(outcome types.QueueOutcome, err error) { + finalState := &FinalState{ + Outcome: outcome, + Err: err, + } + + // Atomically store the pointer. This is the critical memory barrier that publishes the state safely. + fi.finalState.Store(finalState) + + fi.done <- finalState + close(fi.done) +} + +// inferOutcome determines the correct QueueOutcome and Error based on the cause of finalization and whether the item +// was already admitted to a queue. +func inferOutcome(cause error, isQueued bool) (types.QueueOutcome, error) { + var specificErr error + var outcomeIfEvicted types.QueueOutcome + switch { + case errors.Is(cause, types.ErrTTLExpired) || errors.Is(cause, context.DeadlineExceeded): + specificErr = types.ErrTTLExpired + outcomeIfEvicted = types.QueueOutcomeEvictedTTL + case errors.Is(cause, context.Canceled): + specificErr = fmt.Errorf("%w: %w", types.ErrContextCancelled, cause) + outcomeIfEvicted = types.QueueOutcomeEvictedContextCancelled + default: + // Handle other potential causes (e.g., custom context errors). + specificErr = cause + outcomeIfEvicted = types.QueueOutcomeEvictedOther + } + + if isQueued { + // The item was in the queue when it expired/cancelled. + return outcomeIfEvicted, fmt.Errorf("%w: %w", types.ErrEvicted, specificErr) + } + + // The item was not yet in the queue (e.g., buffered in enqueueChan). + // We treat this as a rejection, as it never formally consumed queue capacity. + return types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, specificErr) } diff --git a/pkg/epp/flowcontrol/controller/internal/item_test.go b/pkg/epp/flowcontrol/controller/internal/item_test.go index 1f713e913..9b7b627c2 100644 --- a/pkg/epp/flowcontrol/controller/internal/item_test.go +++ b/pkg/epp/flowcontrol/controller/internal/item_test.go @@ -31,12 +31,16 @@ import ( func TestFlowItem_New(t *testing.T) { t.Parallel() - req := typesmocks.NewMockFlowControlRequest(100, "req-1", types.FlowKey{}, context.Background()) + req := typesmocks.NewMockFlowControlRequest(100, "req-1", types.FlowKey{}) - item := NewItem(req, time.Minute, time.Now()) + enqueueTime := time.Now() + item := NewItem(req, time.Minute, enqueueTime) require.NotNil(t, item, "NewItem should not return a nil item") - assert.False(t, item.isFinalized(), "a new item must not be in a finalized state") + assert.Equal(t, enqueueTime, item.EnqueueTime(), "EnqueueTime should be populated") + assert.Equal(t, time.Minute, item.EffectiveTTL(), "EffectiveTTL should be populated") + assert.Same(t, req, item.OriginalRequest(), "OriginalRequest should be populated") + assert.Nil(t, item.FinalState(), "a new item must not have a final state") select { case <-item.Done(): t.Fatal("Done() channel for a new item must block, but it was closed") @@ -55,21 +59,178 @@ func TestFlowItem_Handle(t *testing.T) { func TestFlowItem_Finalize_Idempotency(t *testing.T) { t.Parallel() - req := typesmocks.NewMockFlowControlRequest(100, "req-1", types.FlowKey{}, context.Background()) - item := NewItem(req, time.Minute, time.Now()) - expectedErr := errors.New("first-error") + now := time.Now() + req := typesmocks.NewMockFlowControlRequest(100, "req-1", types.FlowKey{}) - item.Finalize(types.QueueOutcomeEvictedTTL, expectedErr) - item.Finalize(types.QueueOutcomeDispatched, nil) // Should take no effect + testCases := []struct { + name string + firstCall func(item *FlowItem) + secondCall func(item *FlowItem) + expectedOutcome types.QueueOutcome + expectedErrIs error + }{ + { + name: "Finalize then Finalize", + firstCall: func(item *FlowItem) { + item.Finalize(types.ErrTTLExpired) + }, + secondCall: func(item *FlowItem) { + item.Finalize(context.Canceled) + }, + expectedOutcome: types.QueueOutcomeRejectedOther, + expectedErrIs: types.ErrTTLExpired, + }, + { + name: "Finalize then FinalizeWithOutcome", + firstCall: func(item *FlowItem) { + item.Finalize(types.ErrTTLExpired) + }, + secondCall: func(item *FlowItem) { + item.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil) + }, + expectedOutcome: types.QueueOutcomeRejectedOther, + expectedErrIs: types.ErrTTLExpired, + }, + { + name: "FinalizeWithOutcome then FinalizeWithOutcome", + firstCall: func(item *FlowItem) { + item.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil) + }, + secondCall: func(item *FlowItem) { + item.FinalizeWithOutcome(types.QueueOutcomeRejectedCapacity, errors.New("rejected")) + }, + expectedOutcome: types.QueueOutcomeDispatched, + expectedErrIs: nil, + }, + { + name: "FinalizeWithOutcome then Finalize", + firstCall: func(item *FlowItem) { + item.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil) + }, + secondCall: func(item *FlowItem) { + item.Finalize(types.ErrTTLExpired) + }, + expectedOutcome: types.QueueOutcomeDispatched, + expectedErrIs: nil, + }, + } - assert.True(t, item.isFinalized(), "item must be in a finalized state after a call to finalize()") - select { - case finalState, ok := <-item.Done(): - require.True(t, ok, "Done() channel should be readable with a value, not just closed") - assert.Equal(t, types.QueueOutcomeEvictedTTL, finalState.Outcome, - "the outcome from Done() must match the first finalized outcome") - assert.Equal(t, expectedErr, finalState.Err, "the error from Done() must match the first finalized error") - case <-time.After(50 * time.Millisecond): - t.Fatal("Done() channel must not block after finalization") + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + item := NewItem(req, time.Minute, now) + + // First call + tc.firstCall(item) + + // Second call + tc.secondCall(item) + + // Check FinalState() + finalState := item.FinalState() + require.NotNil(t, finalState, "FinalState should not be nil") + assert.Equal(t, tc.expectedOutcome, finalState.Outcome, "Outcome should match the first call") + if tc.expectedErrIs != nil { + assert.ErrorIs(t, finalState.Err, tc.expectedErrIs, "Error should match the first call") + } else { + assert.NoError(t, finalState.Err, "Error should be nil") + } + + // Check Done channel + select { + case state, ok := <-item.Done(): + require.True(t, ok, "Done channel should be readable") + assert.Equal(t, tc.expectedOutcome, state.Outcome, "Done channel outcome should match the first call") + if tc.expectedErrIs != nil { + assert.ErrorIs(t, state.Err, tc.expectedErrIs, "Done channel error should match the first call") + } else { + assert.NoError(t, state.Err, "Done channel error should be nil") + } + case <-time.After(50 * time.Millisecond): + t.Fatal("Done channel should have received the state") + } + }) + } +} + +func TestFlowItem_Finalize_InferOutcome(t *testing.T) { + t.Parallel() + now := time.Now() + + testCases := []struct { + name string + cause error + isQueued bool + expectOutcome types.QueueOutcome + expectErrIs error + }{ + { + name: "queued TTL expired", + cause: types.ErrTTLExpired, + isQueued: true, + expectOutcome: types.QueueOutcomeEvictedTTL, + expectErrIs: types.ErrTTLExpired, + }, + { + name: "queued context cancelled", + cause: context.Canceled, + isQueued: true, + expectOutcome: types.QueueOutcomeEvictedContextCancelled, + expectErrIs: types.ErrContextCancelled, + }, + { + name: "queued other error", + cause: errors.New("other cause"), + isQueued: true, + expectOutcome: types.QueueOutcomeEvictedOther, + expectErrIs: types.ErrEvicted, + }, + { + name: "not queued TTL expired", + cause: types.ErrTTLExpired, + isQueued: false, + expectOutcome: types.QueueOutcomeRejectedOther, + expectErrIs: types.ErrTTLExpired, + }, + { + name: "not queued context cancelled", + cause: context.Canceled, + isQueued: false, + expectOutcome: types.QueueOutcomeRejectedOther, + expectErrIs: types.ErrContextCancelled, + }, + { + name: "nil cause queued", + cause: nil, + isQueued: true, + expectOutcome: types.QueueOutcomeEvictedOther, + expectErrIs: types.ErrEvicted, + }, + { + name: "nil cause not queued", + cause: nil, + isQueued: false, + expectOutcome: types.QueueOutcomeRejectedOther, + expectErrIs: types.ErrRejected, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + req := typesmocks.NewMockFlowControlRequest(100, "req-1", types.FlowKey{}) + item := NewItem(req, time.Minute, now) + if tc.isQueued { + item.SetHandle(&typesmocks.MockQueueItemHandle{}) + } + + item.Finalize(tc.cause) + + finalState := item.FinalState() + require.NotNil(t, finalState, "FinalState should not be nil") + assert.Equal(t, tc.expectOutcome, finalState.Outcome, "Unexpected outcome") + require.Error(t, finalState.Err, "An error should be set") + assert.ErrorIs(t, finalState.Err, tc.expectErrIs, "Unexpected error type") + }) } } diff --git a/pkg/epp/flowcontrol/controller/internal/processor.go b/pkg/epp/flowcontrol/controller/internal/processor.go index 9b6bb705f..e3cccc196 100644 --- a/pkg/epp/flowcontrol/controller/internal/processor.go +++ b/pkg/epp/flowcontrol/controller/internal/processor.go @@ -27,7 +27,6 @@ import ( "github.com/go-logr/logr" "k8s.io/utils/clock" - "sigs.k8s.io/controller-runtime/pkg/log" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" @@ -44,106 +43,116 @@ const maxCleanupWorkers = 4 // This is used as a signal for the `controller.FlowController`'s "fast failover" logic. var ErrProcessorBusy = errors.New("shard processor is busy") -// ShardProcessor is the core worker of the `controller.FlowController`. -// It is paired one-to-one with a `contracts.RegistryShard` instance and is responsible for all request lifecycle -// operations on that shard. It acts as the "data plane" worker that executes against the concurrent-safe state provided -// by its shard. +// ShardProcessor is the core worker of the FlowController. // -// # Concurrency Model: The Single-Writer Actor +// It is paired one-to-one with a RegistryShard instance and is responsible for all request lifecycle operations on that +// shard, from the point an item is successfully submitted to it. // -// To ensure correctness and high performance, the processor uses a single-goroutine, actor-based model. The main `Run` -// loop is the sole "writer" for all state-mutating operations, particularly enqueueing. This makes complex transactions -// inherently atomic without coarse-grained locks. +// # Request Lifecycle Management & Ownership +// +// The ShardProcessor takes ownership of a FlowItem only after it has been successfully sent to its internal enqueueChan +// via Submit or SubmitOrBlock (i.e., when these methods return nil). +// Once the Processor takes ownership, it is solely responsible for ensuring that item.Finalize() or +// item.FinalizeWithOutcome() is called exactly once for that item, under all circumstances (dispatch, rejection, sweep, +// or shutdown). +// +// If Submit or SubmitOrBlock return an error, ownership remains with the caller (the Controller), which must then +// handle the finalization. // -// # Concurrency Guarantees +// # Concurrency Model // -// 1. Safe Enqueueing: The "check-then-act" sequence for capacity is safe because it is only ever performed by the -// single `Run` goroutine. -// 2. Idempotent Finalization: The primary internal race condition is between the main `dispatchCycle` and the -// background `runExpiryCleanup` goroutine, both of which might try to finalize an item. This is resolved by the -// `FlowItem.Finalize` method, which uses `sync.Once` to guarantee that only the first attempt to finalize an item -// succeeds. +// To ensure correctness and high performance, the processor uses a single-goroutine, actor-based model. The main run +// loop is the sole writer for all state-mutating operations. This makes complex transactions (like capacity checks) +// inherently atomic without coarse-grained locks. type ShardProcessor struct { - shard contracts.RegistryShard - saturationDetector contracts.SaturationDetector - clock clock.Clock - expiryCleanupInterval time.Duration - logger logr.Logger + shard contracts.RegistryShard + saturationDetector contracts.SaturationDetector + clock clock.WithTicker + cleanupSweepInterval time.Duration + logger logr.Logger + + // lifecycleCtx controls the processor's lifetime. Monitored by Submit* methods for safe shutdown. + lifecycleCtx context.Context - // enqueueChan is the entry point for new requests to be processed by this shard's `Run` loop. + // enqueueChan is the entry point for new requests. enqueueChan chan *FlowItem - // wg is used to wait for background tasks like expiry cleanup to complete on shutdown. + + // wg is used to wait for background tasks (cleanup sweep) to complete on shutdown. wg sync.WaitGroup isShuttingDown atomic.Bool shutdownOnce sync.Once } -// NewShardProcessor creates a new `ShardProcessor` instance. +// NewShardProcessor creates a new ShardProcessor instance. func NewShardProcessor( + ctx context.Context, shard contracts.RegistryShard, saturationDetector contracts.SaturationDetector, - clock clock.Clock, - expiryCleanupInterval time.Duration, + clock clock.WithTicker, + cleanupSweepInterval time.Duration, enqueueChannelBufferSize int, logger logr.Logger, ) *ShardProcessor { return &ShardProcessor{ - shard: shard, - saturationDetector: saturationDetector, - clock: clock, - expiryCleanupInterval: expiryCleanupInterval, - logger: logger, - // A buffered channel decouples the processor from the distributor, allowing for a fast, asynchronous handoff of new - // requests. - enqueueChan: make(chan *FlowItem, enqueueChannelBufferSize), + shard: shard, + saturationDetector: saturationDetector, + clock: clock, + cleanupSweepInterval: cleanupSweepInterval, + logger: logger, + lifecycleCtx: ctx, + enqueueChan: make(chan *FlowItem, enqueueChannelBufferSize), } } -// Submit attempts a non-blocking handoff of an item to the processor's internal channel for asynchronous processing. +// Submit attempts a non-blocking handoff of an item to the processor's internal enqueue channel. // -// It returns nil if the item was accepted by the processor, or if the processor is shutting down (in which case the -// item is immediately finalized with a shutdown error). In both cases, a nil return means the item's lifecycle has been -// handled by this processor and the caller should not retry. -// It returns `ErrProcessorBusy` if the processor's channel is momentarily full, signaling that the caller should try -// another processor. +// Ownership Contract: +// - Returns nil: The item was successfully handed off. +// The ShardProcessor takes responsibility for calling Finalize on the item. +// - Returns error: The item was not handed off. +// Ownership of the FlowItem remains with the caller, who is responsible for calling Finalize. +// +// Possible errors: +// - ErrProcessorBusy: The processor's input channel is full. +// - types.ErrFlowControllerNotRunning: The processor is shutting down. func (sp *ShardProcessor) Submit(item *FlowItem) error { if sp.isShuttingDown.Load() { - item.Finalize(types.QueueOutcomeRejectedOther, - fmt.Errorf("%w: %w", types.ErrRejected, types.ErrFlowControllerNotRunning)) - return nil // Success from the caller's perspective; the item is terminal. + return types.ErrFlowControllerNotRunning } - - select { + select { // The default case makes this select non-blocking. case sp.enqueueChan <- item: - return nil // Success + return nil // Ownership transferred. + case <-sp.lifecycleCtx.Done(): + return types.ErrFlowControllerNotRunning default: - // The channel buffer is full, signaling transient backpressure. return ErrProcessorBusy } } -// SubmitOrBlock performs a blocking submission of an item to the processor's internal channel. -// It will wait until either the submission succeeds or the provided context is cancelled. +// SubmitOrBlock performs a blocking handoff of an item to the processor's internal enqueue channel. +// It waits until the item is handed off, the caller's context is cancelled, or the processor shuts down. // -// This method is the fallback used by the distributor when all processors are busy, providing graceful backpressure -// instead of immediate rejection. +// Ownership Contract: +// - Returns nil: The item was successfully handed off. +// The ShardProcessor takes responsibility for calling Finalize on the item. +// - Returns error: The item was not handed off. +// Ownership of the FlowItem remains with the caller, who is responsible for calling Finalize. // -// It returns the `ctx.Err()` if the context is cancelled during the wait. +// Possible errors: +// - ctx.Err(): The provided context was cancelled or its deadline exceeded. +// - types.ErrFlowControllerNotRunning: The processor is shutting down. func (sp *ShardProcessor) SubmitOrBlock(ctx context.Context, item *FlowItem) error { if sp.isShuttingDown.Load() { - // Here, we return an error because the caller, expecting to block, was prevented from doing so by the shutdown. - // This is a failure of the operation. - item.Finalize(types.QueueOutcomeRejectedOther, - fmt.Errorf("%w: %w", types.ErrRejected, types.ErrFlowControllerNotRunning)) return types.ErrFlowControllerNotRunning } - select { + select { // The absence of a default case makes this call blocking. case sp.enqueueChan <- item: - return nil // Success + return nil // Ownership transferred. case <-ctx.Done(): - // The caller's context was cancelled while we were blocked. return ctx.Err() + case <-sp.lifecycleCtx.Done(): + return types.ErrFlowControllerNotRunning } } @@ -155,7 +164,7 @@ func (sp *ShardProcessor) Run(ctx context.Context) { defer sp.logger.V(logutil.DEFAULT).Info("Shard processor run loop stopped.") sp.wg.Add(1) - go sp.runExpiryCleanup(ctx) + go sp.runCleanupSweep(ctx) // This is the main worker loop. It continuously processes incoming requests and dispatches queued requests until the // context is cancelled. The `select` statement has three cases: @@ -204,70 +213,54 @@ func (sp *ShardProcessor) enqueue(item *FlowItem) { req := item.OriginalRequest() key := req.FlowKey() - logger := log.FromContext(req.Context()).WithName("enqueue").WithValues( - "flowKey", key, - "flowID", key.ID, - "priority", key.Priority, - "reqID", req.ID(), - "reqByteSize", req.ByteSize(), - ) + // --- Optimistic External Finalization Check --- + // Check if the item was finalized by the Controller (due to TTL/cancellation) while it was buffered in enqueueChan. + // This is an optimistic check to avoid unnecessary processing on items already considered dead. + // The ultimate guarantee of cleanup for any races is the runCleanupSweep mechanism. + if finalState := item.FinalState(); finalState != nil { + sp.logger.V(logutil.TRACE).Info("Item finalized externally before processing, discarding.", + "outcome", finalState.Outcome, "err", finalState.Err, "flowKey", key, "reqID", req.ID()) + return + } + // --- Configuration Validation --- managedQ, err := sp.shard.ManagedQueue(key) if err != nil { finalErr := fmt.Errorf("configuration error: failed to get queue for flow key %s: %w", key, err) - logger.Error(finalErr, "Rejecting item.") - item.Finalize(types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, finalErr)) + sp.logger.Error(finalErr, "Rejecting item.", "flowKey", key, "reqID", req.ID()) + item.FinalizeWithOutcome(types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, finalErr)) return } band, err := sp.shard.PriorityBandAccessor(key.Priority) if err != nil { finalErr := fmt.Errorf("configuration error: failed to get priority band for priority %d: %w", key.Priority, err) - logger.Error(finalErr, "Rejecting item.") - item.Finalize(types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, finalErr)) + sp.logger.Error(finalErr, "Rejecting item.", "flowKey", key, "reqID", req.ID()) + item.FinalizeWithOutcome(types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, finalErr)) return } - logger = logger.WithValues("priorityName", band.PriorityName()) + // --- Capacity Check --- + // This check is safe because it is performed by the single-writer Run goroutine. if !sp.hasCapacity(key.Priority, req.ByteSize()) { - // This is an expected outcome, not a system error. Log at the default level with rich context. - stats := sp.shard.Stats() - bandStats := stats.PerPriorityBandStats[key.Priority] - logger.V(logutil.DEFAULT).Info("Rejecting request, queue at capacity", - "outcome", types.QueueOutcomeRejectedCapacity, - "shardTotalBytes", stats.TotalByteSize, - "shardCapacityBytes", stats.TotalCapacityBytes, - "bandTotalBytes", bandStats.ByteSize, - "bandCapacityBytes", bandStats.CapacityBytes, - ) - item.Finalize(types.QueueOutcomeRejectedCapacity, fmt.Errorf("%w: %w", types.ErrRejected, types.ErrQueueAtCapacity)) - return - } - - // This is an optimistic check to prevent a needless add/remove cycle for an item that was finalized (e.g., context - // cancelled) during the handoff to this processor. A race condition still exists where an item can be finalized - // after this check but before the `Add` call completes. - // - // This is considered acceptable because: - // 1. The race window is extremely small. - // 2. The background `runExpiryCleanup` goroutine acts as the ultimate guarantor of correctness, as it will - // eventually find and evict any finalized item that slips through this check and is added to a queue. - if item.isFinalized() { - finalState := item.finalState - outcome, err := finalState.Outcome, finalState.Err - logger.V(logutil.VERBOSE).Info("Item finalized before adding to queue, ignoring.", "outcome", outcome, "err", err) + sp.logger.V(logutil.DEBUG).Info("Rejecting request, queue at capacity", + "flowKey", key, "reqID", req.ID(), "priorityName", band.PriorityName(), "reqByteSize", req.ByteSize()) + item.FinalizeWithOutcome(types.QueueOutcomeRejectedCapacity, fmt.Errorf("%w: %w", + types.ErrRejected, types.ErrQueueAtCapacity)) return } - // This is the point of commitment. After this call, the item is officially in the queue and is the responsibility of - // the dispatch or cleanup loops to finalize. + // --- Commitment Point --- + // The item is admitted. The ManagedQueue.Add implementation is responsible for calling item.SetHandle() atomically. if err := managedQ.Add(item); err != nil { finalErr := fmt.Errorf("failed to add item to queue for flow key %s: %w", key, err) - logger.Error(finalErr, "Rejecting item.") - item.Finalize(types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, finalErr)) + sp.logger.Error(finalErr, "Rejecting item post-admission.", + "flowKey", key, "reqID", req.ID(), "priorityName", band.PriorityName()) + item.FinalizeWithOutcome(types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, finalErr)) return } - logger.V(logutil.TRACE).Info("Item enqueued.") + sp.logger.V(logutil.TRACE).Info("Item enqueued.", + "flowKey", key, "reqID", req.ID(), "priorityName", band.PriorityName()) } // hasCapacity checks if the shard and the specific priority band have enough capacity to accommodate an item of a given @@ -414,145 +407,102 @@ func (sp *ShardProcessor) dispatchItem(itemAcc types.QueueItemAccessor, logger l return fmt.Errorf("failed to remove item %q from queue for flow %s: %w", req.ID(), req.FlowKey(), err) } - // Final check for expiry/cancellation right before dispatch. removedItem := removedItemAcc.(*FlowItem) - isExpired, outcome, expiryErr := checkItemExpiry(removedItem, sp.clock.Now()) - if isExpired { - // Ensure we always have a non-nil error to wrap for consistent logging and error handling. - finalErr := expiryErr - if finalErr == nil { - finalErr = errors.New("item finalized before dispatch") - } - logger.V(logutil.VERBOSE).Info("Item expired at time of dispatch, evicting", "outcome", outcome, - "err", finalErr) - removedItem.Finalize(outcome, fmt.Errorf("%w: %w", types.ErrEvicted, finalErr)) - // Return an error to signal that the dispatch did not succeed. - return fmt.Errorf("item %q expired before dispatch: %w", req.ID(), finalErr) - } - - // Finalize the item as dispatched. - removedItem.Finalize(types.QueueOutcomeDispatched, nil) - logger.V(logutil.TRACE).Info("Item dispatched.") + sp.logger.V(logutil.TRACE).Info("Item dispatched.", "flowKey", req.FlowKey(), "reqID", req.ID()) + removedItem.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil) return nil } -// checkItemExpiry provides the authoritative check to determine if an item should be evicted due to TTL expiry or -// context cancellation. -// -// It serves as a safeguard against race conditions. Its first action is to check if the item has already been finalized -// by a competing goroutine (e.g., the cleanup loop finalizing an item the dispatch loop is trying to process).= -// This ensures that the final outcome is decided exactly once. -func checkItemExpiry( - itemAcc types.QueueItemAccessor, - now time.Time, -) (isExpired bool, outcome types.QueueOutcome, err error) { - item := itemAcc.(*FlowItem) - - // This check is a critical defense against race conditions. If another goroutine (e.g., the cleanup loop) has - // already finalized this item, we must respect that outcome. - if item.isFinalized() { - finalState := item.finalState - return true, finalState.Outcome, finalState.Err - } - - // Check if the request's context has been cancelled. - if ctxErr := item.OriginalRequest().Context().Err(); ctxErr != nil { - return true, types.QueueOutcomeEvictedContextCancelled, fmt.Errorf("%w: %w", types.ErrContextCancelled, ctxErr) - } - - // Check if the item has outlived its TTL. - if item.EffectiveTTL() > 0 && now.Sub(item.EnqueueTime()) > item.EffectiveTTL() { - return true, types.QueueOutcomeEvictedTTL, types.ErrTTLExpired - } - - return false, types.QueueOutcomeNotYetFinalized, nil -} - -// runExpiryCleanup starts a background goroutine that periodically scans all queues on the shard for expired items. -func (sp *ShardProcessor) runExpiryCleanup(ctx context.Context) { +// runCleanupSweep starts a background goroutine that periodically scans all queues for externally finalized items +// ("zombie" items) and removes them in batches. +func (sp *ShardProcessor) runCleanupSweep(ctx context.Context) { defer sp.wg.Done() - logger := sp.logger.WithName("runExpiryCleanup") - logger.V(logutil.DEFAULT).Info("Shard expiry cleanup goroutine starting.") - defer logger.V(logutil.DEFAULT).Info("Shard expiry cleanup goroutine stopped.") + logger := sp.logger.WithName("runCleanupSweep") + logger.V(logutil.DEFAULT).Info("Shard cleanup sweep goroutine starting.") + defer logger.V(logutil.DEFAULT).Info("Shard cleanup sweep goroutine stopped.") - ticker := time.NewTicker(sp.expiryCleanupInterval) + ticker := sp.clock.NewTicker(sp.cleanupSweepInterval) defer ticker.Stop() for { select { case <-ctx.Done(): return - case now := <-ticker.C: - sp.cleanupExpired(now) + case <-ticker.C(): + sp.sweepFinalizedItems() } } } -// cleanupExpired performs a single scan of all queues on the shard, removing and finalizing any items that have -// expired. -func (sp *ShardProcessor) cleanupExpired(now time.Time) { - processFn := func(managedQ contracts.ManagedQueue, queueLogger logr.Logger) { - // This predicate identifies items to be removed by the Cleanup call. - predicate := func(item types.QueueItemAccessor) bool { - isExpired, _, _ := checkItemExpiry(item, now) - return isExpired +// sweepFinalizedItems performs a single scan of all queues, removing finalized items in batch and releasing their +// memory. +func (sp *ShardProcessor) sweepFinalizedItems() { + processFn := func(managedQ contracts.ManagedQueue, logger logr.Logger) { + key := managedQ.FlowQueueAccessor().FlowKey() + predicate := func(itemAcc types.QueueItemAccessor) bool { + return itemAcc.(*FlowItem).FinalState() != nil } - removedItems, err := managedQ.Cleanup(predicate) if err != nil { - queueLogger.Error(err, "Error during ManagedQueue Cleanup") + logger.Error(err, "Error during ManagedQueue Cleanup", "flowKey", key) } - - // Finalize all the items that were removed. - sp.finalizeExpiredItems(removedItems, now, queueLogger) + logger.V(logutil.DEBUG).Info("Swept finalized items and released capacity.", + "flowKey", key, "count", len(removedItems)) } - sp.processAllQueuesConcurrently("cleanupExpired", processFn) + sp.processAllQueuesConcurrently("sweepFinalizedItems", processFn) } -// shutdown handles the graceful termination of the processor, ensuring any pending items in the enqueue channel or in -// the queues are finalized correctly. +// shutdown handles the graceful termination of the processor, ensuring all pending items (in channel and queues) are +// Finalized. func (sp *ShardProcessor) shutdown() { sp.shutdownOnce.Do(func() { - // Set the atomic bool so that any new calls to Enqueue will fail fast. sp.isShuttingDown.Store(true) sp.logger.V(logutil.DEFAULT).Info("Shard processor shutting down.") - // Drain the channel BEFORE closing it. This prevents a panic from any goroutine that is currently blocked trying to - // send to the channel. We read until it's empty. - DrainLoop: + DrainLoop: // Drain the enqueueChan to finalize buffered items. for { select { case item := <-sp.enqueueChan: - if item == nil { // This is a safeguard against logic errors in the distributor. + if item == nil { continue } - item.Finalize(types.QueueOutcomeRejectedOther, + // Finalize buffered items. + item.FinalizeWithOutcome(types.QueueOutcomeRejectedOther, fmt.Errorf("%w: %w", types.ErrRejected, types.ErrFlowControllerNotRunning)) default: - // The channel is empty, we can now safely close it. break DrainLoop } } - close(sp.enqueueChan) - - // Evict all remaining items from the queues. + // We do not close enqueueChan because external goroutines (Controller) send on it. + // The channel will be garbage collected when the processor terminates. sp.evictAll() }) } -// evictAll drains all queues on the shard and finalizes every item with a shutdown error. +// evictAll drains all queues on the shard, finalizes every item, and releases their memory. func (sp *ShardProcessor) evictAll() { - processFn := func(managedQ contracts.ManagedQueue, queueLogger logr.Logger) { + processFn := func(managedQ contracts.ManagedQueue, logger logr.Logger) { + key := managedQ.FlowQueueAccessor().FlowKey() removedItems, err := managedQ.Drain() if err != nil { - queueLogger.Error(err, "Error during ManagedQueue Drain") + logger.Error(err, "Error during ManagedQueue Drain", "flowKey", key) } - // Finalize all the items that were removed. - getOutcome := func(_ types.QueueItemAccessor) (types.QueueOutcome, error) { - return types.QueueOutcomeEvictedOther, fmt.Errorf("%w: %w", types.ErrEvicted, types.ErrFlowControllerNotRunning) + outcome := types.QueueOutcomeEvictedOther + errShutdown := fmt.Errorf("%w: %w", types.ErrEvicted, types.ErrFlowControllerNotRunning) + for _, i := range removedItems { + item, ok := i.(*FlowItem) + if !ok { + logger.Error(fmt.Errorf("internal error: unexpected type %T", i), + "Panic condition detected during shutdown", "flowKey", key) + continue + } + + // Finalization is idempotent; safe to call even if already finalized externally. + item.FinalizeWithOutcome(outcome, errShutdown) + logger.V(logutil.TRACE).Info("Item evicted during shutdown.", + "flowKey", key, "reqID", item.OriginalRequest().ID()) } - sp.finalizeItems(removedItems, queueLogger, getOutcome) } sp.processAllQueuesConcurrently("evictAll", processFn) } @@ -620,38 +570,3 @@ func (sp *ShardProcessor) processAllQueuesConcurrently( close(tasks) // Close the channel to signal workers to exit. wg.Wait() // Wait for all workers to finish. } - -// finalizeItems is a helper to iterate over a slice of items, safely cast them, and finalize them with an outcome -// determined by the `getOutcome` function. -func (sp *ShardProcessor) finalizeItems( - items []types.QueueItemAccessor, - logger logr.Logger, - getOutcome func(item types.QueueItemAccessor) (types.QueueOutcome, error), -) { - for _, i := range items { - item, ok := i.(*FlowItem) - if !ok { - unexpectedItemErr := fmt.Errorf("internal error: item %q of type %T is not a *FlowItem", - i.OriginalRequest().ID(), i) - logger.Error(unexpectedItemErr, "Panic condition detected during finalization", "item", i) - continue - } - - outcome, err := getOutcome(i) - item.Finalize(outcome, err) - logger.V(logutil.TRACE).Info("Item finalized", "reqID", item.OriginalRequest().ID(), - "outcome", outcome, "err", err) - } -} - -// finalizeExpiredItems is a specialized version of finalizeItems for items that are known to be expired. -// It determines the precise reason for expiry and finalizes the item accordingly. -func (sp *ShardProcessor) finalizeExpiredItems(items []types.QueueItemAccessor, now time.Time, logger logr.Logger) { - getOutcome := func(item types.QueueItemAccessor) (types.QueueOutcome, error) { - // We don't need the `isExpired` boolean here because we know it's true, but this function conveniently returns the - // precise outcome and error. - _, outcome, expiryErr := checkItemExpiry(item, now) - return outcome, fmt.Errorf("%w: %w", types.ErrEvicted, expiryErr) - } - sp.finalizeItems(items, logger, getOutcome) -} diff --git a/pkg/epp/flowcontrol/controller/internal/processor_test.go b/pkg/epp/flowcontrol/controller/internal/processor_test.go index 4c31d7ae4..2280b82d7 100644 --- a/pkg/epp/flowcontrol/controller/internal/processor_test.go +++ b/pkg/epp/flowcontrol/controller/internal/processor_test.go @@ -100,6 +100,7 @@ func newTestHarness(t *testing.T, expiryCleanupInterval time.Duration) *testHarn queues: make(map[types.FlowKey]*mocks.MockManagedQueue), priorityFlows: make(map[int][]types.FlowKey), } + h.ctx, h.cancel = context.WithCancel(context.Background()) // Wire up the harness to provide the mock implementations for the shard's dependencies. h.ManagedQueueFunc = h.managedQueue @@ -118,7 +119,14 @@ func newTestHarness(t *testing.T, expiryCleanupInterval time.Duration) *testHarn } } - h.processor = NewShardProcessor(h, h.saturationDetector, h.clock, expiryCleanupInterval, 100, h.logger) + h.processor = NewShardProcessor( + h.ctx, + h, + h.saturationDetector, + h.clock, + expiryCleanupInterval, + 100, + h.logger) require.NotNil(t, h.processor, "NewShardProcessor should not return nil") t.Cleanup(func() { h.Stop() }) @@ -170,8 +178,7 @@ func (h *testHarness) waitForFinalization(item *FlowItem) (types.QueueOutcome, e // newTestItem creates a new FlowItem for testing purposes. func (h *testHarness) newTestItem(id string, key types.FlowKey, ttl time.Duration) *FlowItem { h.t.Helper() - ctx := log.IntoContext(context.Background(), h.logger) - req := typesmocks.NewMockFlowControlRequest(100, id, key, ctx) + req := typesmocks.NewMockFlowControlRequest(100, id, key) return NewItem(req, ttl, h.clock.Now()) } @@ -365,61 +372,11 @@ func TestShardProcessor(t *testing.T) { h.Start() h.Go() h.Stop() // Stop the processor, then immediately try to enqueue. - require.NoError(t, h.processor.Submit(item), "precondition: Submit should not fail, even on shutdown") + require.ErrorIs(t, h.processor.Submit(item), types.ErrFlowControllerNotRunning, + "Submit should return ErrFlowControllerNotRunning on shutdown") // --- ASSERT --- - outcome, err := h.waitForFinalization(item) - assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, "The outcome should be RejectedOther") - require.Error(t, err, "An eviction on shutdown should produce an error") - assert.ErrorIs(t, err, types.ErrFlowControllerNotRunning, "The error should be of type ErrFlowControllerNotRunning") - }) - - t.Run("should evict item on TTL expiry via background cleanup", func(t *testing.T) { - t.Parallel() - // --- ARRANGE --- - h := newTestHarness(t, testCleanupTick) - item := h.newTestItem("req-expired-evict", testFlow, testShortTTL) - h.addQueue(testFlow) - - // --- ACT --- - h.Start() - require.NoError(t, h.processor.Submit(item), "precondition: Submit should not fail") - h.Go() - - h.clock.Step(testShortTTL * 2) // Let time pass for the item to expire. - // Manually invoke the cleanup logic to simulate a tick of the cleanup loop deterministically. - h.processor.cleanupExpired(h.clock.Now()) - - // --- ASSERT --- - outcome, err := h.waitForFinalization(item) - assert.Equal(t, types.QueueOutcomeEvictedTTL, outcome, "The final outcome should be EvictedTTL") - require.Error(t, err, "A TTL eviction should produce an error") - assert.ErrorIs(t, err, types.ErrTTLExpired, "The error should be of type ErrTTLExpired") - }) - - t.Run("should evict item on context cancellation", func(t *testing.T) { - t.Parallel() - // --- ARRANGE --- - h := newTestHarness(t, testCleanupTick) - ctx, cancel := context.WithCancel(context.Background()) - req := typesmocks.NewMockFlowControlRequest(100, "req-ctx-cancel", testFlow, ctx) - item := NewItem(req, testTTL, h.clock.Now()) - h.addQueue(testFlow) - - // --- ACT --- - h.Start() - require.NoError(t, h.processor.Submit(item), "precondition: Submit should not fail") - h.Go() - cancel() // Cancel the context after the item is enqueued. - // Manually invoke the cleanup logic to deterministically check for the cancelled context. - h.processor.cleanupExpired(h.clock.Now()) - - // --- ASSERT --- - outcome, err := h.waitForFinalization(item) - assert.Equal(t, types.QueueOutcomeEvictedContextCancelled, outcome, - "The outcome should be EvictedContextCancelled") - require.Error(t, err, "A context cancellation eviction should produce an error") - assert.ErrorIs(t, err, types.ErrContextCancelled, "The error should be of type ErrContextCancelled") + assert.Nil(t, item.FinalState(), "Item should not be finalized by the processor") }) t.Run("should evict a queued item on shutdown", func(t *testing.T) { @@ -444,7 +401,8 @@ func TestShardProcessor(t *testing.T) { outcome, err := h.waitForFinalization(item) assert.Equal(t, types.QueueOutcomeEvictedOther, outcome, "The outcome should be EvictedOther") require.Error(t, err, "An eviction on shutdown should produce an error") - assert.ErrorIs(t, err, types.ErrFlowControllerNotRunning, "The error should be of type ErrFlowControllerNotRunning") + assert.ErrorIs(t, err, types.ErrFlowControllerNotRunning, + "The error should be of type ErrFlowControllerNotRunning") }) t.Run("should handle concurrent enqueues and dispatch all items", func(t *testing.T) { @@ -454,7 +412,7 @@ func TestShardProcessor(t *testing.T) { const numConcurrentItems = 20 q := h.addQueue(testFlow) itemsToTest := make([]*FlowItem, 0, numConcurrentItems) - for i := 0; i < numConcurrentItems; i++ { + for i := range numConcurrentItems { item := h.newTestItem(fmt.Sprintf("req-concurrent-%d", i), testFlow, testTTL) itemsToTest = append(itemsToTest, item) } @@ -495,16 +453,26 @@ func TestShardProcessor(t *testing.T) { // Use channels to pause the dispatch cycle right before it would remove the item. policyCanProceed := make(chan struct{}) itemIsBeingDispatched := make(chan struct{}) + var signalOnce sync.Once + var removedItem types.QueueItemAccessor require.NoError(t, q.Add(item)) // Add the item directly to the queue. // Override the queue's `RemoveFunc` to pause the dispatch goroutine at a critical moment. q.RemoveFunc = func(h types.QueueItemHandle) (types.QueueItemAccessor, error) { - close(itemIsBeingDispatched) // 1. Signal that dispatch is happening. - <-policyCanProceed // 2. Wait for the test to tell us to continue. - // 4. After we unblock, the item will have already been finalized by the cleanup logic, so we simulate the - // real-world outcome of a failed remove. - return nil, fmt.Errorf("item with handle %v not found", h) + var err error + signalOnce.Do(func() { + removedItem = item + close(itemIsBeingDispatched) // 1. Signal that dispatch is happening. + <-policyCanProceed // 2. Wait for the test to tell us to continue. + // 4. After we unblock, the item will have already been finalized by the cleanup logic. + // We simulate the item no longer being found. + err = fmt.Errorf("item with handle %v not found", h) + }) + if removedItem == item { + return item, nil // Return the item on the first call + } + return nil, err // Return error on subsequent calls } // --- ACT --- @@ -513,20 +481,23 @@ func TestShardProcessor(t *testing.T) { h.Go() // Wait for the dispatch cycle to select our item and pause inside our mock `RemoveFunc`. - <-itemIsBeingDispatched + select { + case <-itemIsBeingDispatched: + case <-time.After(testWaitTimeout): + t.Fatal("Timed out waiting for item to be dispatched") + } // 3. The dispatch goroutine is now paused. We can now safely win the "race" by running cleanup logic. h.clock.Step(testShortTTL * 2) - h.processor.cleanupExpired(h.clock.Now()) // This will remove and finalize the item. + item.Finalize(types.ErrTTLExpired) // This will finalize the item with RejectedOther. - // 5. Un-pause the dispatch goroutine. It will now fail to remove the item and the `dispatchCycle` will - // correctly conclude without finalizing the item a second time. + // 5. Un-pause the dispatch goroutine. close(policyCanProceed) // --- ASSERT --- - // The item's final state should be from the cleanup logic (EvictedTTL), not the dispatch logic. + // The item's final state should be from the Finalize call above. outcome, err := h.waitForFinalization(item) - assert.Equal(t, types.QueueOutcomeEvictedTTL, outcome, "The outcome should be EvictedTTL from the cleanup routine") + assert.Equal(t, types.QueueOutcomeEvictedTTL, outcome, "The outcome should be EvictedTTL from the Finalize call") require.Error(t, err, "A TTL eviction should produce an error") assert.ErrorIs(t, err, types.ErrTTLExpired, "The error should be of type ErrTTLExpired") }) @@ -594,9 +565,10 @@ func TestShardProcessor(t *testing.T) { h.ManagedQueueFunc = func(types.FlowKey) (contracts.ManagedQueue, error) { return nil, testErr } }, assert: func(t *testing.T, h *testHarness, item *FlowItem) { - assert.Equal(t, types.QueueOutcomeRejectedOther, item.finalState.Outcome, "Outcome should be RejectedOther") - require.Error(t, item.finalState.Err, "An error should be returned") - assert.ErrorIs(t, item.finalState.Err, testErr, "The underlying error should be preserved") + assert.Equal(t, types.QueueOutcomeRejectedOther, item.FinalState().Outcome, + "Outcome should be RejectedOther") + require.Error(t, item.FinalState().Err, "An error should be returned") + assert.ErrorIs(t, item.FinalState().Err, testErr, "The underlying error should be preserved") }, }, { @@ -606,9 +578,10 @@ func TestShardProcessor(t *testing.T) { h.PriorityBandAccessorFunc = func(int) (framework.PriorityBandAccessor, error) { return nil, testErr } }, assert: func(t *testing.T, h *testHarness, item *FlowItem) { - assert.Equal(t, types.QueueOutcomeRejectedOther, item.finalState.Outcome, "Outcome should be RejectedOther") - require.Error(t, item.finalState.Err, "An error should be returned") - assert.ErrorIs(t, item.finalState.Err, testErr, "The underlying error should be preserved") + assert.Equal(t, types.QueueOutcomeRejectedOther, item.FinalState().Outcome, + "Outcome should be RejectedOther") + require.Error(t, item.FinalState().Err, "An error should be returned") + assert.ErrorIs(t, item.FinalState().Err, testErr, "The underlying error should be preserved") }, }, { @@ -618,9 +591,10 @@ func TestShardProcessor(t *testing.T) { mockQueue.AddFunc = func(types.QueueItemAccessor) error { return testErr } }, assert: func(t *testing.T, h *testHarness, item *FlowItem) { - assert.Equal(t, types.QueueOutcomeRejectedOther, item.finalState.Outcome, "Outcome should be RejectedOther") - require.Error(t, item.finalState.Err, "An error should be returned") - assert.ErrorIs(t, item.finalState.Err, testErr, "The underlying error should be preserved") + assert.Equal(t, types.QueueOutcomeRejectedOther, item.FinalState().Outcome, + "Outcome should be RejectedOther") + require.Error(t, item.FinalState().Err, "An error should be returned") + assert.ErrorIs(t, item.FinalState().Err, testErr, "The underlying error should be preserved") }, }, { @@ -640,13 +614,13 @@ func TestShardProcessor(t *testing.T) { item: func() *FlowItem { // Create a pre-finalized item. item := newTestHarness(t, 0).newTestItem("req-finalized", testFlow, testTTL) - item.Finalize(types.QueueOutcomeDispatched, nil) + item.FinalizeWithOutcome(types.QueueOutcomeDispatched, nil) return item }(), assert: func(t *testing.T, h *testHarness, item *FlowItem) { // The item was already finalized, so its state should not change. - assert.Equal(t, types.QueueOutcomeDispatched, item.finalState.Outcome, "Outcome should remain unchanged") - assert.NoError(t, item.finalState.Err, "Error should remain unchanged") + assert.Equal(t, types.QueueOutcomeDispatched, item.FinalState().Outcome, "Outcome should remain unchanged") + assert.NoError(t, item.FinalState().Err, "Error should remain unchanged") }, }, } @@ -905,9 +879,9 @@ func TestShardProcessor(t *testing.T) { // Verify all high-priority items are gone and low-priority items remain. for _, item := range highPrioItems { - assert.Equal(t, types.QueueOutcomeDispatched, item.finalState.Outcome, + assert.Equal(t, types.QueueOutcomeDispatched, item.FinalState().Outcome, "High-priority item should be dispatched") - assert.NoError(t, item.finalState.Err, "Dispatched high-priority item should not have an error") + assert.NoError(t, item.FinalState().Err, "Dispatched high-priority item should not have an error") } assert.Equal(t, numItems, qLow.Len(), "Low-priority queue should still be full") @@ -967,11 +941,12 @@ func TestShardProcessor(t *testing.T) { } }) - t.Run("should evict item that expires at moment of dispatch", func(t *testing.T) { + t.Run("should not dispatch already finalized item", func(t *testing.T) { t.Parallel() // --- ARRANGE --- h := newTestHarness(t, testCleanupTick) - item := h.newTestItem("req-expired-dispatch", testFlow, testShortTTL) + item := h.newTestItem("req-already-finalized", testFlow, testTTL) + item.FinalizeWithOutcome(types.QueueOutcomeRejectedOther, errors.New("already done")) h.ManagedQueueFunc = func(types.FlowKey) (contracts.ManagedQueue, error) { return &mocks.MockManagedQueue{ @@ -982,43 +957,61 @@ func TestShardProcessor(t *testing.T) { } // --- ACT --- - h.clock.Step(testShortTTL * 2) // Make the item expire. err := h.processor.dispatchItem(item, h.logger) // --- ASSERT --- - // First, check the error returned by `dispatchItem`. - require.Error(t, err, "dispatchItem should return an error for an expired item") - assert.ErrorIs(t, err, types.ErrTTLExpired, "The error should be of type ErrTTLExpired") - - // Second, check the final state of the item itself. - assert.Equal(t, types.QueueOutcomeEvictedTTL, item.finalState.Outcome, - "The item's final outcome should be EvictedTTL") - require.Error(t, item.finalState.Err, "The item's final state should contain an error") - assert.ErrorIs(t, item.finalState.Err, types.ErrTTLExpired, - "The item's final error should be of type ErrTTLExpired") + require.NoError(t, err, "dispatchItem should return no error for an already finalized item") + + // Check the final state of the item itself - it should not have changed. + finalState := item.FinalState() + require.NotNil(t, finalState, "Item must be finalized") + assert.Equal(t, types.QueueOutcomeRejectedOther, finalState.Outcome, + "The item's final outcome should be RejectedOther") + assert.ErrorContains(t, finalState.Err, "already done", + "The error should be the one from the first Finalize call") }) }) t.Run("cleanup and utility methods", func(t *testing.T) { t.Parallel() - t.Run("should remove and finalize expired items", func(t *testing.T) { + t.Run("should sweep externally finalized items", func(t *testing.T) { t.Parallel() // --- ARRANGE --- h := newTestHarness(t, testCleanupTick) - // Create an item that is already expired relative to the cleanup time. - item := h.newTestItem("req-expired", testFlow, 1*time.Millisecond) + item := h.newTestItem("req-external-finalized", testFlow, testTTL) q := h.addQueue(testFlow) - require.NoError(t, q.Add(item)) - cleanupTime := h.clock.Now().Add(10 * time.Millisecond) + require.NoError(t, q.Add(item), "Failed to add item to queue") + + // Externally finalize the item + item.Finalize(context.Canceled) + require.NotNil(t, item.FinalState(), "Item should be finalized") // --- ACT --- - h.processor.cleanupExpired(cleanupTime) + h.processor.sweepFinalizedItems() // --- ASSERT --- - assert.Equal(t, types.QueueOutcomeEvictedTTL, item.finalState.Outcome, "Item outcome should be EvictedTTL") - require.Error(t, item.finalState.Err, "Item should have an error") - assert.ErrorIs(t, item.finalState.Err, types.ErrTTLExpired, "Item error should be ErrTTLExpired") + assert.Equal(t, 0, q.Len(), "Queue should be empty after sweep") + finalState := item.FinalState() + assert.Equal(t, types.QueueOutcomeEvictedContextCancelled, finalState.Outcome, + "Outcome should be EvictedContextCancelled") + assert.ErrorIs(t, finalState.Err, types.ErrContextCancelled, "Error should be ErrContextCancelled") + }) + + t.Run("should not sweep items not finalized", func(t *testing.T) { + t.Parallel() + // --- ARRANGE --- + h := newTestHarness(t, testCleanupTick) + item := h.newTestItem("req-not-finalized", testFlow, testTTL) + q := h.addQueue(testFlow) + require.NoError(t, q.Add(item), "Failed to add item to queue") + + // --- ACT --- + h.processor.sweepFinalizedItems() + + // --- ASSERT --- + assert.Equal(t, 1, q.Len(), "Queue should still contain the item") + assert.Nil(t, item.FinalState(), "Item should not be finalized") }) t.Run("should evict all items on shutdown", func(t *testing.T) { @@ -1033,9 +1026,10 @@ func TestShardProcessor(t *testing.T) { h.processor.evictAll() // --- ASSERT --- - assert.Equal(t, types.QueueOutcomeEvictedOther, item.finalState.Outcome, "Item outcome should be EvictedOther") - require.Error(t, item.finalState.Err, "Item should have an error") - assert.ErrorIs(t, item.finalState.Err, types.ErrFlowControllerNotRunning, + assert.Equal(t, types.QueueOutcomeEvictedOther, item.FinalState().Outcome, + "Item outcome should be EvictedOther") + require.Error(t, item.FinalState().Err, "Item should have an error") + assert.ErrorIs(t, item.FinalState().Err, types.ErrFlowControllerNotRunning, "Item error should be ErrFlowControllerNotRunning") }) @@ -1055,25 +1049,6 @@ func TestShardProcessor(t *testing.T) { }, "processAllQueuesConcurrently should not panic on registry errors") }) - t.Run("should handle items of an unexpected type gracefully during finalization", func(t *testing.T) { - t.Parallel() - // --- ARRANGE --- - h := newTestHarness(t, testCleanupTick) - item := &typesmocks.MockQueueItemAccessor{ - OriginalRequestV: typesmocks.NewMockFlowControlRequest(0, "bad-item", testFlow, context.Background()), - } - items := []types.QueueItemAccessor{item} - - // --- ACT & ASSERT --- - // The test passes if this call completes without panicking. - assert.NotPanics(t, func() { - getOutcome := func(types.QueueItemAccessor) (types.QueueOutcome, error) { - return types.QueueOutcomeEvictedOther, nil - } - h.processor.finalizeItems(items, h.logger, getOutcome) - }, "finalizeItems should not panic on unexpected item types") - }) - t.Run("should process all queues with a worker pool", func(t *testing.T) { t.Parallel() // --- ARRANGE --- @@ -1122,6 +1097,26 @@ func TestShardProcessor(t *testing.T) { require.Error(t, err, "Submit must return an error when the channel is full") assert.ErrorIs(t, err, ErrProcessorBusy, "The returned error must be ErrProcessorBusy") }) + + t.Run("should return ErrFlowControllerNotRunning if lifecycleCtx is cancelled", func(t *testing.T) { + t.Parallel() + h := newTestHarness(t, testCleanupTick) + h.Start() + h.Go() // Ensure the Run loop has started + h.cancel() // Cancel the lifecycle context + h.Stop() // Wait for the processor to fully stop + + item := h.newTestItem("item-ctx-cancel", testFlow, testTTL) + err := h.processor.Submit(item) + require.ErrorIs(t, err, types.ErrFlowControllerNotRunning, + "Submit must return ErrFlowControllerNotRunning when lifecycleCtx is cancelled") + assert.Nil(t, item.FinalState(), "Item should not be finalized by Submit") + + err = h.processor.SubmitOrBlock(context.Background(), item) + require.ErrorIs(t, err, types.ErrFlowControllerNotRunning, + "SubmitOrBlock must return ErrFlowControllerNotRunning when lifecycleCtx is cancelled") + assert.Nil(t, item.FinalState(), "Item should not be finalized by SubmitOrBlock") + }) }) t.Run("SubmitOrBlock", func(t *testing.T) { @@ -1195,114 +1190,9 @@ func TestShardProcessor(t *testing.T) { require.Error(t, err, "SubmitOrBlock should return an error when shutting down") assert.ErrorIs(t, err, types.ErrFlowControllerNotRunning, "The error should be ErrFlowControllerNotRunning") - outcome, err := h.waitForFinalization(item) - assert.Equal(t, types.QueueOutcomeRejectedOther, outcome, "The outcome should be RejectedOther") - require.Error(t, err, "Finalization should include an error") - assert.ErrorIs(t, err, types.ErrFlowControllerNotRunning, - "The finalization error should be ErrFlowControllerNotRunning") + // Item should not be finalized by the processor + assert.Nil(t, item.FinalState(), "Item should not be finalized by the processor") }) }) }) } - -func TestCheckItemExpiry(t *testing.T) { - t.Parallel() - - // --- ARRANGE --- - now := time.Now() - ctxCancelled, cancel := context.WithCancel(context.Background()) - cancel() // Cancel the context immediately. - - testCases := []struct { - name string - item types.QueueItemAccessor - now time.Time - expectExpired bool - expectOutcome types.QueueOutcome - expectErr error - }{ - { - name: "should not be expired if TTL is not reached and context is active", - item: NewItem( - typesmocks.NewMockFlowControlRequest(100, "req-not-expired", testFlow, context.Background()), - testTTL, - now), - now: now.Add(30 * time.Second), - expectExpired: false, - expectOutcome: types.QueueOutcomeNotYetFinalized, - expectErr: nil, - }, - { - name: "should not be expired if TTL is disabled (0)", - item: NewItem( - typesmocks.NewMockFlowControlRequest(100, "req-not-expired-no-ttl", testFlow, context.Background()), - 0, - now), - now: now.Add(30 * time.Second), - expectExpired: false, - expectOutcome: types.QueueOutcomeNotYetFinalized, - expectErr: nil, - }, - { - name: "should be expired if TTL is exceeded", - item: NewItem( - typesmocks.NewMockFlowControlRequest(100, "req-ttl-expired", testFlow, context.Background()), - time.Second, - now), - now: now.Add(2 * time.Second), - expectExpired: true, - expectOutcome: types.QueueOutcomeEvictedTTL, - expectErr: types.ErrTTLExpired, - }, - { - name: "should be expired if context is cancelled", - item: NewItem( - typesmocks.NewMockFlowControlRequest(100, "req-ctx-cancelled", testFlow, ctxCancelled), - testTTL, - now), - now: now, - expectExpired: true, - expectOutcome: types.QueueOutcomeEvictedContextCancelled, - expectErr: types.ErrContextCancelled, - }, - { - name: "should be expired if already finalized", - item: func() types.QueueItemAccessor { - i := NewItem( - typesmocks.NewMockFlowControlRequest(100, "req-finalized", testFlow, context.Background()), - testTTL, - now) - i.Finalize(types.QueueOutcomeDispatched, nil) - return i - }(), - now: now, - expectExpired: true, - expectOutcome: types.QueueOutcomeDispatched, - expectErr: nil, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - // --- ACT --- - isExpired, outcome, err := checkItemExpiry(tc.item, tc.now) - - // --- ASSERT --- - assert.Equal(t, tc.expectExpired, isExpired, "Expired status should match expected value") - assert.Equal(t, tc.expectOutcome, outcome, "Outcome should match expected value") - - if tc.expectErr != nil { - require.Error(t, err, "An error was expected") - // Use ErrorIs for sentinel errors, ErrorContains for general messages. - if errors.Is(tc.expectErr, types.ErrTTLExpired) || errors.Is(tc.expectErr, types.ErrContextCancelled) { - assert.ErrorIs(t, err, tc.expectErr, "The specific error type should be correct") - } else { - assert.ErrorContains(t, err, tc.expectErr.Error(), "The error message should contain the expected text") - } - } else { - assert.NoError(t, err, "No error was expected") - } - }) - } -} diff --git a/pkg/epp/flowcontrol/types/mocks/mocks.go b/pkg/epp/flowcontrol/types/mocks/mocks.go index c52c5c2db..5fabf3683 100644 --- a/pkg/epp/flowcontrol/types/mocks/mocks.go +++ b/pkg/epp/flowcontrol/types/mocks/mocks.go @@ -19,7 +19,6 @@ limitations under the License. package mocks import ( - "context" "time" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" @@ -28,7 +27,6 @@ import ( // MockFlowControlRequest provides a mock implementation of the `types.FlowControlRequest` interface. type MockFlowControlRequest struct { - Ctx context.Context FlowKeyV types.FlowKey ByteSizeV uint64 InitialEffectiveTTLV time.Duration @@ -41,20 +39,14 @@ func NewMockFlowControlRequest( byteSize uint64, id string, key types.FlowKey, - ctx context.Context, ) *MockFlowControlRequest { - if ctx == nil { - ctx = context.Background() - } return &MockFlowControlRequest{ ByteSizeV: byteSize, IDV: id, FlowKeyV: key, - Ctx: ctx, } } -func (m *MockFlowControlRequest) Context() context.Context { return m.Ctx } func (m *MockFlowControlRequest) FlowKey() types.FlowKey { return m.FlowKeyV } func (m *MockFlowControlRequest) ByteSize() uint64 { return m.ByteSizeV } func (m *MockFlowControlRequest) InitialEffectiveTTL() time.Duration { return m.InitialEffectiveTTLV } @@ -114,7 +106,6 @@ func NewMockQueueItemAccessor(byteSize uint64, reqID string, key types.FlowKey) byteSize, reqID, key, - context.Background(), ), HandleV: &MockQueueItemHandle{}, } diff --git a/pkg/epp/flowcontrol/types/request.go b/pkg/epp/flowcontrol/types/request.go index 255d3bc45..e427b0aba 100644 --- a/pkg/epp/flowcontrol/types/request.go +++ b/pkg/epp/flowcontrol/types/request.go @@ -17,7 +17,6 @@ limitations under the License. package types import ( - "context" "time" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/backend/metrics" @@ -30,11 +29,6 @@ import ( // wraps this object with its own internal structures (which implement `QueueItemAccessor`) to manage the request's // lifecycle without modifying the original. type FlowControlRequest interface { - // Context returns the request's context. The `controller.FlowController` uses this for monitoring cancellation (e.g., - // if the client disconnects or a request-scoped timeout occurs), which can lead to the request being evicted from a - // queue. - Context() context.Context - // FlowKey returns the composite key that uniquely identifies the flow instance this request belongs to. // The `controller.FlowController` uses this key as the primary identifier to look up the correct // `contracts.ManagedQueue` and configured `framework.IntraFlowDispatchPolicy` from a `contracts.RegistryShard`. diff --git a/pkg/epp/requestcontrol/admission.go b/pkg/epp/requestcontrol/admission.go index 383d2844a..69fd5adf8 100644 --- a/pkg/epp/requestcontrol/admission.go +++ b/pkg/epp/requestcontrol/admission.go @@ -62,7 +62,7 @@ type saturationDetector interface { // flowController defines the minimal interface required by FlowControlAdmissionController for enqueuing requests and // waiting for an admission outcome. type flowController interface { - EnqueueAndWait(req types.FlowControlRequest) (types.QueueOutcome, error) + EnqueueAndWait(ctx context.Context, req types.FlowControlRequest) (types.QueueOutcome, error) } // rejectIfSheddableAndSaturated checks if a request should be immediately rejected because it's sheddable @@ -157,7 +157,6 @@ func (fcac *FlowControlAdmissionController) Admit( logger.V(logutil.TRACE).Info("Request proceeding to flow control", "requestID", reqCtx.SchedulingRequest.RequestId) fcReq := &flowControlRequest{ - ctx: ctx, requestID: reqCtx.SchedulingRequest.RequestId, fairnessID: reqCtx.FairnessID, priority: priority, @@ -165,7 +164,7 @@ func (fcac *FlowControlAdmissionController) Admit( candidatePods: candidatePods, } - outcome, err := fcac.flowController.EnqueueAndWait(fcReq) + outcome, err := fcac.flowController.EnqueueAndWait(ctx, fcReq) logger.V(logutil.DEBUG).Info("Flow control outcome", "requestID", reqCtx.SchedulingRequest.RequestId, "outcome", outcome, "error", err) return translateFlowControlOutcome(outcome, err) @@ -173,7 +172,6 @@ func (fcac *FlowControlAdmissionController) Admit( // flowControlRequest is an adapter that implements the types.FlowControlRequest interface. type flowControlRequest struct { - ctx context.Context requestID string fairnessID string priority int @@ -183,7 +181,6 @@ type flowControlRequest struct { var _ types.FlowControlRequest = &flowControlRequest{} -func (r *flowControlRequest) Context() context.Context { return r.ctx } func (r *flowControlRequest) ID() string { return r.requestID } func (r *flowControlRequest) InitialEffectiveTTL() time.Duration { return 0 } // Use controller default. func (r *flowControlRequest) ByteSize() uint64 { return r.requestByteSize } diff --git a/pkg/epp/requestcontrol/admission_test.go b/pkg/epp/requestcontrol/admission_test.go index 002c50f06..085778200 100644 --- a/pkg/epp/requestcontrol/admission_test.go +++ b/pkg/epp/requestcontrol/admission_test.go @@ -48,7 +48,10 @@ type mockFlowController struct { called bool } -func (m *mockFlowController) EnqueueAndWait(_ fctypes.FlowControlRequest) (fctypes.QueueOutcome, error) { +func (m *mockFlowController) EnqueueAndWait( + _ context.Context, + _ fctypes.FlowControlRequest, +) (fctypes.QueueOutcome, error) { m.called = true return m.outcome, m.err } @@ -115,7 +118,6 @@ func TestLegacyAdmissionController_Admit(t *testing.T) { func TestFlowControlRequestAdapter(t *testing.T) { t.Parallel() - ctx := context.Background() candidatePods := []backendmetrics.PodMetrics{&backendmetrics.FakePodMetrics{}} testCases := []struct { @@ -140,7 +142,6 @@ func TestFlowControlRequestAdapter(t *testing.T) { t.Run(tc.name, func(t *testing.T) { t.Parallel() fcReq := &flowControlRequest{ - ctx: ctx, requestID: tc.requestID, fairnessID: tc.fairnessID, priority: tc.priority, @@ -148,7 +149,6 @@ func TestFlowControlRequestAdapter(t *testing.T) { candidatePods: candidatePods, } - assert.Equal(t, ctx, fcReq.Context(), "Context() mismatch") assert.Equal(t, tc.requestID, fcReq.ID(), "ID() mismatch") assert.Equal(t, tc.requestByteSize, fcReq.ByteSize(), "ByteSize() mismatch") assert.Equal(t, candidatePods, fcReq.CandidatePodsForScheduling(), "CandidatePodsForScheduling() mismatch")