diff --git a/pkg/epp/flowcontrol/contracts/errors.go b/pkg/epp/flowcontrol/contracts/errors.go index 38af37031..beb758fe0 100644 --- a/pkg/epp/flowcontrol/contracts/errors.go +++ b/pkg/epp/flowcontrol/contracts/errors.go @@ -33,4 +33,8 @@ var ( // ErrInvalidShardCount indicates that an invalid shard count was provided (e.g., zero or negative). ErrInvalidShardCount = errors.New("invalid shard count") + + // ErrShardDraining indicates that an operation could not be completed because the target shard is in the process of + // being gracefully drained. The caller should retry the operation on a different, Active shard. + ErrShardDraining = errors.New("shard is draining") ) diff --git a/pkg/epp/flowcontrol/contracts/registry.go b/pkg/epp/flowcontrol/contracts/registry.go index 1b788d91a..de1b89ae6 100644 --- a/pkg/epp/flowcontrol/contracts/registry.go +++ b/pkg/epp/flowcontrol/contracts/registry.go @@ -21,104 +21,84 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" ) -// FlowRegistry is the complete interface for the global control plane. An implementation of this interface is the single -// source of truth for all flow control state and configuration. +// FlowRegistry is the complete interface for the global flow control plane. +// It composes the client-facing data path interface and the administrative interface. A concrete implementation of this +// interface is the single source of truth for all flow control state. // -// # Conformance +// # Conformance: Implementations MUST be goroutine-safe. +// +// # Flow Lifecycle +// +// A flow instance, identified by its immutable `types.FlowKey`, has a lease-based lifecycle managed by this interface. +// Any implementation MUST adhere to this lifecycle: // -// All methods MUST be goroutine-safe. Implementations are expected to perform complex updates (e.g., -// `RegisterOrUpdateFlow`) atomically. +// 1. Lease Acquisition: A client calls Connect to acquire a lease. This signals that the flow is in use and protects +// it from garbage collection. If the flow does not exist, it is created Just-In-Time (JIT). +// 2. Active State: A flow is "Active" as long as its lease count is greater than zero. +// 3. Lease Release: The client MUST call `Close()` on the returned `FlowConnection` to release the lease. +// When the lease count drops to zero, the flow becomes "Idle". +// 4. Garbage Collection: The implementation MUST automatically garbage collect a flow after it has remained +// continuously Idle for a configurable duration. // // # System Invariants // // Concrete implementations MUST uphold the following invariants: +// // 1. Shard Consistency: All configured priority bands and registered flow instances must exist on every Active shard. -// Plugin instance types must be consistent for a given flow across all shards. -// 2. Flow Instance Uniqueness: Each unique `types.FlowKey` (`ID` + `Priority`) corresponds to exactly one managed flow -// instance. -// 3. Capacity Partitioning: Global and per-band capacity limits must be uniformly partitioned across all Active +// 2. Capacity Partitioning: Global and per-band capacity limits must be uniformly partitioned across all Active // shards. -// -// # Flow Lifecycle -// -// A flow instance (identified by its immutable `FlowKey`) has a simple lifecycle: -// -// - Registered: Known to the `FlowRegistry` via `RegisterOrUpdateFlow`. -// - Idle: Queues are empty across all Active and Draining shards. -// - Garbage Collected (Unregistered): The registry automatically garbage collects flows after they have remained Idle -// for a configurable duration. -// -// # Shard Lifecycle -// -// When a shard is decommissioned, it is marked inactive (`IsActive() == false`) to prevent new enqueues. The shard -// continues to drain and is deleted only after it is empty. type FlowRegistry interface { + FlowRegistryClient FlowRegistryAdmin - ShardProvider } // FlowRegistryAdmin defines the administrative interface for the global control plane. -// -// # Dynamic Update Strategies -// -// The contract specifies behaviors for handling dynamic updates, prioritizing stability and correctness: -// -// - Immutable Flow Identity (`types.FlowKey`): The system treats the `FlowKey` (`ID` + `Priority`) as the immutable -// identifier. Changing the priority of traffic requires registering a new `FlowKey`. The old flow instance is -// automatically garbage collected when Idle. This design eliminates complex priority migration logic. -// -// - Graceful Draining (Shard Scale-Down): Decommissioned shards enter a Draining state. They stop accepting new -// requests but continue to be processed for dispatch until empty. -// -// - Self-Balancing (Shard Scale-Up): When new shards are added, the `controller.FlowController`'s distribution logic -// naturally utilizes them, funneling new requests to the less-loaded shards. Existing queued items are not -// migrated. type FlowRegistryAdmin interface { - // RegisterOrUpdateFlow handles the registration of a new flow instance or the update of an existing instance's - // specification (for the same `types.FlowKey`). The operation is atomic across all shards. - // - // Since the `FlowKey` (including `Priority`) is immutable, this method cannot change a flow's priority. - // To change priority, the caller should simply register the new `FlowKey`; the old flow instance will be - // automatically garbage collected when it becomes Idle. - // - // Returns errors wrapping `ErrFlowIDEmpty`, `ErrPriorityBandNotFound`, or internal errors if plugin instantiation - // fails. - RegisterOrUpdateFlow(spec types.FlowSpecification) error - - // UpdateShardCount dynamically adjusts the number of internal state shards. - // - // The implementation MUST atomically re-partition capacity allocations across all active shards. - // Returns an error wrapping `ErrInvalidShardCount` if `n` is not positive. - UpdateShardCount(n int) error - // Stats returns globally aggregated statistics for the entire `FlowRegistry`. Stats() AggregateStats - // ShardStats returns a slice of statistics, one for each internal shard. This provides visibility for debugging and - // monitoring per-shard behavior (e.g., identifying hot or stuck shards). + // ShardStats returns a slice of statistics, one for each internal shard. ShardStats() []ShardStats } -// ShardProvider defines the interface for discovering available shards. -// -// A "shard" is an internal, parallel execution unit that allows the `controller.FlowController`'s core dispatch logic -// to be parallelized, preventing a CPU bottleneck at high request rates. The `FlowRegistry`'s state is sharded to -// support this parallelism by reducing lock contention. +// FlowRegistryClient defines the primary, client-facing interface for the registry. +// This is the interface that the `controller.FlowController`'s data path depends upon. +type FlowRegistryClient interface { + // WithConnection manages a scoped, leased session for a given flow. + // It is the primary and sole entry point for interacting with the data path. + // + // This method handles the entire lifecycle of a flow connection: + // 1. Just-In-Time (JIT) Registration: If the flow for the given `types.FlowKey` does not exist, it is created and + // registered automatically. + // 2. Lease Acquisition: It acquires a lifecycle lease, protecting the flow from garbage collection. + // 3. Callback Execution: It invokes the provided function `fn`, passing in a temporary `ActiveFlowConnection` handle. + // 4. Guaranteed Lease Release: It ensures the lease is safely released when the callback function returns. + // + // This functional, callback-based approach makes resource leaks impossible, as the caller is not responsible for + // manually closing the connection. + // + // Errors returned by the callback `fn` are propagated up. + // Returns `ErrFlowIDEmpty` if the provided key has an empty ID. + WithConnection(key types.FlowKey, fn func(conn ActiveFlowConnection) error) error +} + +// ActiveFlowConnection represents a handle to a temporary, leased session on a flow. +// It provides a safe, scoped entry point to the registry's sharded data plane. // -// Consumers MUST check `RegistryShard.IsActive()` before routing new work to a shard to avoid sending requests to a -// Draining shard. -type ShardProvider interface { - // Shards returns a slice of accessors, one for each internal state shard (Active and Draining). - // Callers should not modify the returned slice. +// An `ActiveFlowConnection` instance is only valid for the duration of the `WithConnection` callback from which it was +// received. Callers MUST NOT store a reference to this object or use it after the callback returns. +// Its purpose is to ensure that any interaction with the flow's state (e.g., accessing its shards and queues) occurs +// safely while the flow is guaranteed to be protected from garbage collection. +type ActiveFlowConnection interface { + // Shards returns a stable snapshot of accessors for all internal state shards (both Active and Draining). + // Consumers MUST check `RegistryShard.IsActive()` before routing new work to a shard from this slice. Shards() []RegistryShard } -// RegistryShard defines the interface for accessing a specific slice (shard) of the `FlowRegistry's` state. -// It provides a concurrent-safe view for `controller.FlowController` workers. -// -// # Conformance +// RegistryShard defines the interface for a single slice (shard) of the `FlowRegistry`'s state. +// A shard acts as an independent, parallel execution unit, allowing the system's dispatch logic to scale horizontally. // -// All methods MUST be goroutine-safe. +// # Conformance: Implementations MUST be goroutine-safe. type RegistryShard interface { // ID returns a unique identifier for this shard, which must remain stable for the shard's lifetime. ID() string @@ -163,14 +143,16 @@ type RegistryShard interface { Stats() ShardStats } -// ManagedQueue defines the interface for a flow's queue instance on a specific shard. -// It acts as a stateful decorator around an underlying `framework.SafeQueue`. +// ManagedQueue defines the interface for a flow's queue on a specific shard. +// It acts as a stateful decorator around an underlying `framework.SafeQueue`, augmenting it with statistics tracking. // // # Conformance // -// - All methods MUST be goroutine-safe. -// - All mutating methods (`Add()`, `Remove()`, etc.) MUST ensure that the underlying queue state and the statistics -// (`Len`, `ByteSize`) are updated atomically relative to each other. +// - Implementations MUST be goroutine-safe. +// - All mutating methods MUST ensure that the underlying queue state and the public statistics (`Len`, `ByteSize`) +// are updated as a single atomic transaction. +// - The `Add` method MUST return an error wrapping `ErrShardDraining` if the queue instance belongs to a parent shard +// that is no longer Active. type ManagedQueue interface { framework.SafeQueue diff --git a/pkg/epp/flowcontrol/registry/connection.go b/pkg/epp/flowcontrol/registry/connection.go new file mode 100644 index 000000000..995f23c13 --- /dev/null +++ b/pkg/epp/flowcontrol/registry/connection.go @@ -0,0 +1,44 @@ +/* +Copyright 2025 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package registry + +import ( + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" +) + +// connection is the concrete, un-exported implementation of the `contracts.ActiveFlowConnection` interface. +// It is a temporary handle created for the duration of a single `WithConnection` call. +type connection struct { + registry *FlowRegistry + key types.FlowKey +} + +var _ contracts.ActiveFlowConnection = &connection{} + +// Shards returns a stable snapshot of accessors for all internal state shards. +func (c *connection) Shards() []contracts.RegistryShard { + c.registry.mu.RLock() + defer c.registry.mu.RUnlock() + + // Return a copy to ensure the caller cannot modify the registry's internal slice. + shardsCopy := make([]contracts.RegistryShard, len(c.registry.allShards)) + for i, s := range c.registry.allShards { + shardsCopy[i] = s + } + return shardsCopy +} diff --git a/pkg/epp/flowcontrol/registry/doc.go b/pkg/epp/flowcontrol/registry/doc.go index 3355c77fc..8ab1c5d4d 100644 --- a/pkg/epp/flowcontrol/registry/doc.go +++ b/pkg/epp/flowcontrol/registry/doc.go @@ -14,116 +14,23 @@ See the License for the specific language governing permissions and limitations under the License. */ -// Package registry provides the concrete implementation of the `contracts.FlowRegistry`. +// Package registry provides the concrete implementation of the `contracts.FlowRegistry` interface. // -// As the stateful control plane, this package manages the lifecycle of all flows, queues, and policies. It provides a -// sharded, concurrent-safe view of its state to the `controller.FlowController` workers, enabling scalable, parallel -// request processing. +// # Architecture: A Sharded, Concurrent Control Plane // -// # Architecture: Composite, Sharded, and Separated Concerns +// This package implements the flow control state machine using a sharded architecture to enable scalable, parallel +// request processing. It separates the orchestration control plane from the request-processing data plane. // -// The registry separates the control plane (orchestration) from the data plane (request processing state). +// - `FlowRegistry`: The top-level orchestrator (Control Plane). It manages the lifecycle of all flows and shards, +// handling registration, garbage collection, and scaling operations. +// - `registryShard`: A slice of the data plane. It holds a partition of the total state and provides a +// read-optimized, concurrent-safe view for a single `controller.FlowController` worker. +// - `managedQueue`: A stateful decorator around a `framework.SafeQueue`. It is the fundamental unit of state, +// responsible for atomically tracking statistics (e.g., length and byte size) and ensuring data consistency. // -// - `FlowRegistry`: The Control Plane. The top-level orchestrator and single source of truth. It centralizes complex -// operations: flow registration, garbage collection (GC) coordination, and shard scaling. -// -// - `registryShard`: The Data Plane Slice. A concurrent-safe "slice" of the registry's total state. It provides a -// read-optimized view for `FlowController` workers. -// -// - `managedQueue`: The Stateful Decorator. A wrapper around a `framework.SafeQueue`. It augments the queue with -// atomic statistics tracking and signaling state transitions to the control plane. -// -// # Data Flow and Interaction Model -// -// The data path (Enqueue/Dispatch) is optimized for minimal latency and maximum concurrency. -// -// Enqueue Path: -// 1. The `FlowController`'s distributor selects an active `registryShard`. -// 2. The distributor calls `shard.ManagedQueue(flowKey)` (acquires `RLock`). -// 3. The distributor calls `managedQueue.Add(item)`. -// 4. `managedQueue` atomically updates the queue and its stats and signals the control plane. -// -// Dispatch Path: -// 1. A `FlowController` worker iterates over its assigned `registryShard`. -// 2. The worker uses policies and accessors to select the next item (acquires `RLock`). -// 3. The worker calls `managedQueue.Remove(handle)`. -// 4. `managedQueue` atomically updates the queue and its stats and signals the control plane. -// -// # Concurrency Strategy: Multi-Tiered and Optimized -// -// The registry maximizes performance on the hot path while ensuring strict correctness for complex state transitions: -// -// - Serialized Control Plane (Actor Model): The `FlowRegistry` uses a single background goroutine to process all -// state change events serially, eliminating race conditions in the control plane. -// -// - Sharding (Data Plane Parallelism): State is partitioned across multiple `registryShard` instances, allowing the -// data path to scale linearly. -// -// - Lock-Free Data Path (Atomics): Statistics aggregation (Shard/Registry level) uses lock-free atomics. -// -// - Strict Consistency (Hybrid Locking): `managedQueue` uses a hybrid locking model (Mutex for writes, Atomics for -// reads) to guarantee strict consistency between queue contents and statistics, which is required for GC -// correctness. +// # Concurrency Model // +// The registry uses a multi-layered strategy to maximize performance on the hot path while ensuring correctness for +// administrative tasks. // (See the `FlowRegistry` struct documentation for detailed locking rules). -// -// # Garbage Collection: The "Trust but Verify" Pattern -// -// The registry handles the race condition between asynchronous data path activity and synchronous GC. The control plane -// maintains an eventually consistent cache (`flowState`). -// -// The registry uses a periodic, generational "Trust but Verify" pattern. It identifies candidate flows using the cache. -// Before deletion, it performs a "Verify" step: it synchronously acquires write locks on ALL shards and queries the -// ground truth (live queue counters). This provides strong consistency when needed. -// -// (See the `garbageCollectFlowLocked` function documentation for detailed steps). -// -// # Scalability Characteristics and Trade-offs -// -// The architecture prioritizes data path throughput and correctness, introducing specific trade-offs: -// -// - Data Path Throughput (Excellent): Scales linearly with the number of shards and benefits from lock-free -// statistics updates. -// -// - GC Latency Impact (Trade-off): The GC "Verify" step requires locking all shards (O(N)). This briefly pauses the -// data path. As the shard count (N) increases, this may impact P99 latency. This trade-off guarantees correctness. -// -// - Control Plane Responsiveness during Scale-Up (Trade-off): Scaling up requires synchronizing all existing flows -// (M) onto the new shards (K). This O(M*K) operation occurs under the main control plane lock. If M is large, this -// operation may block the control plane. -// -// # Event-Driven State Machine and Lifecycle Scenarios -// -// The system relies on atomic state transitions to generate reliable, edge-triggered signals. These signals are sent -// reliably; if the event channel is full, the sender blocks, applying necessary backpressure to ensure no events are -// lost, preventing state divergence. -// -// The following scenarios detail how the registry handles lifecycle events: -// -// New Flow Registration: A new flow instance `F1` (`FlowKey{ID: "A", Priority: 10}`) is registered. -// -// 1. `managedQueue` instances are created for `F1` on all shards. -// 2. The `flowState` cache marks `F1` as Idle. If it remains Idle, it will eventually be garbage collected. -// -// Flow Activity/Inactivity: -// -// 1. When the first request for `F1` is enqueued, the queue signals `BecameNonEmpty`. The control plane marks `F1` -// as Active, protecting it from GC. -// 2. When the last request is dispatched globally, the queues signal `BecameEmpty`. The control plane updates the -// cache, and `F1` is now considered Idle by the GC scanner. -// -// "Changing" Flow Priority: Traffic for `ID: "A"` needs to shift from `Priority: 10` to `Priority: 20`. -// -// 1. The caller registers a new flow instance, `F2` (`FlowKey{ID: "A", Priority: 20}`). -// 2. The system treats `F1` and `F2` as independent entities (Immutable `FlowKey` design). -// 3. As `F1` becomes Idle, it is automatically garbage collected. This achieves the outcome gracefully without complex -// state migration logic. -// -// Shard Scaling: -// -// - Scale-Up: New shards are created and marked Active. Existing flows are synchronized onto the new shards. -// Configuration is re-partitioned. -// - Scale-Down: Targeted shards transition to Draining (stop accepting new work). Configuration is re-partitioned -// across remaining active shards. When a Draining shard is empty, it signals `BecameDrained` and is removed by the -// control plane. package registry diff --git a/pkg/epp/flowcontrol/registry/events.go b/pkg/epp/flowcontrol/registry/events.go deleted file mode 100644 index c76a38e10..000000000 --- a/pkg/epp/flowcontrol/registry/events.go +++ /dev/null @@ -1,127 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package registry - -import ( - "fmt" - - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" -) - -// ============================================================================= -// Control Plane Events (Transport) -// ============================================================================= - -// event is a marker interface for internal state machine events processed by the `FlowRegistry`'s event loop. -type event interface { - isEvent() -} - -// queueStateChangedEvent is sent when a `managedQueue`'s state changes, carrying a `queueStateSignal`. -type queueStateChangedEvent struct { - shardID string - key types.FlowKey - signal queueStateSignal -} - -func (queueStateChangedEvent) isEvent() {} - -// shardStateChangedEvent is sent when a `registryShard`'s state changes, carrying a `shardStateSignal`. -type shardStateChangedEvent struct { - shardID string - signal shardStateSignal -} - -func (shardStateChangedEvent) isEvent() {} - -// syncEvent is used exclusively for testing to synchronize test execution with the `FlowRegistry`'s event loop. -// It allows tests to wait until all preceding events have been processed, eliminating the need for polling. -// test-only -type syncEvent struct { - // doneCh is used by the event loop to signal back to the sender that the sync event has been processed. - doneCh chan struct{} -} - -func (syncEvent) isEvent() {} - -// ============================================================================= -// Control Plane Signals and Callbacks -// ============================================================================= - -// --- Queue Signals --- - -// queueStateSignal is an enum describing the edge-triggered state change events emitted by a `managedQueue`. -type queueStateSignal int - -const ( - // queueStateSignalBecameEmpty is sent when a queue transitions from non-empty to empty. - // Trigger: Len > 0 -> Len == 0. Used for inactivity GC tracking. - queueStateSignalBecameEmpty queueStateSignal = iota - - // queueStateSignalBecameNonEmpty is sent when a queue transitions from empty to non-empty. - // Trigger: Len == 0 -> Len > 0. Used for inactivity GC tracking. - queueStateSignalBecameNonEmpty -) - -func (s queueStateSignal) String() string { - switch s { - case queueStateSignalBecameEmpty: - return "QueueBecameEmpty" - case queueStateSignalBecameNonEmpty: - return "QueueBecameNonEmpty" - default: - return fmt.Sprintf("Unknown(%d)", s) - } -} - -// signalQueueStateFunc defines the callback function that a `managedQueue` uses to signal lifecycle events. -// Implementations should avoid blocking on internal locks or I/O. However, they are expected to block if the -// `FlowRegistry`'s event channel is full; this behavior is required to apply backpressure and ensure reliable event -// delivery for the GC system. -type signalQueueStateFunc func(key types.FlowKey, signal queueStateSignal) - -// --- Shard Signals --- - -// shardStateSignal is an enum describing the edge-triggered state change events emitted by a `registryShard`. -type shardStateSignal int - -const ( - // shardStateSignalBecameDrained is sent when a Draining shard transitions to empty. - // Trigger: Transition from `componentStatusDraining` -> `componentStatusDrained`. Used for final GC of the shard. - shardStateSignalBecameDrained shardStateSignal = iota -) - -func (s shardStateSignal) String() string { - switch s { - case shardStateSignalBecameDrained: - return "ShardBecameDrained" - default: - return fmt.Sprintf("Unknown(%d)", s) - } -} - -// signalShardStateFunc defines the callback function that a `registryShard` uses to signal its own state changes. -// Implementations should avoid blocking on internal locks or I/O. However, they are expected to block if the -// `FlowRegistry`'s event channel is full (see `signalQueueStateFunc` for rationale). -type signalShardStateFunc func(shardID string, signal shardStateSignal) - -// --- Statistics Propagation --- - -// propagateStatsDeltaFunc defines the callback function used to propagate statistics changes (deltas) up the hierarchy -// (Queue -> Shard -> Registry). -// Implementations MUST be non-blocking (relying on atomics). -type propagateStatsDeltaFunc func(priority uint, lenDelta, byteSizeDelta int64) diff --git a/pkg/epp/flowcontrol/registry/flowstate.go b/pkg/epp/flowcontrol/registry/flowstate.go deleted file mode 100644 index b515d8f9e..000000000 --- a/pkg/epp/flowcontrol/registry/flowstate.go +++ /dev/null @@ -1,105 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package registry - -import ( - "fmt" - - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" -) - -// flowState holds all tracking state for a single flow instance within the registry. -// -// # Role: The Eventually Consistent Cache for GC -// -// This structure is central to the GC logic. It acts as an eventually consistent, cached view of a flow's emptiness -// status across all shards. It is updated asynchronously via events from the data path. -// -// # Concurrency and Consistency Model -// -// `flowState` is a passive, non-thread-safe data structure. Access to `flowState` is serialized by `FlowRegistry.mu`. -// -// Invariant: Because this state is eventually consistent, it must not be the sole source of truth for destructive -// operations (like garbage collection). All destructive actions must first verify the live state of the system by -// consulting the atomic counters on the `managedQueue` instances directly (the "Trust but Verify" pattern). -type flowState struct { - // spec is the desired state of the flow instance, uniquely identified by its immutable `spec.Key`. - spec types.FlowSpecification - // emptyOnShards tracks the empty status of the flow's queue across all shards. - // The key is the shard ID. - emptyOnShards map[string]bool - // lastActiveGeneration tracks the GC generation number in which this flow was last observed to be Active. - // This is used by the periodic GC scanner to identify flows that have been Idle for at least one GC cycle. - lastActiveGeneration uint64 -} - -// newFlowState creates the initial state for a newly registered flow instance. -// It initializes the state based on the current set of all shards. -func newFlowState(spec types.FlowSpecification, allShards []*registryShard) *flowState { - s := &flowState{ - spec: spec, - emptyOnShards: make(map[string]bool, len(allShards)), - // A new flow starts at generation 0. This sentinel value indicates to the GC "Mark" phase that the flow is brand - // new and should be granted a grace period of one cycle. - lastActiveGeneration: gcGenerationNewFlow, - } - // A new flow instance starts with an empty queue on all shards. - for _, shard := range allShards { - s.emptyOnShards[shard.id] = true - } - return s -} - -// update synchronizes the flow's specification. This is a forward-looking method for when non-key fields on the spec -// become mutable (e.g., a per-flow policy override). It enforces the invariant that the `FlowKey` cannot change. -func (s *flowState) update(newSpec types.FlowSpecification) { - if s.spec.Key != newSpec.Key { - // This should be impossible if the FlowRegistry's logic is correct, as flowStates are keyed by FlowKey. - panic(fmt.Sprintf("invariant violation: attempted to update flowState for key %v with a new key %v", - s.spec.Key, newSpec.Key)) - } - s.spec = newSpec -} - -// handleQueueSignal updates the flow's internal emptiness state based on a signal from one of its queues. -func (s *flowState) handleQueueSignal(shardID string, signal queueStateSignal) { - switch signal { - case queueStateSignalBecameEmpty: - s.emptyOnShards[shardID] = true - case queueStateSignalBecameNonEmpty: - s.emptyOnShards[shardID] = false - } -} - -// isIdle checks if the flow's queues are empty across the provided set of shards (active and draining). -// -// The `shards` parameter defines the scope of the check (the current ground truth). We only verify the state against -// this list because the internal `emptyOnShards` map might contain stale entries for shards that have been garbage -// collected by the registry but not yet purged from this specific `flowState`. -func (s *flowState) isIdle(shards []*registryShard) bool { - for _, shard := range shards { - if !s.emptyOnShards[shard.id] { - return false - } - } - return true -} - -// purgeShard removes a decommissioned shard's ID from the tracking map to prevent memory leaks. -func (s *flowState) purgeShard(shardID string) { - delete(s.emptyOnShards, shardID) -} diff --git a/pkg/epp/flowcontrol/registry/flowstate_test.go b/pkg/epp/flowcontrol/registry/flowstate_test.go deleted file mode 100644 index ceedbd53d..000000000 --- a/pkg/epp/flowcontrol/registry/flowstate_test.go +++ /dev/null @@ -1,216 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package registry - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" -) - -// fsTestHarness holds all the common components for a `flowState` test. -type fsTestHarness struct { - t *testing.T - fs *flowState - allShards []*registryShard -} - -// newFsTestHarness creates a new test harness, initializing a `flowState` with a default spec and shard layout. -func newFsTestHarness(t *testing.T) *fsTestHarness { - t.Helper() - - spec := types.FlowSpecification{ - Key: types.FlowKey{ID: "f1", Priority: 10}, - } - // Note: The shards are simple structs for this test, as we only need their IDs. - allShards := []*registryShard{{id: "s1"}, {id: "s2"}, {id: "s3-draining"}} - - fs := newFlowState(spec, allShards) - - return &fsTestHarness{ - t: t, - fs: fs, - allShards: allShards, - } -} - -func TestFlowState_New(t *testing.T) { - t.Parallel() - h := newFsTestHarness(t) - - assert.Equal(t, uint64(0), h.fs.lastActiveGeneration, "Initial lastActiveGeneration should be 0") - require.Len(t, h.fs.emptyOnShards, 3, "Should track emptiness status for all initial shards") - - // A new flow instance should be considered empty on all shards by default. - assert.True(t, h.fs.emptyOnShards["s1"], "Queue on s1 should start as empty") - assert.True(t, h.fs.emptyOnShards["s2"], "Queue on s2 should start as empty") - assert.True(t, h.fs.emptyOnShards["s3-draining"], "Queue on s3-draining should start as empty") -} - -func TestFlowState_Update(t *testing.T) { - t.Parallel() - - t.Run("ShouldUpdateSpec_WhenKeyIsUnchanged", func(t *testing.T) { - t.Parallel() - h := newFsTestHarness(t) - - // Create a new spec with the same key but different (hypothetical) content. - updatedSpec := types.FlowSpecification{ - Key: h.fs.spec.Key, - // Imagine other fields like per-flow `framework.IntraFlowDispatchPolicy` overrides being added here in the - // future. - } - - h.fs.update(updatedSpec) - assert.Equal(t, updatedSpec, h.fs.spec, "Spec should be updated with new content") - }) - - t.Run("ShouldPanic_WhenKeyIsChanged", func(t *testing.T) { - t.Parallel() - h := newFsTestHarness(t) - - invalidSpec := types.FlowSpecification{ - Key: types.FlowKey{ID: "f1", Priority: 99}, // Different priority - } - - assert.PanicsWithValue(t, - fmt.Sprintf("invariant violation: attempted to update flowState for key %s with a new key %s", - h.fs.spec.Key, invalidSpec.Key), - func() { h.fs.update(invalidSpec) }, - "Should panic when attempting to change the immutable FlowKey", - ) - }) -} - -func TestFlowState_HandleQueueSignal(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - signal queueStateSignal - initialState bool - expectedState bool - message string - }{ - { - name: "BecameNonEmpty_WhenPreviouslyEmpty", - signal: queueStateSignalBecameNonEmpty, - initialState: true, - expectedState: false, - message: "Should mark queue as non-empty (false)", - }, - { - name: "BecameEmpty_WhenPreviouslyNonEmpty", - signal: queueStateSignalBecameEmpty, - initialState: false, - expectedState: true, - message: "Should mark queue as empty (true)", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - h := newFsTestHarness(t) - shardID := "s1" - - h.fs.emptyOnShards[shardID] = tc.initialState - h.fs.handleQueueSignal(shardID, tc.signal) - assert.Equal(t, tc.expectedState, h.fs.emptyOnShards[shardID], tc.message) - }) - } -} - -func TestFlowState_IsIdle(t *testing.T) { - t.Parallel() - - testCases := []struct { - name string - setup func(h *fsTestHarness) - shards []*registryShard // The set of shards to check for idleness. - expectIdle bool - message string - }{ - { - name: "WhenAllShardsAreEmpty", - setup: nil, // Default state is all empty. - expectIdle: true, - message: "Should be idle when all tracked queues are empty", - }, - { - name: "WhenOneShardIsNonEmpty", - setup: func(h *fsTestHarness) { - h.fs.handleQueueSignal("s2", queueStateSignalBecameNonEmpty) - }, - expectIdle: false, - message: "Should not be idle if any queue is non-empty", - }, - { - name: "WhenADrainingShardIsNonEmpty", - setup: func(h *fsTestHarness) { - h.fs.handleQueueSignal("s3-draining", queueStateSignalBecameNonEmpty) - }, - expectIdle: false, - message: "Should not be idle if a draining shard's queue is non-empty", - }, - { - name: "WhenNonEmptyShardIsNotInCheckedSet", - setup: func(h *fsTestHarness) { - // The flow is active on a shard that is Draining. - h.fs.handleQueueSignal("s3-draining", queueStateSignalBecameNonEmpty) - }, - // But we are only checking the set of Active shards. - shards: []*registryShard{{id: "s1"}, {id: "s2"}}, - expectIdle: true, - message: "Should be considered idle if the only non-empty queue is on a shard that is not being checked", - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - h := newFsTestHarness(t) - if tc.setup != nil { - tc.setup(h) - } - - shardsToCheck := tc.shards - if shardsToCheck == nil { - shardsToCheck = h.allShards - } - - assert.Equal(t, tc.expectIdle, h.fs.isIdle(shardsToCheck), tc.message) - }) - } -} - -func TestFlowState_PurgeShard(t *testing.T) { - t.Parallel() - h := newFsTestHarness(t) - shardToPurge := "s2" - - require.Contains(t, h.fs.emptyOnShards, shardToPurge, "Test setup: shard must exist before purging") - h.fs.purgeShard(shardToPurge) - - assert.NotContains(t, h.fs.emptyOnShards, shardToPurge, "Shard should be purged from the tracking map") - assert.Contains(t, h.fs.emptyOnShards, "s1", "Other shards should remain in the tracking map") - assert.Len(t, h.fs.emptyOnShards, 2, "Map length should be reduced by one") -} diff --git a/pkg/epp/flowcontrol/registry/lifecycle.go b/pkg/epp/flowcontrol/registry/lifecycle.go deleted file mode 100644 index 8b5ba6e8e..000000000 --- a/pkg/epp/flowcontrol/registry/lifecycle.go +++ /dev/null @@ -1,56 +0,0 @@ -/* -Copyright 2025 The Kubernetes Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package registry - -import ( - "fmt" -) - -// ============================================================================= -// Component Lifecycle State Machine -// ============================================================================= - -// componentStatus represents the lifecycle state of a `registryShard`. -// It is manipulated using atomic operations (e.g., `atomic.Int32`) to ensure robust, atomic state transitions. -type componentStatus int32 - -const ( - // componentStatusActive indicates the component is fully operational and accepting new work. - componentStatusActive componentStatus = iota - - // componentStatusDraining indicates the component is shutting down. It is not accepting new work, but is still - // processing existing work. - componentStatusDraining - - // componentStatusDrained indicates the component has finished draining and is empty. - // The transition into this state (from `Draining`) occurs exactly once via `CompareAndSwap` and triggers the - // `BecameDrained` signal. This acts as an atomic latch for GC. - componentStatusDrained -) - -func (s componentStatus) String() string { - switch s { - case componentStatusActive: - return "Active" - case componentStatusDraining: - return "Draining" - case componentStatusDrained: - return "Drained" - default: - return fmt.Sprintf("Unknown(%d)", s) - } -} diff --git a/pkg/epp/flowcontrol/registry/managedqueue.go b/pkg/epp/flowcontrol/registry/managedqueue.go index 5d3a67093..e56cdd821 100644 --- a/pkg/epp/flowcontrol/registry/managedqueue.go +++ b/pkg/epp/flowcontrol/registry/managedqueue.go @@ -29,73 +29,60 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -// managedQueueCallbacks groups the callback functions that a `managedQueue` uses to communicate with its parent shard. -type managedQueueCallbacks struct { - // propagateStatsDelta is called to propagate statistics changes (e.g., queue length, byte size) up to the parent. - propagateStatsDelta propagateStatsDeltaFunc - // signalQueueState is called to signal important lifecycle events, such as becoming empty, which are used by the - // garbage collector. - signalQueueState signalQueueStateFunc -} - // managedQueue implements `contracts.ManagedQueue`. It acts as a stateful decorator around a `framework.SafeQueue`. // -// # Role: The Stateful Decorator -// -// Its primary responsibility is to augment a generic queue implementation with critical registry features: strictly -// consistent statistics tracking and exactly-once, edge-triggered signaling of state changes (e.g., becoming empty) to -// the control plane for garbage collection. +// # Role: The Stateful Statistics Decorator // -// # Concurrency Model: Hybrid Locking (Mutex + Atomics + `SafeQueue`) +// Its sole responsibility is to augment a generic queue with strictly consistent, atomic statistics tracking. +// All mutating operations (`Add`, `Remove`, etc.) are wrapped to ensure that the queue's internal state and its +// externally visible statistics (`Len`, `ByteSize`) are always perfectly synchronized. // -// The `managedQueue` employs a hybrid locking strategy to guarantee strict consistency between the queue contents and -// its statistics, while maintaining high performance. +// # Concurrency Model: Mutex for Writes, Atomics for Reads // -// 1. Mutex-Protected Writes (`sync.Mutex`): A single mutex protects all mutating operations (Add, Remove, etc.). -// This ensures that the update to the underlying queue and the update to the internal counters occur as a single, -// atomic transaction. This strict consistency is required for GC correctness. +// The `managedQueue` employs a hybrid locking strategy to guarantee strict consistency while maintaining high read +// performance. // -// 2. Synchronous Propagation and Ordering: The propagation of statistics deltas (both `Len` and `ByteSize`) occurs -// synchronously within this critical section. This strict ordering (`Add` propagation before `Remove` propagation -// for the same item) guarantees the non-negative invariant across the entire system (Shard/Registry aggregates). +// 1. Mutex-Protected Writes (`sync.Mutex`): A single mutex protects all mutating operations. +// This ensures the update to the underlying queue and the update to the internal counters occur as a single, atomic +// transaction. +// 2. Synchronous Propagation: Statistics deltas are propagated synchronously within this critical section, +// guaranteeing a non-negativity invariant across the entire system (Shard/Registry aggregates). +// 3. Lock-Free Reads (Atomics): The counters use `atomic.Int64`, allowing high-frequency accessors (`Len()`, +// `ByteSize()`) to read statistics without acquiring the mutex. // -// 3. Lock-Free Statistics Reads (Atomics): The counters use `atomic.Int64`. This allows high-frequency accessors -// (`Len()`, `ByteSize()`) to read the statistics without acquiring the Mutex. -// -// 4. Concurrent Reads (`SafeQueue`): Read operations (e.g., `PeekHead`/`PeekTail` used by policies) rely on the -// underlying `framework.SafeQueue`'s concurrency control and do not acquire the `managedQueue`'s write lock. -// -// # Statistical Integrity and Invariant Protection +// # Invariant Protection // +// To guarantee statistical integrity, the following invariants must be upheld: // 1. Exclusive Access: All mutations on the underlying `SafeQueue` MUST be performed exclusively through this wrapper. -// -// 2. Non-Autonomous State: The underlying queue implementation must not change state autonomously (e.g., no internal -// TTL eviction). -// -// 3. Read-Only Proxy: To prevent policy plugins from bypassing statistics tracking, the read-only view is provided via -// the `flowQueueAccessor` wrapper. This enforces the read-only contract at the type system level. +// 2. Non-Autonomous State: The underlying queue must not change state autonomously (e.g., no internal TTL eviction). type managedQueue struct { - // mu protects all mutating operations (writes) on the queue. It ensures that the underlying queue's state and the - // atomic counters are updated atomically. Read operations (like `Peek`) do not acquire this lock. - mu sync.Mutex + // --- Immutable Identity & Dependencies (set at construction) --- + key types.FlowKey + dispatchPolicy framework.IntraFlowDispatchPolicy + logger logr.Logger + + // onStatsDelta is the callback used to propagate statistics changes up to the parent shard. + onStatsDelta propagateStatsDeltaFunc + // isDraining is a callback that checks the lifecycle state of the parent shard, allowing this queue to reject new + // work when the shard is being decommissioned. + isDraining func() bool + + // --- State Protected by `mu` --- + // mu protects all mutating operations. It ensures that any changes to the underlying `queue` and the updates to the + // atomic counters occur as a single, atomic transaction. + mu sync.Mutex // queue is the underlying, concurrency-safe queue implementation that this `managedQueue` decorates. + // Its state must only be modified while holding `mu`. queue framework.SafeQueue - // dispatchPolicy is the intra-flow policy used to select items from this specific queue. - dispatchPolicy framework.IntraFlowDispatchPolicy - - // key uniquely identifies the flow instance this queue belongs to. - key types.FlowKey + // --- Concurrent-Safe State (Atomics) --- - // Queue-level statistics. Updated under the protection of the `mu` lock, but read lock-free. - // Guaranteed to be non-negative. + // Queue-level statistics. + // These are written under the protection of `mu` but can be read lock-free at any time using atomic operations. + // They are guaranteed to be non-negative. byteSize atomic.Int64 len atomic.Int64 - - // parentCallbacks provides the communication channels back to the parent shard. - parentCallbacks managedQueueCallbacks - logger logr.Logger } var _ contracts.ManagedQueue = &managedQueue{} @@ -106,18 +93,20 @@ func newManagedQueue( dispatchPolicy framework.IntraFlowDispatchPolicy, key types.FlowKey, logger logr.Logger, - parentCallbacks managedQueueCallbacks, + onStatsDelta propagateStatsDeltaFunc, + isDraining func() bool, ) *managedQueue { mqLogger := logger.WithName("managed-queue").WithValues( "flowKey", key, "queueType", queue.Name(), ) mq := &managedQueue{ - queue: queue, - dispatchPolicy: dispatchPolicy, - key: key, - parentCallbacks: parentCallbacks, - logger: mqLogger, + queue: queue, + dispatchPolicy: dispatchPolicy, + key: key, + onStatsDelta: onStatsDelta, + logger: mqLogger, + isDraining: isDraining, } return mq } @@ -131,13 +120,19 @@ func (mq *managedQueue) FlowQueueAccessor() framework.FlowQueueAccessor { // Add wraps the underlying `framework.SafeQueue.Add` call and atomically updates the queue's and the parent shard's // statistics. func (mq *managedQueue) Add(item types.QueueItemAccessor) error { + // Enforce the system's routing contract by rejecting new work for a Draining shard. + // This prevents a race where a caller could route a request to a shard just as it begins to drain. + if mq.isDraining() { + return contracts.ErrShardDraining + } + mq.mu.Lock() defer mq.mu.Unlock() if err := mq.queue.Add(item); err != nil { return err } - mq.propagateStatsDelta(1, int64(item.OriginalRequest().ByteSize())) + mq.propagateStatsDeltaLocked(1, int64(item.OriginalRequest().ByteSize())) mq.logger.V(logging.TRACE).Info("Request added to queue", "requestID", item.OriginalRequest().ID()) return nil } @@ -152,7 +147,7 @@ func (mq *managedQueue) Remove(handle types.QueueItemHandle) (types.QueueItemAcc if err != nil { return nil, err } - mq.propagateStatsDelta(-1, -int64(removedItem.OriginalRequest().ByteSize())) + mq.propagateStatsDeltaLocked(-1, -int64(removedItem.OriginalRequest().ByteSize())) mq.logger.V(logging.TRACE).Info("Request removed from queue", "requestID", removedItem.OriginalRequest().ID()) return removedItem, nil } @@ -170,7 +165,7 @@ func (mq *managedQueue) Cleanup(predicate framework.PredicateFunc) (cleanedItems if len(cleanedItems) == 0 { return cleanedItems, nil } - mq.propagateStatsDeltaForRemovedItems(cleanedItems) + mq.propagateStatsDeltaForRemovedItemsLocked(cleanedItems) mq.logger.V(logging.DEBUG).Info("Cleaned up queue", "removedItemCount", len(cleanedItems)) return cleanedItems, nil } @@ -187,7 +182,7 @@ func (mq *managedQueue) Drain() ([]types.QueueItemAccessor, error) { if len(drainedItems) == 0 { return drainedItems, nil } - mq.propagateStatsDeltaForRemovedItems(drainedItems) + mq.propagateStatsDeltaForRemovedItemsLocked(drainedItems) mq.logger.V(logging.DEBUG).Info("Drained queue", "itemCount", len(drainedItems)) return drainedItems, nil } @@ -204,48 +199,33 @@ func (mq *managedQueue) Comparator() framework.ItemComparator { return mq. // --- Internal Methods --- -// propagateStatsDelta updates the queue's statistics, signals emptiness events, and propagates the delta. +// propagateStatsDeltaLocked updates the queue's statistics and propagates the delta to the parent shard. +// It must be called while holding the `managedQueue.mu` lock. // -// Invariant Check: This function panics if a statistic becomes negative. Because all updates are protected by the -// `managedQueue.mu` lock, a negative value indicates a critical logic error (e.g., double-counting a removal). -// This check enforces the non-negative invariant locally, which mathematically guarantees that the aggregated -// statistics (Shard/Registry level) also remain non-negative. -func (mq *managedQueue) propagateStatsDelta(lenDelta, byteSizeDelta int64) { - // Note: We rely on the caller (Add/Remove/etc.) to hold the `managedQueue.mu` lock. - +// Invariant Check: This function panics if a statistic becomes negative. This enforces the non-negative invariant +// locally, which mathematically guarantees that the aggregated statistics (Shard/Registry level) also remain +// non-negative. +func (mq *managedQueue) propagateStatsDeltaLocked(lenDelta, byteSizeDelta int64) { newLen := mq.len.Add(lenDelta) if newLen < 0 { panic(fmt.Sprintf("invariant violation: managedQueue length for flow %s became negative (%d)", mq.key, newLen)) } mq.byteSize.Add(byteSizeDelta) - // Evaluate and signal our own state change *before* propagating the statistics to the parent shard. This ensures a - // strict, bottom-up event ordering, preventing race conditions where the parent might process its own state change - // before the child's signal is handled. - // 1. Evaluate GC signals based on the strictly consistent local state. - // 2. Propagate the delta up to the Shard/Registry. This propagation is lock-free and eventually consistent. - mq.evaluateEmptinessState(newLen-lenDelta, newLen) - mq.parentCallbacks.propagateStatsDelta(mq.key.Priority, lenDelta, byteSizeDelta) -} - -// evaluateEmptinessState checks if the queue has transitioned between non-empty <-> empty and signals the parent if so. -func (mq *managedQueue) evaluateEmptinessState(oldLen, newLen int64) { - if oldLen > 0 && newLen == 0 { - mq.parentCallbacks.signalQueueState(mq.key, queueStateSignalBecameEmpty) - } else if oldLen == 0 && newLen > 0 { - mq.parentCallbacks.signalQueueState(mq.key, queueStateSignalBecameNonEmpty) - } + // Propagate the delta up to the parent shard. This propagation is lock-free and eventually consistent. + mq.onStatsDelta(mq.key.Priority, lenDelta, byteSizeDelta) } -// propagateStatsDeltaForRemovedItems calculates the total stat changes for a slice of removed items and applies them. -func (mq *managedQueue) propagateStatsDeltaForRemovedItems(items []types.QueueItemAccessor) { +// propagateStatsDeltaForRemovedItemsLocked calculates the total stat changes for a slice of removed items and applies +// them. It must be called while holding the `managedQueue.mu` lock. +func (mq *managedQueue) propagateStatsDeltaForRemovedItemsLocked(items []types.QueueItemAccessor) { var lenDelta int64 var byteSizeDelta int64 for _, item := range items { lenDelta-- byteSizeDelta -= int64(item.OriginalRequest().ByteSize()) } - mq.propagateStatsDelta(lenDelta, byteSizeDelta) + mq.propagateStatsDeltaLocked(lenDelta, byteSizeDelta) } // --- `flowQueueAccessor` --- @@ -256,7 +236,7 @@ func (mq *managedQueue) propagateStatsDeltaForRemovedItems(items []types.QueueIt // // This wrapper protects system invariants. It acts as a proxy that exposes only the read-only methods. // This prevents policy plugins from using type assertions to access the concrete `*managedQueue` and calling mutation -// methods, which would bypass statistics tracking and signaling. +// methods, which would bypass statistics tracking. type flowQueueAccessor struct { mq *managedQueue } diff --git a/pkg/epp/flowcontrol/registry/managedqueue_test.go b/pkg/epp/flowcontrol/registry/managedqueue_test.go index 0b3d3daee..f5e3d7fa7 100644 --- a/pkg/epp/flowcontrol/registry/managedqueue_test.go +++ b/pkg/epp/flowcontrol/registry/managedqueue_test.go @@ -18,7 +18,6 @@ package registry import ( "errors" - "fmt" "sync" "sync/atomic" "testing" @@ -27,6 +26,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" frameworkmocks "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/mocks" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/queue" @@ -39,18 +39,17 @@ import ( // mqTestHarness holds all components for testing a `managedQueue`. type mqTestHarness struct { - t *testing.T - mq *managedQueue - propagator *mockStatsPropagator - signalRecorder *mockQueueStateSignalRecorder - mockPolicy *frameworkmocks.MockIntraFlowDispatchPolicy + t *testing.T + mq *managedQueue + propagator *mockStatsPropagator + mockPolicy *frameworkmocks.MockIntraFlowDispatchPolicy } // newMockedMqHarness creates a harness that uses a mocked underlying queue. // This is ideal for isolating and unit testing the decorator logic of `managedQueue`. func newMockedMqHarness(t *testing.T, queue *frameworkmocks.MockSafeQueue, key types.FlowKey) *mqTestHarness { t.Helper() - return newMqHarness(t, queue, key) + return newMqHarness(t, queue, key, false) } // newRealMqHarness creates a harness that uses a real "ListQueue" implementation. @@ -58,42 +57,38 @@ func newMockedMqHarness(t *testing.T, queue *frameworkmocks.MockSafeQueue, key t func newRealMqHarness(t *testing.T, key types.FlowKey) *mqTestHarness { t.Helper() q, err := queue.NewQueueFromName(listqueue.ListQueueName, nil) - require.NoError(t, err, "Test setup: creating a real ListQueue should not fail") - return newMqHarness(t, q, key) + require.NoError(t, err, "Test setup: creating a real ListQueue implementation should not fail") + return newMqHarness(t, q, key, false) } -func newMqHarness(t *testing.T, queue framework.SafeQueue, key types.FlowKey) *mqTestHarness { +// newMqHarness is the base constructor for the test harness. +func newMqHarness(t *testing.T, queue framework.SafeQueue, key types.FlowKey, isDraining bool) *mqTestHarness { t.Helper() propagator := &mockStatsPropagator{} - signalRec := newMockQueueStateSignalRecorder() mockPolicy := &frameworkmocks.MockIntraFlowDispatchPolicy{ ComparatorV: &frameworkmocks.MockItemComparator{}, } - callbacks := managedQueueCallbacks{ - propagateStatsDelta: propagator.propagate, - signalQueueState: signalRec.signal, - } - mq := newManagedQueue(queue, mockPolicy, key, logr.Discard(), callbacks) - require.NotNil(t, mq, "Test setup: newManagedQueue should not return nil") + isDrainingFunc := func() bool { return isDraining } + mq := newManagedQueue(queue, mockPolicy, key, logr.Discard(), propagator.propagate, isDrainingFunc) + require.NotNil(t, mq, "Test setup: newManagedQueue must return a valid instance") return &mqTestHarness{ - t: t, - mq: mq, - propagator: propagator, - signalRecorder: signalRec, - mockPolicy: mockPolicy, + t: t, + mq: mq, + propagator: propagator, + mockPolicy: mockPolicy, } } +// setupWithItems pre-populates the queue and resets the mock propagator for focused testing. func (h *mqTestHarness) setupWithItems(items ...types.QueueItemAccessor) { h.t.Helper() for _, item := range items { err := h.mq.Add(item) - require.NoError(h.t, err, "Harness setup: failed to add item") + require.NoError(h.t, err, "Harness setup: failed to add initial item to the queue") } - // Reset counters after the setup phase is complete. h.propagator.reset() } @@ -113,37 +108,13 @@ func (p *mockStatsPropagator) reset() { p.byteSizeDelta.Store(0) } -// mockQueueStateSignalRecorder captures queue state signals from the system under test. -type mockQueueStateSignalRecorder struct { - mu sync.Mutex - signals []queueStateSignal -} - -func newMockQueueStateSignalRecorder() *mockQueueStateSignalRecorder { - return &mockQueueStateSignalRecorder{signals: make([]queueStateSignal, 0)} -} - -func (r *mockQueueStateSignalRecorder) signal(_ types.FlowKey, signal queueStateSignal) { - r.mu.Lock() - defer r.mu.Unlock() - r.signals = append(r.signals, signal) -} - -func (r *mockQueueStateSignalRecorder) getSignals() []queueStateSignal { - r.mu.Lock() - defer r.mu.Unlock() - signalsCopy := make([]queueStateSignal, len(r.signals)) - copy(signalsCopy, r.signals) - return signalsCopy -} - // --- Unit Tests --- func TestManagedQueue_InitialState(t *testing.T) { t.Parallel() h := newMockedMqHarness(t, &frameworkmocks.MockSafeQueue{}, types.FlowKey{ID: "flow", Priority: 1}) - assert.Zero(t, h.mq.Len(), "A new queue should have a length of 0") - assert.Zero(t, h.mq.ByteSize(), "A new queue should have a byte size of 0") + assert.Zero(t, h.mq.Len(), "A newly initialized queue must have a length of 0") + assert.Zero(t, h.mq.ByteSize(), "A newly initialized queue must have a byte size of 0") } func TestManagedQueue_Add(t *testing.T) { @@ -153,7 +124,9 @@ func TestManagedQueue_Add(t *testing.T) { testCases := []struct { name string setupMock func(q *frameworkmocks.MockSafeQueue) + isDraining bool expectErr bool + expectErrIs error // Optional expectedLenDelta int64 expectedByteSizeDelta int64 }{ @@ -162,6 +135,7 @@ func TestManagedQueue_Add(t *testing.T) { setupMock: func(q *frameworkmocks.MockSafeQueue) { q.AddFunc = func(types.QueueItemAccessor) error { return nil } }, + isDraining: false, expectErr: false, expectedLenDelta: 1, expectedByteSizeDelta: 100, @@ -171,31 +145,46 @@ func TestManagedQueue_Add(t *testing.T) { setupMock: func(q *frameworkmocks.MockSafeQueue) { q.AddFunc = func(types.QueueItemAccessor) error { return errors.New("add failed") } }, + isDraining: false, expectErr: true, expectedLenDelta: 0, expectedByteSizeDelta: 0, }, + { + name: "ShouldFail_AndNotChangeStats_WhenQueueIsDraining", + setupMock: func(q *frameworkmocks.MockSafeQueue) {}, + isDraining: true, + expectErr: true, + expectErrIs: contracts.ErrShardDraining, + expectedLenDelta: 0, + expectedByteSizeDelta: 0, + }, } for _, tc := range testCases { t.Run(tc.name, func(t *testing.T) { t.Parallel() - item := typesmocks.NewMockQueueItemAccessor(100, "req", flowKey) q := &frameworkmocks.MockSafeQueue{} - tc.setupMock(q) - h := newMockedMqHarness(t, q, flowKey) + h := newMqHarness(t, q, flowKey, tc.isDraining) + item := typesmocks.NewMockQueueItemAccessor(100, "req", flowKey) + if tc.setupMock != nil { + tc.setupMock(q) + } err := h.mq.Add(item) if tc.expectErr { - require.Error(t, err, "Add should have returned an error") + require.Error(t, err, "Add operation must fail when the underlying queue returns an error") + if tc.expectErrIs != nil { + assert.ErrorIs(t, err, tc.expectErrIs, "The returned error was not of the expected type") + } } else { - require.NoError(t, err, "Add should not have returned an error") + require.NoError(t, err, "Add operation must succeed when the underlying queue accepts the item") } assert.Equal(t, tc.expectedLenDelta, h.propagator.lenDelta.Load(), - "Propagated length delta should be correct") + "The propagated length delta must exactly match the change in queue size") assert.Equal(t, tc.expectedByteSizeDelta, h.propagator.byteSizeDelta.Load(), - "Propagated byte size delta should be correct") + "The propagated byte size delta must exactly match the change in queue size") }) } } @@ -247,13 +236,15 @@ func TestManagedQueue_Remove(t *testing.T) { _, err := h.mq.Remove(item.Handle()) if tc.expectErr { - require.Error(t, err, "Remove should have returned an error") + require.Error(t, err, + "Remove operation must fail when the underlying queue returns an error (e.g., item not found)") } else { - require.NoError(t, err, "Remove should not have returned an error") + require.NoError(t, err, "Remove operation must succeed when the underlying queue successfully removes the item") } - assert.Equal(t, tc.expectedLenDelta, h.propagator.lenDelta.Load(), "Propagated length delta should be correct") + assert.Equal(t, tc.expectedLenDelta, h.propagator.lenDelta.Load(), + "The propagated length delta must exactly match the change in queue size") assert.Equal(t, tc.expectedByteSizeDelta, h.propagator.byteSizeDelta.Load(), - "Propagated byte size delta should be correct") + "The propagated byte size delta must exactly match the change in queue size") }) } } @@ -270,7 +261,7 @@ func TestManagedQueue_Cleanup(t *testing.T) { expectedByteSizeDelta int64 }{ { - name: "ShouldSucceed_AndDecrementStats", + name: "ShouldSucceed_AndDecrementStats_WhenItemsRemoved", setupMock: func(q *frameworkmocks.MockSafeQueue, items []types.QueueItemAccessor) { q.CleanupFunc = func(_ framework.PredicateFunc) ([]types.QueueItemAccessor, error) { return items, nil @@ -319,13 +310,14 @@ func TestManagedQueue_Cleanup(t *testing.T) { _, err := h.mq.Cleanup(func(_ types.QueueItemAccessor) bool { return true }) if tc.expectErr { - require.Error(t, err, "Cleanup should have returned an error") + require.Error(t, err, "Cleanup operation must fail if the underlying queue implementation encounters an error") } else { - require.NoError(t, err, "Cleanup should not have returned an error") + require.NoError(t, err, "Cleanup operation must succeed if the underlying queue implementation succeeds") } - assert.Equal(t, tc.expectedLenDelta, h.propagator.lenDelta.Load(), "Propagated length delta should be correct") + assert.Equal(t, tc.expectedLenDelta, h.propagator.lenDelta.Load(), + "The propagated length delta must exactly match the total number of items removed during cleanup") assert.Equal(t, tc.expectedByteSizeDelta, h.propagator.byteSizeDelta.Load(), - "Propagated byte size delta should be correct") + "The propagated byte size delta must exactly match the total size of items removed during cleanup") }) } } @@ -380,206 +372,140 @@ func TestManagedQueue_Drain(t *testing.T) { _, err := h.mq.Drain() if tc.expectErr { - require.Error(t, err, "Drain should have returned an error") + require.Error(t, err, "Drain operation must fail if the underlying queue implementation encounters an error") } else { - require.NoError(t, err, "Drain should not have returned an error") + require.NoError(t, err, "Drain operation must succeed if the underlying queue implementation succeeds") } - assert.Equal(t, tc.expectedLenDelta, h.propagator.lenDelta.Load(), "Propagated length delta should be correct") + assert.Equal(t, tc.expectedLenDelta, h.propagator.lenDelta.Load(), + "The propagated length delta must exactly match the total number of items drained") assert.Equal(t, tc.expectedByteSizeDelta, h.propagator.byteSizeDelta.Load(), - "Propagated byte size delta should be correct") + "The propagated byte size delta must exactly match the total size of items drained") }) } } -func TestManagedQueue_PanicOnUnderflow(t *testing.T) { - t.Parallel() - flowKey := types.FlowKey{ID: "flow", Priority: 1} - item := typesmocks.NewMockQueueItemAccessor(100, "req", flowKey) - q := &frameworkmocks.MockSafeQueue{} - q.AddFunc = func(types.QueueItemAccessor) error { return nil } - q.RemoveFunc = func(types.QueueItemHandle) (types.QueueItemAccessor, error) { - return item, nil - } - h := newMockedMqHarness(t, q, flowKey) - - // Add and then successfully remove the item. - require.NoError(t, h.mq.Add(item), "Test setup: Add should succeed") - _, err := h.mq.Remove(item.Handle()) - require.NoError(t, err, "Test setup: First Remove should succeed") - require.Zero(t, h.mq.Len(), "Test setup: Queue should be empty") - - // Attempting to remove the same item again should cause a panic. - assert.PanicsWithValue(t, - fmt.Sprintf("invariant violation: managedQueue length for flow %s became negative (-1)", flowKey), - func() { _, _ = h.mq.Remove(item.Handle()) }, - "A second removal of the same item should trigger a panic on length underflow", - ) -} - -func TestManagedQueue_Signaling(t *testing.T) { +func TestManagedQueue_FlowQueueAccessor(t *testing.T) { t.Parallel() - flowKey := types.FlowKey{ID: "flow", Priority: 1} - h := newRealMqHarness(t, flowKey) - item1 := typesmocks.NewMockQueueItemAccessor(100, "r1", flowKey) - item2 := typesmocks.NewMockQueueItemAccessor(50, "r2", flowKey) - - // 1. Initial state: Empty - assert.Empty(t, h.signalRecorder.getSignals(), "No signals should be present on a new queue") - - // 2. Transition: Empty -> NonEmpty - require.NoError(t, h.mq.Add(item1), "Adding an item should not fail") - assert.Equal(t, []queueStateSignal{queueStateSignalBecameNonEmpty}, h.signalRecorder.getSignals(), - "Should signal BecameNonEmpty on first add") - - // 3. Steady state: NonEmpty -> NonEmpty - require.NoError(t, h.mq.Add(item2), "Adding a second item should not fail") - assert.Equal(t, []queueStateSignal{queueStateSignalBecameNonEmpty}, h.signalRecorder.getSignals(), - "No new signal should be sent when adding to a non-empty queue") - - // 4. Steady state: NonEmpty -> NonEmpty - _, err := h.mq.Remove(item1.Handle()) - require.NoError(t, err, "Removing an item should not fail") - assert.Equal(t, []queueStateSignal{queueStateSignalBecameNonEmpty}, h.signalRecorder.getSignals(), - "No new signal should be sent when removing from a multi-item queue") - - // 5. Transition: NonEmpty -> Empty - _, err = h.mq.Remove(item2.Handle()) - require.NoError(t, err, "Removing the last item should not fail") - expectedSignalSequence := []queueStateSignal{queueStateSignalBecameNonEmpty, queueStateSignalBecameEmpty} - assert.Equal(t, expectedSignalSequence, h.signalRecorder.getSignals(), - "Should signal BecameEmpty on removal of the last item") -} -func TestManagedQueue_FlowQueueAccessor_ProxiesCalls(t *testing.T) { - t.Parallel() - flowKey := types.FlowKey{ID: "flow", Priority: 1} - q := &frameworkmocks.MockSafeQueue{} - harness := newMockedMqHarness(t, q, flowKey) - item := typesmocks.NewMockQueueItemAccessor(100, "req-1", flowKey) - q.PeekHeadV = item - q.PeekTailV = item - q.NameV = "MockQueue" - q.CapabilitiesV = []framework.QueueCapability{framework.CapabilityFIFO} - require.NoError(t, harness.mq.Add(item), "Test setup: Adding an item should not fail") - - accessor := harness.mq.FlowQueueAccessor() - require.NotNil(t, accessor, "FlowQueueAccessor should not be nil") - - assert.Equal(t, harness.mq.Name(), accessor.Name(), "Accessor Name() should match managed queue") - assert.Equal(t, harness.mq.Capabilities(), accessor.Capabilities(), - "Accessor Capabilities() should match managed queue") - assert.Equal(t, harness.mq.Len(), accessor.Len(), "Accessor Len() should match managed queue") - assert.Equal(t, harness.mq.ByteSize(), accessor.ByteSize(), "Accessor ByteSize() should match managed queue") - assert.Equal(t, flowKey, accessor.FlowKey(), "Accessor FlowKey() should match managed queue") - assert.Equal(t, harness.mockPolicy.Comparator(), accessor.Comparator(), - "Accessor Comparator() should match the one from the policy") - assert.Equal(t, harness.mockPolicy.Comparator(), harness.mq.Comparator(), - "ManagedQueue Comparator() should also match the one from the policy") - - peekedHead, err := accessor.PeekHead() - require.NoError(t, err, "Accessor PeekHead() should not return an error") - assert.Same(t, item, peekedHead, "Accessor PeekHead() should return the correct item instance") - - peekedTail, err := accessor.PeekTail() - require.NoError(t, err, "Accessor PeekTail() should not return an error") - assert.Same(t, item, peekedTail, "Accessor PeekTail() should return the correct item instance") + t.Run("ProxiesCalls", func(t *testing.T) { + t.Parallel() + flowKey := types.FlowKey{ID: "flow", Priority: 1} + q := &frameworkmocks.MockSafeQueue{} + harness := newMockedMqHarness(t, q, flowKey) + item := typesmocks.NewMockQueueItemAccessor(100, "req-1", flowKey) + q.PeekHeadV = item + q.PeekTailV = item + q.NameV = "MockQueue" + q.CapabilitiesV = []framework.QueueCapability{framework.CapabilityFIFO} + require.NoError(t, harness.mq.Add(item), "Test setup: Adding an item must succeed") + + accessor := harness.mq.FlowQueueAccessor() + require.NotNil(t, accessor, "FlowQueueAccessor must return a non-nil instance (guaranteed by contract)") + + assert.Equal(t, harness.mq.Name(), accessor.Name(), "Accessor Name() must proxy the underlying queue's name") + assert.Equal(t, harness.mq.Capabilities(), accessor.Capabilities(), + "Accessor Capabilities() must proxy the underlying queue's capabilities") + assert.Equal(t, harness.mq.Len(), accessor.Len(), "Accessor Len() must reflect the managed queue's current length") + assert.Equal(t, harness.mq.ByteSize(), accessor.ByteSize(), + "Accessor ByteSize() must reflect the managed queue's current byte size") + assert.Equal(t, flowKey, accessor.FlowKey(), "Accessor FlowKey() must return the correct identifier for the flow") + assert.Equal(t, harness.mockPolicy.Comparator(), accessor.Comparator(), + "Accessor Comparator() must return the comparator provided by the configured intra-flow policy") + assert.Equal(t, harness.mockPolicy.Comparator(), harness.mq.Comparator(), + "ManagedQueue Comparator() must also return the comparator provided by the configured intra-flow policy") + + peekedHead, err := accessor.PeekHead() + require.NoError(t, err, "Accessor PeekHead() must succeed when the underlying queue succeeds") + assert.Same(t, item, peekedHead, "Accessor PeekHead() must return the exact item instance at the head") + + peekedTail, err := accessor.PeekTail() + require.NoError(t, err, "Accessor PeekTail() must succeed when the underlying queue succeeds") + assert.Same(t, item, peekedTail, "Accessor PeekTail() must return the exact item instance at the tail") + }) + + t.Run("EmptyQueue", func(t *testing.T) { + t.Parallel() + flowKey := types.FlowKey{ID: "flow", Priority: 1} + q := &frameworkmocks.MockSafeQueue{} + expectedErr := errors.New("queue is empty") + q.PeekHeadErrV = expectedErr + harness := newMockedMqHarness(t, q, flowKey) + accessor := harness.mq.FlowQueueAccessor() + + _, err := accessor.PeekHead() + require.Error(t, err, "Accessor PeekHead() should return an error on an empty queue") + assert.ErrorIs(t, err, expectedErr, "Accessor should proxy the specific error from the underlying queue") + }) } -// --- Concurrency Tests --- +// --- Concurrency Test --- -// TestManagedQueue_Concurrency_SignalingRace targets the race condition of a queue flapping between empty and non-empty -// states. -// It ensures the `BecameEmpty` and `BecameNonEmpty` signals are sent correctly in strict alternation, without -// duplicates or missed signals. -func TestManagedQueue_Concurrency_SignalingRace(t *testing.T) { +// TestManagedQueue_Concurrency_StatsIntegrity validates that under high contention, the final propagated statistics +// are consistent. It spins up multiple goroutines that concurrently and rapidly add and remove items to stress-test +// the mutex protecting write operations and the atomic propagation of statistics. +func TestManagedQueue_Concurrency_StatsIntegrity(t *testing.T) { t.Parallel() - - const ops = 1000 - flowKey := types.FlowKey{ID: "flow", Priority: 1} - h := newRealMqHarness(t, flowKey) - - var wg sync.WaitGroup - wg.Add(2) - - // Goroutine 1: Adder - continuously adds items. - go func() { - defer wg.Done() - for range ops { - _ = h.mq.Add(typesmocks.NewMockQueueItemAccessor(10, "req", flowKey)) - } - }() - - // Goroutine 2: Remover - continuously drains the queue. - go func() { - defer wg.Done() - for range ops { - // Drain is used to remove all items present in a single atomic operation. - _, _ = h.mq.Drain() - } - }() - - wg.Wait() - - // Verification: The critical part of this test is to analyze the sequence of signals. - signals := h.signalRecorder.getSignals() - require.NotEmpty(t, signals, "At least some signals should have been generated") - - // The sequence must be a strict alternation of `BecameNonEmpty` and `BecameEmpty` signals. - // There should never be two of the same signal in a row. - for i := 0; i < len(signals)-1; i++ { - assert.NotEqual(t, signals[i], signals[i+1], "Signals at index %d and %d must not be duplicates", i, i+1) - } - - // The first signal must be `BecameNonEmpty`. - assert.Equal(t, queueStateSignalBecameNonEmpty, signals[0], "The first signal must be BecameNonEmpty") -} - -// TestManagedQueue_Concurrency_ItemIntegrity validates that under high concurrency, the queue does not lose or -// duplicate items and that the final propagated statistics are consistent with the operations performed. -func TestManagedQueue_Concurrency_ItemIntegrity(t *testing.T) { - t.Parallel() - - const numGoRoutines = 10 - const opsPerGoRoutine = 500 - const itemByteSize = 10 + const ( + numWorkers = 10 + opsPerWorker = 500 + itemByteSize = 10 + ) flowKey := types.FlowKey{ID: "flow", Priority: 1} h := newRealMqHarness(t, flowKey) var wg sync.WaitGroup - wg.Add(numGoRoutines) + wg.Add(numWorkers) - // Each goroutine will attempt to perform a mix of `Add` and `Remove` operations. - for range numGoRoutines { + for range numWorkers { go func() { defer wg.Done() - for range opsPerGoRoutine { - // Add an item. + for range opsPerWorker { item := typesmocks.NewMockQueueItemAccessor(uint64(itemByteSize), "req", flowKey) - require.NoError(t, h.mq.Add(item), "Concurrent Add should not fail") - - // Immediately try to remove it. This creates high contention on the queue's internal state. + require.NoError(t, h.mq.Add(item), "Concurrent Add operation must succeed without errors or races") + // In this chaos test, `Remove` may fail if another goroutine removes the item first. This is expected. _, _ = h.mq.Remove(item.Handle()) } }() } - wg.Wait() - // After all operations, the queue should ideally be empty, but some removals might have failed if another goroutine - // got to it first. We drain any remaining items to get a final count. + // After all operations, the queue should ideally be empty, but we drain any remaining items to get a definitive final + // state. _, err := h.mq.Drain() - require.NoError(t, err, "Final drain should not fail") + require.NoError(t, err, "Final drain operation must succeed to finalize the test state") + + assert.Zero(t, h.mq.Len(), "Final queue length must be zero after draining all remaining items") + assert.Zero(t, h.mq.ByteSize(), "Final queue byte size must be zero after draining all remaining items") + assert.Equal(t, int64(0), h.propagator.lenDelta.Load(), + "The net length delta propagated across all operations must be exactly zero") + assert.Equal(t, int64(0), h.propagator.byteSizeDelta.Load(), + "The net byte size delta propagated across all operations must be exactly zero") +} - // Final State Verification - assert.Zero(t, h.mq.Len(), "Queue length must be zero after final drain") - assert.Zero(t, h.mq.ByteSize(), "Queue byte size must be zero after final drain") +// --- Invariant Test --- - // Statistical Integrity Verification - // The total number of propagated additions must equal the number of propagated removals. - lenDelta := h.propagator.lenDelta.Load() - byteSizeDelta := h.propagator.byteSizeDelta.Load() +func TestManagedQueue_InvariantPanics_OnUnderflow(t *testing.T) { + t.Parallel() + flowKey := types.FlowKey{ID: "flow", Priority: 1} + item := typesmocks.NewMockQueueItemAccessor(100, "req", flowKey) + q := &frameworkmocks.MockSafeQueue{} + q.AddFunc = func(types.QueueItemAccessor) error { return nil } + q.RemoveFunc = func(types.QueueItemHandle) (types.QueueItemAccessor, error) { + return item, nil + } + h := newMockedMqHarness(t, q, flowKey) - assert.Equal(t, int64(0), lenDelta, - "Net length delta propagated must be zero, proving every add was matched by a remove delta") - assert.Equal(t, int64(0), byteSizeDelta, "Net byte size delta propagated must be zero") + require.NoError(t, h.mq.Add(item), "Test setup: Initial Add must succeed") + _, err := h.mq.Remove(item.Handle()) + require.NoError(t, err, "Test setup: First Remove must succeed") + + // This remove call should cause the stats to go negative. + assert.Panics(t, + func() { + // Mock the underlying queue to succeed on the second remove, even though it's logically inconsistent. + // This isolates the panic to the `managedQueue`'s decorator logic. + _, _ = h.mq.Remove(item.Handle()) + }, + "Attempting to remove an item that results in negative statistics must trigger an invariant violation panic", + ) } diff --git a/pkg/epp/flowcontrol/registry/registry.go b/pkg/epp/flowcontrol/registry/registry.go index 34ed8fc0f..95d604ede 100644 --- a/pkg/epp/flowcontrol/registry/registry.go +++ b/pkg/epp/flowcontrol/registry/registry.go @@ -17,131 +17,128 @@ limitations under the License. package registry import ( + "cmp" "context" "fmt" + "slices" "sync" "sync/atomic" + "time" "github.com/go-logr/logr" "k8s.io/utils/clock" + "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) +// propagateStatsDeltaFunc defines the callback function used to propagate statistics changes (deltas) up the hierarchy +// (Queue -> Shard -> Registry). +// Implementations MUST be non-blocking (relying on atomics). +type propagateStatsDeltaFunc func(priority uint, lenDelta, byteSizeDelta int64) + // bandStats holds the aggregated atomic statistics for a single priority band across all shards. type bandStats struct { byteSize atomic.Int64 len atomic.Int64 } +// flowState holds all tracking state for a single flow instance within the registry. +type flowState struct { + key types.FlowKey + + // gcLock protects the flow's lifecycle state. + // - The Garbage Collector takes an exclusive write lock to safely delete the flow. + // - Active connections take a shared read lock for the duration of their operation, preventing the GC from running + // while allowing other connections to proceed concurrently. + gcLock sync.RWMutex + + // leaseCount is an atomic reference counter for all concurrent, in-flight connections. + // It is the sole source of truth for determining if a flow is Idle. + leaseCount atomic.Int64 + + // becameIdleAt tracks the time at which the lease count last dropped to zero. + // A zero value indicates the flow is currently Active. + // This field is always protected by the gcLock's exclusive write lock during modifications. + becameIdleAt time.Time +} + // FlowRegistry is the concrete implementation of the `contracts.FlowRegistry` interface. // // # Role: The Central Orchestrator // -// The `FlowRegistry` is the single source of truth for all configuration and the lifecycle manager for all shards and -// flow instances (identified by `types.FlowKey`). It is responsible for complex, multi-step operations such as flow -// registration, dynamic shard scaling, and coordinating garbage collection across all shards. +// The `FlowRegistry` is the single source of truth for flow control configuration and the lifecycle manager for all +// shards and flow instances. It provides a highly concurrent data path for request processing while ensuring +// correctness for administrative tasks like scaling and garbage collection. // -// # Concurrency Model: The Serialized Control Plane (Actor Model) +// # Concurrency Model: A Multi-Layered Strategy // -// To ensure correctness during complex state transitions, the `FlowRegistry` employs an Actor-like pattern. All -// administrative operations and internal state change events are serialized. +// The registry is designed for high throughput by separating the concurrency domains of the request hot path, garbage +// collection, and administrative tasks. // -// A single background goroutine (the `Run` loop) processes events from the `events` channel and a periodic `gcTicker`. -// This event loop acquires the main `mu` lock before processing any event. External administrative methods (like -// `RegisterOrUpdateFlow`) also acquire this lock. This strict serialization eliminates race conditions in the control -// plane, simplifying the complex logic of the distributed state machine. +// 1. `sync.Map` for `flowStates` (Hot Path): The `WithConnection` method uses a `sync.Map` for highly concurrent, +// often lock-free, lookups and Just-In-Time registration of different flows. +// 2. `flowState.gcLock` (`sync.RWMutex`) for Per-Flow Lifecycle: Each flow has its own `RWMutex` to arbitrate between +// active connections and the garbage collector. This surgical locking prevents GC on one flow from impacting any +// other. The interaction is as follows: +// 3. `mu (sync.RWMutex)` for Global Topology: A single registry-wide mutex protects the overall shard topology during +// infrequent administrative operations like scaling. // -// # Detailed Concurrency Strategy and Locking Rules +// # Flow Lifecycle: Lease-Based with Surgical GC // -// The registry employs a multi-tiered concurrency strategy: +// A flow's lifecycle is managed by a lease-based reference count. // -// 1. Serialized Control Plane (Actor Model): Implemented by the `Run` loop and the `mu` lock. +// 1. Lease Acquisition: A client calls `WithConnection` to begin a managed session. This acquires a lease by +// incrementing an atomic counter. +// 2. Lease Release: Lease release is automatic and guaranteed. When the callback function provided to `WithConnection` +// returns, the lease is released by decrementing the atomic counter. When the count reaches zero, the flow is +// marked Idle with a timestamp. +// 3. Garbage Collection: A background task scans for Idle flows. To prevent a TOCTOU race, the GC acquires an +// exclusive lock on the specific `flowState.gcLock` before re-verifying the lease count is still zero and +// proceeding with deletion. // -// 2. Coarse-Grained Admin Lock (`FlowRegistry.mu`): Protects the core control plane state. +// # Locking Order // -// 3. Shard-Level R/W Lock (`registryShard.mu`): Protects a single shard's metadata. +// To prevent deadlocks, locks MUST be acquired in the following order: // -// 4. Queue-Level Write Lock (`managedQueue.mu`): Each `managedQueue` uses a `sync.Mutex` to protect writes, ensuring -// strict consistency between queue contents and statistics (required for GC correctness). -// -// 5. Lock-Free Data Path (Atomics): Statistics aggregation (Shard/Registry level) uses lock-free atomics. Statistics -// reads at the queue level are also lock-free. -// -// 6. Strict Lock Hierarchy: To prevent deadlocks, a strict acquisition hierarchy is enforced: -// `FlowRegistry.mu` -> `registryShard.mu` -> `managedQueue.mu` -// If multiple locks are required, they MUST be acquired in this order. -// -// 7. Blocking Rules: Goroutines MUST NOT hold the `FlowRegistry.mu` lock while attempting to send to the `events` -// channel. Sends block if the channel is full. Blocking while holding the lock would deadlock the event loop. -// (e.g., use the `deferredActions` pattern). +// 1. `FlowRegistry.mu` (Registry-level write lock) +// 2. `registryShard.mu` (Shard-level write lock) +// 3. `flowState.gcLock` (Per-flow GC lock) type FlowRegistry struct { + // --- Immutable dependencies (set at construction) --- + config *Config logger logr.Logger + clock clock.WithTicker - // config holds the master configuration for the entire system. It is validated and defaulted at startup and used as - // the template for creating partitioned `ShardConfig`s. - config *Config + // --- Lock-free / Concurrent state (hot path) --- - // clock provides the time abstraction for the registry and its components (like `gcTracker`). - clock clock.WithTickerAndDelayedExecution + // flowStates tracks all flow instances, keyed by `types.FlowKey`. + flowStates sync.Map // stores `types.FlowKey` -> *flowState - // mu protects administrative operations and the internal state (shard lists, `flowState`s, etc.). - // Acquired by both external administrative methods and the internal event loop (`Run`), ensuring serialization. - mu sync.Mutex + // Globally aggregated statistics, updated atomically via lock-free propagation. + totalByteSize atomic.Int64 + totalLen atomic.Int64 + perPriorityBandStats map[uint]*bandStats // Keyed by priority. - // activeShards contains shards that are operational (Active). - // A slice is used to maintain a deterministic order, which is crucial for consistent configuration partitioning - // during scaling events. - activeShards []*registryShard + // --- Administrative state (protected by `mu`) --- - // drainingShards contains shards that are being gracefully shut down. - // A map is used for efficient O(1) removal of a shard by its ID when its draining process completes. + mu sync.RWMutex + activeShards []*registryShard drainingShards map[string]*registryShard - - // allShards is a cached, combined slice of active and draining shards. - // This cache is updated only when the set of shards changes, optimizing read-heavy operations that need to iterate - // over all shards. - allShards []*registryShard - - // nextShardID is a monotonically increasing counter used to generate unique, stable IDs for shards throughout the - // lifetime of the process. - nextShardID uint64 - - // flowStates tracks the desired state and GC state of all flow instances, keyed by the immutable `types.FlowKey`. - flowStates map[types.FlowKey]*flowState - - // gcTicker drives the periodic garbage collection cycle. - gcTicker clock.Ticker - - // gcGeneration is a monotonically increasing counter for GC cycles, used for the "mark-and-sweep" algorithm. - gcGeneration uint64 - - // events is a channel for all internal state change events from shards and queues. - // - // Note: The GC system relies on exactly-once delivery of edge-triggered events; therefore, sends to this channel must - // not be dropped. If the buffer fills, the data path will block, applying necessary backpressure to the control - // plane. - events chan event - - // Globally aggregated statistics. Updated atomically via lock-free propagation. - totalByteSize atomic.Int64 - totalLen atomic.Int64 - - // perPriorityBandStats stores *bandStats, keyed by priority (`uint`). - // The map structure is immutable after initialization; values are updated atomically. - perPriorityBandStats map[uint]*bandStats + allShards []*registryShard // Cached, sorted combination of Active and Draining shards + nextShardID uint64 } var _ contracts.FlowRegistry = &FlowRegistry{} -// RegistryOption allows configuring the `FlowRegistry` during initialization using functional options. +// RegistryOption allows configuring the `FlowRegistry` during initialization. type RegistryOption func(*FlowRegistry) -// WithClock sets the clock abstraction used by the registry (primarily for GC timers). -// This is essential for deterministic testing. If `clk` is nil, the option is ignored. -func WithClock(clk clock.WithTickerAndDelayedExecution) RegistryOption { +// withClock sets the clock abstraction for deterministic testing. +// test-only +func withClock(clk clock.WithTickerAndDelayedExecution) RegistryOption { return func(fr *FlowRegistry) { if clk != nil { fr.clock = clk @@ -156,140 +153,137 @@ func NewFlowRegistry(config Config, logger logr.Logger, opts ...RegistryOption) return nil, fmt.Errorf("master configuration is invalid: %w", err) } - // Buffered channel to absorb bursts of events. See comment on the struct field for concurrency notes. - events := make(chan event, config.EventChannelBufferSize) - fr := &FlowRegistry{ config: validatedConfig, logger: logger.WithName("flow-registry"), - flowStates: make(map[types.FlowKey]*flowState), - events: events, activeShards: []*registryShard{}, drainingShards: make(map[string]*registryShard), perPriorityBandStats: make(map[uint]*bandStats, len(validatedConfig.PriorityBands)), - // Initialize `generation` to 1. A generation value of 0 in `flowState` is reserved as a sentinel value to indicate a - // brand new flow that has not yet been processed by the GC loop. - gcGeneration: 1, } for _, opt := range opts { opt(fr) } - if fr.clock == nil { fr.clock = &clock.RealClock{} } - fr.gcTicker = fr.clock.NewTicker(fr.config.FlowGCTimeout) - for i := range config.PriorityBands { band := &config.PriorityBands[i] fr.perPriorityBandStats[band.Priority] = &bandStats{} } - // `UpdateShardCount` handles the initial creation and populates `activeShards`. - if err := fr.UpdateShardCount(validatedConfig.InitialShardCount); err != nil { - fr.gcTicker.Stop() + if err := fr.updateShardCount(validatedConfig.InitialShardCount); err != nil { return nil, fmt.Errorf("failed to initialize shards: %w", err) } - fr.logger.V(logging.DEFAULT).Info("FlowRegistry initialized successfully") return fr, nil } -// Run starts the registry's background event processing loop. It blocks until the provided context is cancelled. -// This loop implements the serialized control plane (Actor model), handling both asynchronous signals from data plane -// components and periodic ticks for garbage collection. +// Run starts the registry's background garbage collection loop. +// It blocks until the provided context is cancelled. func (fr *FlowRegistry) Run(ctx context.Context) { - fr.logger.Info("Starting FlowRegistry event loop") - defer fr.logger.Info("FlowRegistry event loop stopped") - defer fr.gcTicker.Stop() + fr.logger.Info("Starting FlowRegistry background garbage collection loop") + defer fr.logger.Info("FlowRegistry background garbage collection loop stopped") + gcTicker := fr.clock.NewTicker(fr.config.FlowGCTimeout) + defer gcTicker.Stop() for { select { case <-ctx.Done(): return - case <-fr.gcTicker.C(): - fr.mu.Lock() - fr.onGCTick() - fr.mu.Unlock() - case evt := <-fr.events: - fr.mu.Lock() - switch e := evt.(type) { - case *queueStateChangedEvent: - fr.onQueueStateChanged(e) - case *shardStateChangedEvent: - fr.onShardStateChanged(e) - case *syncEvent: - close(e.doneCh) // Signal synchronization point reached. - } - fr.mu.Unlock() + case <-gcTicker.C(): + fr.executeGCCycle() } } } -// RegisterOrUpdateFlow handles the registration of a new flow instance or the update of an existing instance's -// specification (for the same `types.FlowKey`). It orchestrates the creation or update atomically across all managed -// shards. -func (fr *FlowRegistry) RegisterOrUpdateFlow(spec types.FlowSpecification) error { - if spec.Key.ID == "" { - return fmt.Errorf("invalid flow specification: %w", contracts.ErrFlowIDEmpty) +// --- `contracts.FlowRegistryClient` Implementation --- + +// Connect establishes a session for a given flow, acquiring a lifecycle lease. +// This is the primary entry point for the data path. +// If the flow does not exist, it is registered Just-In-Time (JIT). +func (fr *FlowRegistry) WithConnection(key types.FlowKey, fn func(conn contracts.ActiveFlowConnection) error) error { + if key.ID == "" { + return contracts.ErrFlowIDEmpty } - fr.mu.Lock() - defer fr.mu.Unlock() + // --- JIT Registration --- + val, ok := fr.flowStates.Load(key) + if !ok { + newFlowState, err := fr.prepareNewFlow(key) + if err != nil { + return fmt.Errorf("failed to prepare JIT registration for flow %s: %w", key, err) + } - totalShardCount := len(fr.activeShards) + len(fr.drainingShards) - components, err := fr.buildFlowComponents(spec, totalShardCount) - if err != nil { - return err + actual, loaded := fr.flowStates.LoadOrStore(key, newFlowState) + val = actual + if loaded { + // Another goroutine won the race. Use its state and discard ours. + // If future changes make the `managedQueue` or its components more stateful (e.g., by adding background + // goroutines, registering with a metrics system, or using `sync.Pool`), a deterministic cleanup function MUST be + // called here to release those resources promptly and prevent leaks. + fr.logger.V(logging.DEBUG).Info("Concurrent JIT registration detected for flow", + "flowKey", key, "flowID", key.ID, "priority", key.Priority) + } } - fr.applyFlowSynchronizationLocked(spec, components) - return nil + // --- Lease Acquisition & Guaranteed Release --- + state := val.(*flowState) + state.gcLock.Lock() + state.leaseCount.Add(1) + state.becameIdleAt = time.Time{} // Mark the flow as Active. + state.gcLock.Unlock() + defer func() { + if state.leaseCount.Add(-1) == 0 { + // This was the last active lease; mark the flow as Idle. + state.gcLock.Lock() + state.becameIdleAt = fr.clock.Now() + state.gcLock.Unlock() + } + }() + + // --- Callback Execution --- + // We acquire a read lock. This has two effects: + // 1. It allows many connections to execute this section concurrently. + // 2. It prevents the GC from acquiring a write lock, thus guaranteeing the flow state cannot be deleted while `fn()` + // is running. + state.gcLock.RLock() + defer state.gcLock.RUnlock() + return fn(&connection{registry: fr, key: key}) } -// UpdateShardCount dynamically adjusts the number of internal state shards. -func (fr *FlowRegistry) UpdateShardCount(n int) error { - if n <= 0 { - return fmt.Errorf("%w: shard count must be a positive integer, but got %d", contracts.ErrInvalidShardCount, n) - } +// prepareNewFlow creates a new `flowState` and synchronizes its queues and policies onto all existing shards. +func (fr *FlowRegistry) prepareNewFlow(key types.FlowKey) (*flowState, error) { + // Get a stable snapshot of the shard topology. + // An RLock is sufficient because while the list of shards must be stable, the internal state of each shard is + // protected by its own lock. + fr.mu.RLock() + defer fr.mu.RUnlock() - fr.mu.Lock() - currentActiveShards := len(fr.activeShards) - if n == currentActiveShards { - fr.mu.Unlock() - return nil + components, err := fr.buildFlowComponents(key, len(fr.allShards)) + if err != nil { + return nil, err } - // deferredActions holds functions to be executed after the lock is released. This pattern is used to cleanly separate - // state mutations (under lock) from side effects that might block (like sending to the `events` channel). - var deferredActions []func() - - if n > currentActiveShards { - if err := fr.executeScaleUpLocked(n); err != nil { - fr.mu.Unlock() - return err - } - } else { - fr.executeScaleDownLocked(n, &deferredActions) + for i, shard := range fr.allShards { + shard.synchronizeFlow(types.FlowSpecification{Key: key}, components[i].policy, components[i].queue) } - fr.mu.Unlock() - // Execute all deferred side effects outside the lock. - for _, action := range deferredActions { - action() - } - return nil + fr.logger.Info("Successfully prepared and synchronized new flow instance", + "flowKey", key, "flowID", key.ID, "priority", key.Priority) + return &flowState{key: key}, nil } +// --- `contracts.FlowRegistryAdmin` Implementation --- + // Stats returns globally aggregated statistics for the entire `FlowRegistry`. // -// Note on Concurrency and Consistency: Statistics are aggregated using high-performance, lock-free atomic updates. -// The returned stats represent a near-consistent snapshot of the system's state. It is not perfectly atomic because the -// various counters are loaded independently without a global lock. +// Statistics are aggregated using high-performance, lock-free atomic updates. +// The returned stats represent a near-consistent snapshot of the system's state. +// It is not perfectly atomic because the various counters are loaded independently without a global lock. func (fr *FlowRegistry) Stats() contracts.AggregateStats { - // Casts from `int64` to `uint64` are safe because the non-negative invariant is strictly enforced at the + // Casts from `int64` to `uint64` are safe because the non-negativity invariant is strictly enforced at the // `managedQueue` level. stats := contracts.AggregateStats{ TotalCapacityBytes: fr.config.MaxBytes, @@ -301,8 +295,6 @@ func (fr *FlowRegistry) Stats() contracts.AggregateStats { for p, s := range fr.perPriorityBandStats { bandCfg, err := fr.config.getBandConfig(p) if err != nil { - // The stats map was populated from the config, so the config must exist for this priority. - // This indicates severe state corruption. panic(fmt.Sprintf("invariant violation: priority band config (%d) missing during stats aggregation: %v", p, err)) } stats.PerPriorityBandStats[p] = contracts.PriorityBandStats{ @@ -316,29 +308,11 @@ func (fr *FlowRegistry) Stats() contracts.AggregateStats { return stats } -// updateAllShardsCacheLocked recalculates and updates the cached `allShards` slice. -// It must be called any time the `activeShards` or `drainingShards` collections are modified. -// It expects the registry's lock to be held. -func (fr *FlowRegistry) updateAllShardsCacheLocked() { - allShards := make([]*registryShard, 0, len(fr.activeShards)+len(fr.drainingShards)) - allShards = append(allShards, fr.activeShards...) - - // Note: Iteration over a map is non-deterministic. However, the order of Draining shards in the combined slice does - // not impact correctness. Active shards are always first and in a stable order. - for _, shard := range fr.drainingShards { - allShards = append(allShards, shard) - } - fr.allShards = allShards -} - -// ShardStats returns a slice of statistics, one for each internal shard (Active and Draining). +// ShardStats returns a slice of statistics, one for each internal shard. func (fr *FlowRegistry) ShardStats() []contracts.ShardStats { - // To minimize lock contention, we acquire the lock just long enough to get the combined list of shards. - // Iterating and gathering stats (which involves reading shard-level atomics/locks) is done outside the critical - // section. - fr.mu.Lock() + fr.mu.RLock() allShards := fr.allShards - fr.mu.Unlock() + fr.mu.RUnlock() shardStats := make([]contracts.ShardStats, len(allShards)) for i, s := range allShards { @@ -347,149 +321,214 @@ func (fr *FlowRegistry) ShardStats() []contracts.ShardStats { return shardStats } -// Shards returns a slice of accessors for all internal shards (active and draining). -// Active shards always precede draining shards in the returned slice. -func (fr *FlowRegistry) Shards() []contracts.RegistryShard { - // Similar to `ShardStats`, minimize lock contention by getting the list under lock. +// --- Garbage Collection --- + +// executeGCCycle orchestrates the periodic GC of Idle flows and Drained shards. +func (fr *FlowRegistry) executeGCCycle() { + fr.logger.V(logging.DEBUG).Info("Starting periodic GC scan") + var flowCandidates []types.FlowKey + fr.flowStates.Range(func(key, value interface{}) bool { + state := value.(*flowState) + state.gcLock.RLock() + // A flow is a candidate if its lease count is zero and its idleness timeout has expired. + if state.leaseCount.Load() == 0 && !state.becameIdleAt.IsZero() { + if fr.clock.Since(state.becameIdleAt) > fr.config.FlowGCTimeout { + flowCandidates = append(flowCandidates, key.(types.FlowKey)) + } + } + state.gcLock.RUnlock() + return true + }) + if len(flowCandidates) > 0 { + fr.verifyAndSweepFlows(flowCandidates) + } + fr.sweepDrainingShards() +} + +// verifyAndSweepFlows performs the "verify" and "sweep" phases of GC for Idle flows. +// For each candidate, it acquires an exclusive lock on that specific flow's state, re-verifies it is still Idle, and +// then safely performs the deletion. +func (fr *FlowRegistry) verifyAndSweepFlows(candidates []types.FlowKey) { + fr.logger.V(logging.DEBUG).Info("Starting GC Verify and Sweep phase for flows", "candidateCount", len(candidates)) + + // Get a stable snapshot of the shard topology, so the list of shards does not change while we are preparing to delete + // queues from them. + fr.mu.RLock() + shardsSnapshot := fr.allShards + fr.mu.RUnlock() + + var collectedCount int + for _, key := range candidates { + val, ok := fr.flowStates.Load(key) + if !ok { + // Benign race: the flow was already deleted by a previous GC cycle or another process. We can safely ignore it. + continue + } + state := val.(*flowState) + + // Acquire the exclusive write lock for this specific flow, blocking any new `Connect/Close` operations for this + // flow only and ensuring the state is stable for our check. All other flows are unaffected. + state.gcLock.Lock() + + // Verify Phase: + if state.leaseCount.Load() > 0 { + // Verification failed. A new lease was acquired between our initial scan and acquiring the lock. + // The flow is Active again, so we leave it alone. + fr.logger.V(logging.DEBUG).Info("GC of flow aborted: re-verification failed (flow is Active)", + "flowKey", key, "flowID", key.ID, "priority", key.Priority, + "leaseCount", state.leaseCount.Load(), "becameIdleAt", state.becameIdleAt) + state.gcLock.Unlock() + continue + } + + // Sweep Phase: + for _, shard := range shardsSnapshot { + shard.deleteFlow(key) + } + fr.flowStates.Delete(key) + fr.logger.V(logging.VERBOSE).Info("Successfully verified and swept flow", + "flowKey", key, "flowID", key.ID, "priority", key.Priority, "becameIdleAt", state.becameIdleAt) + collectedCount++ + state.gcLock.Unlock() + } + + fr.logger.V(logging.DEBUG).Info("GC Verify and Sweep phase completed", "flowsCollected", collectedCount) +} + +// sweepDrainingShards finalizes the removal of drained shards. +func (fr *FlowRegistry) sweepDrainingShards() { + // Acquire a full write lock on the registry as we may be modifying the shard topology. fr.mu.Lock() - allShards := fr.allShards - fr.mu.Unlock() + defer fr.mu.Unlock() - shardContracts := make([]contracts.RegistryShard, len(allShards)) - for i, s := range allShards { - shardContracts[i] = s + var shardsToDelete []string + for id, shard := range fr.drainingShards { + // A Draining shard is ready for GC once it is completely empty. + // Draining shards do not accept new work (enforced at `managedQueue.Add`), so `shard.totalLen.Load()` can only + // monotonically decrease. + if shard.totalLen.Load() == 0 { + shardsToDelete = append(shardsToDelete, id) + } + } + + if len(shardsToDelete) > 0 { + fr.logger.V(logging.DEBUG).Info("Garbage collecting drained shards", "shardIDs", shardsToDelete) + for _, id := range shardsToDelete { + delete(fr.drainingShards, id) + } + fr.updateAllShardsCacheLocked() } - return shardContracts } -// --- Internal Methods --- +// --- Shard Management (Scaling) --- -// executeScaleUpLocked handles adding new shards using a "prepare-commit" pattern. -// -// First, in a "prepare" phase, all new shards are fully created and initialized in a temporary slice. This includes -// all fallible work. If any part of this phase fails, the operation is aborted without modifying the `FlowRegistry`'s -// state. If preparation succeeds, a "commit" phase atomically applies the changes to the registry. -// -// # Scalability Considerations +// updateShardCount dynamically adjusts the number of internal state shards. +func (fr *FlowRegistry) updateShardCount(n int) error { + if n <= 0 { + return fmt.Errorf("%w: shard count must be a positive integer, but got %d", contracts.ErrInvalidShardCount, n) + } + + // Use a full write lock as this is a major structural change to the shard topology. + fr.mu.Lock() + defer fr.mu.Unlock() + + currentActiveShards := len(fr.activeShards) + if n == currentActiveShards { + return nil + } + + if n > currentActiveShards { + return fr.executeScaleUpLocked(n) + } + fr.executeScaleDownLocked(n) + return nil +} + +// executeScaleUpLocked handles adding new shards. +// It uses a "prepare-then-commit" pattern to ensure that the entire scale-up operation is transactional and never +// leaves the system in a partially-synchronized, inconsistent state. // -// The preparation phase synchronizes all existing flows onto the new shards. This requires O(M*K) operations -// (M=flows, K=new shards) performed while holding the main control plane lock. If M is large, this operation will block -// the control plane (including event processing) for a significant duration. +// The preparation phase iterates over all existing flows once, pre-building all necessary components for every new +// shard. This requires O(M*K) operations (M=flows, K=new shards) and is performed while holding the main control plane +// lock. If M is large, this operation may block the control plane for a significant duration. // // Expects the registry's write lock to be held. func (fr *FlowRegistry) executeScaleUpLocked(newTotalActive int) error { currentActive := len(fr.activeShards) numToAdd := newTotalActive - currentActive - fr.logger.Info("Scaling up shards", "currentActive", currentActive, "newTotalActive", newTotalActive) - // --- Prepare Phase --- - // Create all new shards in a temporary slice. This phase is fallible and performs no mutations on the - // `FlowRegistry`'s state. - preparedShards := make([]*registryShard, numToAdd) + // Prepare All New Shard Objects (Fallible): + newShards := make([]*registryShard, numToAdd) for i := range numToAdd { - shardID := fmt.Sprintf("shard-%d", fr.nextShardID+uint64(i)) - shardIndex := currentActive + i - partitionedConfig := fr.config.partition(shardIndex, newTotalActive) - - callbacks := shardCallbacks{ - propagateStatsDelta: fr.propagateStatsDelta, - signalQueueState: fr.handleQueueStateSignal, - signalShardState: fr.handleShardStateSignal, - } - shard, err := newShard(shardID, partitionedConfig, fr.logger, callbacks, fr.config.interFlowDispatchPolicyFactory) + // Using a padding of 4 allows for up to 9999 shards, which is a very safe upper bound. + shardID := fmt.Sprintf("shard-%04d", fr.nextShardID+uint64(i)) + partitionedConfig := fr.config.partition(currentActive+i, newTotalActive) + shard, err := newShard( + shardID, + partitionedConfig, + fr.logger, + fr.propagateStatsDelta, + fr.config.interFlowDispatchPolicyFactory, + ) if err != nil { - return fmt.Errorf("failed to create new shard %s: %w", shardID, err) + return fmt.Errorf("failed to create new shard object %s: %w", shardID, err) } - - // Synchronize all existing flows onto this newly created shard. - for _, state := range fr.flowStates { - components, err := fr.buildFlowComponents(state.spec, 1) - if err != nil { - // This is unlikely as the flow was already validated, but we handle it defensively. - return fmt.Errorf("failed to prepare synchronization for flow %s on new shard %s: %w", - state.spec.Key, shardID, err) - } - shard.synchronizeFlow(state.spec, components[0].policy, components[0].queue) + newShards[i] = shard + } + + // Prepare All Components for All New Shards (Fallible): + // Pre-build every component for every existing flow on every new shard. + // If any single component fails to build, the entire scale-up operation is aborted, and all prepared data is + // discarded, leaving the system state clean. + allComponents := make(map[types.FlowKey][]flowComponents) + var rangeErr error + fr.flowStates.Range(func(key, _ interface{}) bool { + flowKey := key.(types.FlowKey) + components, err := fr.buildFlowComponents(flowKey, len(newShards)) + if err != nil { + rangeErr = fmt.Errorf("failed to prepare components for flow %s on new shards: %w", flowKey, err) + return false } - preparedShards[i] = shard + allComponents[flowKey] = components + return true + }) + if rangeErr != nil { + return rangeErr } - // --- Commit Phase --- - // Preparation succeeded. Atomically apply all changes to the registry's state. - // This phase must be infallible. - fr.activeShards = append(fr.activeShards, preparedShards...) - - for _, shard := range preparedShards { - for _, state := range fr.flowStates { - state.emptyOnShards[shard.id] = true + // Commit (Infallible): + for i, shard := range newShards { + for key, components := range allComponents { + shard.synchronizeFlow(types.FlowSpecification{Key: key}, components[i].policy, components[i].queue) } } - + fr.activeShards = append(fr.activeShards, newShards...) fr.nextShardID += uint64(numToAdd) fr.repartitionShardConfigsLocked() fr.updateAllShardsCacheLocked() return nil } -// executeScaleDownLocked handles marking shards for graceful draining and re-partitioning. -// It appends the necessary draining actions to the `deferredActions` slice. These actions MUST be executed by the -// caller after the registry lock is released to prevent deadlocks. -// It expects the registry's write lock to be held. -func (fr *FlowRegistry) executeScaleDownLocked(newTotalActive int, deferredActions *[]func()) { +// executeScaleDownLocked handles marking shards for graceful draining. +// Expects the registry's write lock to be held. +func (fr *FlowRegistry) executeScaleDownLocked(newTotalActive int) { currentActive := len(fr.activeShards) fr.logger.Info("Scaling down shards", "currentActive", currentActive, "newTotalActive", newTotalActive) - // Identify the shards to drain. These are the ones at the end of the Active list. shardsToDrain := fr.activeShards[newTotalActive:] fr.activeShards = fr.activeShards[:newTotalActive] for _, shard := range shardsToDrain { fr.drainingShards[shard.id] = shard - } - - // Defer the `markAsDraining` calls, which may block if the `events` channel is full. - for _, shard := range shardsToDrain { - s := shard - *deferredActions = append(*deferredActions, func() { - s.markAsDraining() - }) + shard.markAsDraining() } fr.repartitionShardConfigsLocked() fr.updateAllShardsCacheLocked() } -// applyFlowSynchronizationLocked is the "commit" step of `RegisterOrUpdateFlow`. -// It updates the central `flowState` and propagates the changes to all shards. -// It expects the registry's write lock to be held. -func (fr *FlowRegistry) applyFlowSynchronizationLocked(spec types.FlowSpecification, components []flowComponents) { - key := spec.Key - state, exists := fr.flowStates[key] - allShards := fr.allShards - if !exists { - // This is a new flow instance. - state = newFlowState(spec, allShards) - fr.flowStates[key] = state - } else { - // This is an update to an existing flow instance (e.g., policy change, when supported). - state.update(spec) - } - - if len(allShards) != len(components) { - // This indicates a severe logic error during the prepare/commit phase synchronization (a race in - // `RegisterOrUpdateFlow`). - panic(fmt.Sprintf("invariant violation: shard/queue/policy count mismatch during commit for flow %s", spec.Key)) - } - - // Propagate the update to all shards (Active and Draining), giving each its own dedicated policy and queue instance. - for i, shard := range allShards { - shard.synchronizeFlow(spec, components[i].policy, components[i].queue) - } - fr.logger.Info("Successfully registered or updated flow instance", "flowKey", key) -} - -// repartitionShardConfigsLocked updates the partitioned configuration for all active shards. -// It expects the registry's write lock to be held. +// repartitionShardConfigsLocked updates the configuration for all active shards. +// Expects the registry's write lock to be held. func (fr *FlowRegistry) repartitionShardConfigsLocked() { numActive := len(fr.activeShards) for i, shard := range fr.activeShards { @@ -498,7 +537,9 @@ func (fr *FlowRegistry) repartitionShardConfigsLocked() { } } -// flowComponents holds the set of plugin instances created for a single shard. +// --- Internal Helpers --- + +// flowComponents holds the plugin instances created for a single flow on a single shard. type flowComponents struct { policy framework.IntraFlowDispatchPolicy queue framework.SafeQueue @@ -506,236 +547,54 @@ type flowComponents struct { // buildFlowComponents instantiates the necessary plugin components for a new flow instance. // It creates a distinct instance of each component for each shard to ensure state isolation. -func (fr *FlowRegistry) buildFlowComponents(spec types.FlowSpecification, numInstances int) ([]flowComponents, error) { - priority := spec.Key.Priority - bandConfig, err := fr.config.getBandConfig(priority) +func (fr *FlowRegistry) buildFlowComponents(key types.FlowKey, numInstances int) ([]flowComponents, error) { + bandConfig, err := fr.config.getBandConfig(key.Priority) if err != nil { - return nil, fmt.Errorf("failed to get configuration for priority %d: %w", priority, err) + return nil, err } - // TODO: When flow-level queue/policy overrides are implemented, check `spec` first. - policyName := bandConfig.IntraFlowDispatchPolicy - queueName := bandConfig.Queue allComponents := make([]flowComponents, numInstances) - for i := range numInstances { - policy, err := fr.config.intraFlowDispatchPolicyFactory(policyName) + policy, err := fr.config.intraFlowDispatchPolicyFactory(bandConfig.IntraFlowDispatchPolicy) if err != nil { - return nil, fmt.Errorf("failed to instantiate intra-flow policy %s for flow %s: %w", policyName, spec.Key, err) + return nil, fmt.Errorf("failed to instantiate intra-flow policy %q for flow %s: %w", + bandConfig.IntraFlowDispatchPolicy, key, err) } - - q, err := fr.config.queueFactory(queueName, policy.Comparator()) + q, err := fr.config.queueFactory(bandConfig.Queue, policy.Comparator()) if err != nil { - return nil, fmt.Errorf("failed to instantiate queue %s for flow %s: %w", queueName, spec.Key, err) + return nil, fmt.Errorf("failed to instantiate queue %q for flow %s: %w", + bandConfig.Queue, key, err) } allComponents[i] = flowComponents{policy: policy, queue: q} } - return allComponents, nil } -// garbageCollectFlowLocked orchestrates the "Trust but Verify" garbage collection of a -// single flow instance. It implements the three steps of the pattern: -// 1. Trust the eventually consistent cache. -// 2. Verify the ground truth with a "stop-the-world" pause. -// 3. Act to delete the flow if confirmed Idle. -// -// It acquires no locks itself but expects the caller to hold the main registry lock. -func (fr *FlowRegistry) garbageCollectFlowLocked(key types.FlowKey) bool { - state, exists := fr.flowStates[key] - if !exists { - return false // Already deleted. Benign race. - } - - if !state.isIdle(fr.allShards) { - return false - } - - logger := fr.logger.WithValues("flowKey", key, "flowID", key.ID, "priority", key.Priority) - if !fr.verifyFlowIsTrulyIdleLocked(key, logger) { - return false - } - - fr.deleteFlowLocked(key, logger) - return true -} - -// verifyFlowIsTrulyIdleLocked performs the "stop-the-world" verification step of GC. -// It acquires a write lock on ALL shards, briefly pausing the data path to get a strongly consistent view of all queue -// lengths for a given flow. -// -// # Scalability Considerations -// -// This "Verify" step requires acquiring write locks on all shards (O(N)). As the shard count (N) increases, this pause -// duration may grow, potentially impacting P99 latency. This trade-off is made explicitly to guarantee correctness and -// prevent data loss. -// -// Returns true if the flow is confirmed to be empty everywhere, false otherwise. -func (fr *FlowRegistry) verifyFlowIsTrulyIdleLocked(key types.FlowKey, logger logr.Logger) bool { - allShards := fr.allShards - for _, shard := range allShards { - shard.mu.Lock() - } - defer func() { - for _, shard := range allShards { - shard.mu.Unlock() - } - }() - - for _, shard := range allShards { - mq, err := shard.managedQueueLocked(key) - if err != nil { - panic(fmt.Sprintf("invariant violation: managed queue for flow %s not found on shard %s during GC: %v", - key, shard.ID(), err)) - } - if mq.Len() > 0 { - logger.V(logging.DEBUG).Info("GC aborted: Live check revealed flow instance is active", "shardID", shard.ID) - return false - } - } - return true -} - -// deleteFlowLocked performs the destructive step of GC. It removes the flow's queues from all shards and deletes its -// state from the central registry map. -func (fr *FlowRegistry) deleteFlowLocked(key types.FlowKey, logger logr.Logger) { - logger.V(logging.VERBOSE).Info("Garbage collecting inactive flow instance") - for _, shard := range fr.allShards { - shard.garbageCollectLocked(key) - } - delete(fr.flowStates, key) - logger.V(logging.VERBOSE).Info("Successfully garbage collected flow instance") -} - -// --- Event Handling (The Control Plane Loop) --- - -const gcGenerationNewFlow uint64 = 0 - -// onGCTick implements the periodic "mark and sweep" garbage collection cycle. -// -// # Algorithm: Generational Mark-and-Sweep -// -// The GC uses a generational algorithm to ensure any idle flow survives for at least one full GC cycle before being -// collected. -// -// 1. Mark: Any flow that is either actively serving requests OR is brand new (identified by a sentinel -// `lastActiveGeneration` value of `gcGenerationNewFlow`) is "marked" by updating its generation timestamp. This -// design choice keeps the `flowState` object decoupled from the GC engine's internal clock. -// -// 2. Sweep: Any flow that is Idle AND was last marked before the *previous* GC cycle is eligible for collection. -// -// It expects the registry's write lock to be held. -func (fr *FlowRegistry) onGCTick() { - previousGeneration := fr.gcGeneration - currentGeneration := previousGeneration + 1 - logger := fr.logger.WithValues("gcGeneration", currentGeneration) - logger.V(logging.DEBUG).Info("Starting periodic GC cycle") - - // --- Mark Phase --- - // A flow is "marked" for survival by updating its `lastActiveGeneration`. - for _, state := range fr.flowStates { - if !state.isIdle(fr.allShards) { - // If a flow is currently Active, mark it with the current generation. - // This protects it from collection in the next cycle. - state.lastActiveGeneration = currentGeneration - } else if state.lastActiveGeneration == gcGenerationNewFlow { - // If a flow is new and Idle, mark it with the *previous* generation. - // This grants it a grace period for the *current* sweep, but correctly identifies it as a candidate for the - // *next* sweep if it remains Idle. - state.lastActiveGeneration = previousGeneration - } - } - - // --- Sweep Phase --- - // A flow is a candidate for collection if it is Idle AND its `lastActiveGeneration` is from before the previous - // cycle. - var candidates []types.FlowKey - for key, state := range fr.flowStates { - if state.isIdle(fr.allShards) && state.lastActiveGeneration < previousGeneration { - candidates = append(candidates, key) - } - } - - if len(candidates) > 0 { - logger.V(logging.DEBUG).Info("Found Idle flow instances for GC verification", "count", len(candidates)) - for _, key := range candidates { - fr.garbageCollectFlowLocked(key) - } - } - - // Finalize the cycle by updating the registry's generation. - fr.gcGeneration = currentGeneration -} - -// onQueueStateChanged handles a state change signal from a `managedQueue`. It updates the control plane's cached view -// of the flow's Idle state. -func (fr *FlowRegistry) onQueueStateChanged(e *queueStateChangedEvent) { - fr.logger.V(logging.DEBUG).Info("Processing queue state signal", - "shardID", e.shardID, - "flowKey", e.key, - "flowID", e.key.ID, - "priority", e.key.Priority, - "signal", e.signal, - ) - - state, ok := fr.flowStates[e.key] - if !ok { - return // Flow was likely already garbage collected. - } - state.handleQueueSignal(e.shardID, e.signal) -} - -// onShardStateChanged handles a state change signal from a `registryShard`. -func (fr *FlowRegistry) onShardStateChanged(e *shardStateChangedEvent) { - if e.signal == shardStateSignalBecameDrained { - logger := fr.logger.WithValues("shardID", e.shardID, "signal", e.signal) - logger.Info("Draining shard is now empty, finalizing garbage collection") - - if _, ok := fr.drainingShards[e.shardID]; !ok { - // A shard signaled Drained but wasn't in the Draining list (e.g., it was Active, or the signal was somehow - // processed twice despite the atomic latch). - // The system state is potentially corrupted. - panic(fmt.Sprintf("invariant violation: shard %s not found in draining map during GC", e.shardID)) - } - delete(fr.drainingShards, e.shardID) - - for _, flowState := range fr.flowStates { - flowState.purgeShard(e.shardID) - } - fr.updateAllShardsCacheLocked() - } -} - -// --- Callbacks (Data Plane -> Control Plane Communication) --- - -// handleQueueStateSignal is the callback passed to shards to allow them to signal queue state changes. -// It sends an event to the event channel for serialized processing by the control plane. -func (fr *FlowRegistry) handleQueueStateSignal(shardID string, key types.FlowKey, signal queueStateSignal) { - // This must block if the channel is full. Dropping events would cause state divergence and memory leaks, as the GC - // system is edge-triggered. Blocking provides necessary backpressure from the data plane to the control plane. - fr.events <- &queueStateChangedEvent{ - shardID: shardID, - key: key, - signal: signal, +// updateAllShardsCacheLocked recalculates the cached `allShards` slice. +// It ensures the slice is sorted by shard ID to maintain a deterministic order. +// Expects the registry's write lock to be held. +func (fr *FlowRegistry) updateAllShardsCacheLocked() { + allShards := make([]*registryShard, 0, len(fr.activeShards)+len(fr.drainingShards)) + allShards = append(allShards, fr.activeShards...) + for _, shard := range fr.drainingShards { + allShards = append(allShards, shard) } -} -// handleShardStateSignal is the callback passed to shards to allow them to signal their own state changes. -func (fr *FlowRegistry) handleShardStateSignal(shardID string, signal shardStateSignal) { - // This must also block (see `handleQueueStateSignal`). - fr.events <- &shardStateChangedEvent{ - shardID: shardID, - signal: signal, - } + // Sort the combined slice by shard ID. + // This provides a stable, deterministic order for all consumers of the shard list, which is critical because map + // iteration for `drainingShards` is non-deterministic. + // While this is a lexicographical sort, our shard ID format is padded with leading zeros (e.g., "shard-0001"), + // ensuring that the string sort produces the same result as a natural numerical sort. + slices.SortFunc(allShards, func(a, b *registryShard) int { + return cmp.Compare(a.id, b.id) + }) + fr.allShards = allShards } -// propagateStatsDelta is the top-level, lock-free aggregator for all statistics changes from all shards. -// It uses atomic operations to maintain high performance under concurrent updates from multiple shards. -// As a result, its counters are eventually consistent and may be transiently inaccurate during high-contention races. +// propagateStatsDelta is the top-level, lock-free aggregator for all statistics. func (fr *FlowRegistry) propagateStatsDelta(priority uint, lenDelta, byteSizeDelta int64) { stats, ok := fr.perPriorityBandStats[priority] if !ok { - // Stats are being propagated for a priority that wasn't initialized. panic(fmt.Sprintf("invariant violation: priority band (%d) stats missing during propagation", priority)) } diff --git a/pkg/epp/flowcontrol/registry/registry_test.go b/pkg/epp/flowcontrol/registry/registry_test.go index fcefd8613..5ffa600d0 100644 --- a/pkg/epp/flowcontrol/registry/registry_test.go +++ b/pkg/epp/flowcontrol/registry/registry_test.go @@ -27,36 +27,29 @@ import ( "github.com/go-logr/logr" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "k8s.io/utils/clock" testclock "k8s.io/utils/clock/testing" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" inter "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/interflow/dispatch" intra "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch" - "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/queue" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types/mocks" ) -const syncTimeout = 2 * time.Second - -// --- Test Harness and Mocks --- +// --- Test Harness --- // registryTestHarness provides a fully initialized test harness for the `FlowRegistry`. type registryTestHarness struct { - t *testing.T - ctx context.Context - cancel context.CancelFunc - fr *FlowRegistry - config Config - fakeClock *testclock.FakeClock - activeItems map[types.FlowKey]types.QueueItemAccessor + t *testing.T + fr *FlowRegistry + config Config + fakeClock *testclock.FakeClock } +// harnessOptions configures the test harness. type harnessOptions struct { config *Config - useFakeClock bool initialShardCount int } @@ -69,10 +62,10 @@ func newRegistryTestHarness(t *testing.T, opts harnessOptions) *registryTestHarn config = *opts.config } else { config = Config{ - FlowGCTimeout: 1 * time.Minute, + FlowGCTimeout: 5 * time.Minute, // Use a realistic but controllable GC time. PriorityBands: []PriorityBandConfig{ - {Priority: 10, PriorityName: "High"}, - {Priority: 20, PriorityName: "Low"}, + {Priority: highPriority, PriorityName: "High"}, + {Priority: lowPriority, PriorityName: "Low"}, }, } } @@ -80,16 +73,12 @@ func newRegistryTestHarness(t *testing.T, opts harnessOptions) *registryTestHarn config.InitialShardCount = opts.initialShardCount } - var clk clock.WithTickerAndDelayedExecution - var fakeClock *testclock.FakeClock - if opts.useFakeClock { - fakeClock = testclock.NewFakeClock(time.Now()) - clk = fakeClock - } - - fr, err := NewFlowRegistry(config, logr.Discard(), WithClock(clk)) + fakeClock := testclock.NewFakeClock(time.Now()) + registryOpts := []RegistryOption{withClock(fakeClock)} + fr, err := NewFlowRegistry(config, logr.Discard(), registryOpts...) require.NoError(t, err, "Test setup: NewFlowRegistry should not fail") + // Start the GC loop in the background. ctx, cancel := context.WithCancel(context.Background()) var wg sync.WaitGroup wg.Add(1) @@ -97,713 +86,627 @@ func newRegistryTestHarness(t *testing.T, opts harnessOptions) *registryTestHarn defer wg.Done() fr.Run(ctx) }() - t.Cleanup(func() { cancel() wg.Wait() }) return ®istryTestHarness{ - t: t, - ctx: ctx, - cancel: cancel, - fr: fr, - config: config, - fakeClock: fakeClock, - activeItems: make(map[types.FlowKey]types.QueueItemAccessor), - } -} - -// synchronize blocks until the `FlowRegistry`'s event loop has processed all preceding events. -func (h *registryTestHarness) synchronize() { - h.t.Helper() - doneCh := make(chan struct{}) - select { - case h.fr.events <- &syncEvent{doneCh: doneCh}: - select { - case <-doneCh: - case <-time.After(syncTimeout): - h.t.Fatalf("Timeout waiting for FlowRegistry synchronization ack") - } - case <-time.After(syncTimeout): - h.t.Fatalf("Timeout sending sync event to FlowRegistry (channel may be full)") - } -} - -// waitForGCTick advances the fake clock to trigger a GC cycle and then blocks until that cycle is complete. -// This works by sandwiching the clock step between two synchronization barriers. -func (h *registryTestHarness) waitForGCTick() { - h.t.Helper() - require.NotNil(h.t, h.fakeClock, "waitForGCTick requires the fake clock to be enabled") - h.synchronize() - h.fakeClock.Step(h.config.FlowGCTimeout + time.Millisecond) - h.synchronize() -} - -// setFlowActive makes a flow Active (by adding an item) or Idle (by removing it) and synchronizes the state change. -func (h *registryTestHarness) setFlowActive(key types.FlowKey, active bool) { - h.t.Helper() - shard := h.getFirstActiveShard() - mq, err := shard.ManagedQueue(key) - require.NoError(h.t, err, "Failed to get managed queue to set flow activity for flow %s", key) - - if active { - require.NotContains(h.t, h.activeItems, key, "Flow %s is already marked as Active in the test harness", key) - item := mocks.NewMockQueueItemAccessor(100, "req", key) - require.NoError(h.t, mq.Add(item), "Failed to add item to make flow %s Active", key) - h.activeItems[key] = item - } else { - item, ok := h.activeItems[key] - require.True(h.t, ok, "Flow %s was not Active in the test harness, cannot make it Idle", key) - _, err := mq.Remove(item.Handle()) - require.NoError(h.t, err, "Failed to remove item to make flow %s Idle", key) - delete(h.activeItems, key) - } - h.synchronize() -} - -func (h *registryTestHarness) getFirstActiveShard() contracts.RegistryShard { - h.t.Helper() - allShards := h.fr.Shards() - for _, s := range allShards { - if s.IsActive() { - return s - } + t: t, + fr: fr, + config: *fr.config, + fakeClock: fakeClock, } - h.t.Fatalf("Failed to find any active shard in list of %d shards", len(allShards)) - return nil } +// assertFlowExists synchronously checks if a flow's queue exists on the first shard. func (h *registryTestHarness) assertFlowExists(key types.FlowKey, msgAndArgs ...any) { h.t.Helper() - allShards := h.fr.Shards() - require.NotEmpty(h.t, allShards, "Cannot assert flow existence without shards") - - // It's sufficient to check one shard, as the registry guarantees consistency. - // A more robust check could iterate all, but this is a good balance. - _, err := allShards[0].ManagedQueue(key) + require.NotEmpty(h.t, h.fr.allShards, "Cannot check for flow existence when no shards are present") + _, err := h.fr.allShards[0].ManagedQueue(key) assert.NoError(h.t, err, msgAndArgs...) } +// assertFlowDoesNotExist synchronously checks if a flow's queue does not exist. func (h *registryTestHarness) assertFlowDoesNotExist(key types.FlowKey, msgAndArgs ...any) { h.t.Helper() - allShards := h.fr.Shards() - // If there are no shards, the flow cannot exist. - if len(allShards) == 0 { + if len(h.fr.allShards) == 0 { + assert.True(h.t, true, "Flow correctly does not exist because no shards exist") return } - _, err := allShards[0].ManagedQueue(key) + _, err := h.fr.allShards[0].ManagedQueue(key) + require.Error(h.t, err, "Expected an error when getting a non-existent flow, but got none") assert.ErrorIs(h.t, err, contracts.ErrFlowInstanceNotFound, msgAndArgs...) } -// --- Basic Tests --- +// openConnectionOnFlow ensures a flow is registered for the provided `key`. +func (h *registryTestHarness) openConnectionOnFlow(key types.FlowKey) { + h.t.Helper() + err := h.fr.WithConnection(key, func(conn contracts.ActiveFlowConnection) error { return nil }) + require.NoError(h.t, err, "Registering flow %s should not fail", key) + h.assertFlowExists(key, "Flow %s should exist after registration", key) +} + +// --- Constructor and Lifecycle Tests --- -func TestFlowRegistry_New_ErrorPaths(t *testing.T) { +func TestFlowRegistry_New(t *testing.T) { t.Parallel() - t.Run("ShouldFail_WhenConfigIsInvalid", func(t *testing.T) { + t.Run("ShouldApplyDefaults_WhenInitialized", func(t *testing.T) { t.Parallel() - _, err := NewFlowRegistry(Config{}, logr.Discard()) // No priority bands is invalid. - require.Error(t, err, "NewFlowRegistry should fail with an invalid config") + config := Config{PriorityBands: []PriorityBandConfig{{Priority: highPriority, PriorityName: "DefaultedBand"}}} + fr, err := NewFlowRegistry(config, logr.Discard()) + require.NoError(t, err, "Creating a valid registry with defaults should not fail") + assert.Equal(t, defaultInitialShardCount, fr.config.InitialShardCount, "InitialShardCount should be defaulted") + assert.Equal(t, defaultFlowGCTimeout, fr.config.FlowGCTimeout, "FlowGCTimeout should be defaulted") + assert.Equal(t, defaultEventChannelBufferSize, fr.config.EventChannelBufferSize, + "EventChannelBufferSize should be defaulted") + assert.Len(t, fr.allShards, defaultInitialShardCount, + "Registry should be initialized with the default number of shards") + bandConf, err := fr.config.getBandConfig(highPriority) + require.NoError(t, err, "Getting the defaulted band config should not fail") + assert.Equal(t, defaultPriorityBandMaxBytes, bandConf.MaxBytes, "Priority band MaxBytes should be defaulted") }) - t.Run("ShouldFail_WhenInitialShardCreationFails", func(t *testing.T) { + t.Run("ShouldFail_OnInvalidConfiguration", func(t *testing.T) { t.Parallel() - failingInterFlowFactory := func(_ inter.RegisteredPolicyName) (framework.InterFlowDispatchPolicy, error) { - return nil, errors.New("injected shard creation failure") - } - configWithFailingFactory := Config{ - PriorityBands: []PriorityBandConfig{{Priority: 10, PriorityName: "High"}}, + testCases := []struct { + name string + config Config + expectErrSubStr string + }{ + { + name: "WhenNoPriorityBandsAreDefined", + config: Config{}, + expectErrSubStr: "at least one priority band must be defined", + }, + { + name: "WhenPriorityLevelsAreDuplicated", + config: Config{ + PriorityBands: []PriorityBandConfig{ + {Priority: highPriority, PriorityName: "A"}, + {Priority: highPriority, PriorityName: "B"}, + }, + }, + expectErrSubStr: fmt.Sprintf("duplicate priority level %d", highPriority), + }, + { + name: "WhenPriorityNamesAreDuplicated", + config: Config{ + PriorityBands: []PriorityBandConfig{ + {Priority: highPriority, PriorityName: "A"}, + {Priority: lowPriority, PriorityName: "A"}, + }, + }, + expectErrSubStr: `duplicate priority name "A"`, + }, } - cfg, err := NewConfig(configWithFailingFactory, withInterFlowDispatchPolicyFactory(failingInterFlowFactory)) - require.NoError(t, err, "Test setup: creating the config object itself should not fail") - _, err = NewFlowRegistry(*cfg, logr.Discard()) + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + _, err := NewFlowRegistry(tc.config, logr.Discard()) + require.Error(t, err, "NewFlowRegistry should fail with an invalid config") + assert.Contains(t, err.Error(), tc.expectErrSubStr, "Error message should contain the expected reason") + }) + } + }) + t.Run("ShouldFail_WhenInitialShardCreationFails", func(t *testing.T) { + t.Parallel() + config, err := NewConfig( + Config{PriorityBands: []PriorityBandConfig{{Priority: highPriority, PriorityName: "A"}}}, + withInterFlowDispatchPolicyFactory(func(inter.RegisteredPolicyName) (framework.InterFlowDispatchPolicy, error) { + return nil, errors.New("injected factory failure") + }), + ) + require.NoError(t, err, "Test setup: creating the config object itself should not fail") + _, err = NewFlowRegistry(*config, logr.Discard()) require.Error(t, err, "NewFlowRegistry should fail when initial shard setup fails") - assert.Contains(t, err.Error(), "injected shard creation failure", "Error message should reflect the root cause") + assert.Contains(t, err.Error(), "injected factory failure", + "Error message should reflect the root cause from the failing plugin factory") }) } -// --- Administrative Method Tests --- +// --- `FlowRegistryClient` API Tests --- -func TestFlowRegistry_RegisterOrUpdateFlow(t *testing.T) { +func TestFlowRegistry_WithConnection_AndHandle(t *testing.T) { t.Parallel() - h := newRegistryTestHarness(t, harnessOptions{}) - key := types.FlowKey{ID: "test-flow", Priority: 10} - spec := types.FlowSpecification{Key: key} - err := h.fr.RegisterOrUpdateFlow(spec) - require.NoError(t, err, "Registering a new flow should succeed") - h.assertFlowExists(key, "Flow should exist in registry after initial registration") + t.Run("ShouldJITRegisterFlow_OnFirstConnection", func(t *testing.T) { + t.Parallel() + h := newRegistryTestHarness(t, harnessOptions{}) + key := types.FlowKey{ID: "jit-flow", Priority: highPriority} - err = h.fr.RegisterOrUpdateFlow(spec) - require.NoError(t, err, "Re-registering the same flow should be idempotent") - h.assertFlowExists(key, "Flow should still exist after idempotent re-registration") -} + h.assertFlowDoesNotExist(key, "Flow should not exist before the first connection") -func TestFlowRegistry_RegisterOrUpdateFlow_ErrorPaths(t *testing.T) { - t.Parallel() + err := h.fr.WithConnection(key, func(conn contracts.ActiveFlowConnection) error { + h.assertFlowExists(key, "Flow should exist immediately after JIT registration within the connection") + require.NotNil(t, conn, "Connection handle provided to callback must not be nil") + return nil + }) - // Define mock factories that can be injected to force specific failures. - failingPolicyFactory := func(name intra.RegisteredPolicyName) (framework.IntraFlowDispatchPolicy, error) { - return nil, errors.New("injected policy factory failure") - } - failingQueueFactory := func(name queue.RegisteredQueueName, c framework.ItemComparator) (framework.SafeQueue, error) { - return nil, errors.New("injected queue factory failure") - } + require.NoError(t, err, "WithConnection should succeed for a new flow") + h.assertFlowExists(key, "Flow should remain in the registry after the connection is closed") + }) - baseConfig := Config{ - PriorityBands: []PriorityBandConfig{ - {Priority: 10, PriorityName: "High"}, - }, - } + t.Run("ShouldFail_WhenFlowIDIsEmpty", func(t *testing.T) { + t.Parallel() + h := newRegistryTestHarness(t, harnessOptions{}) + key := types.FlowKey{ID: "", Priority: highPriority} // Invalid key - testCases := []struct { - name string - spec types.FlowSpecification - setup func(h *registryTestHarness) // Surgical setup to induce failure. - errStr string // Expected substring in the error message. - errIs error // Optional: Expected wrapped error. - }{ - { - name: "ShouldFail_WhenSpecIDIsEmpty", - spec: types.FlowSpecification{Key: types.FlowKey{ID: "", Priority: 10}}, - errStr: "flow ID cannot be empty", - errIs: contracts.ErrFlowIDEmpty, - }, - { - name: "ShouldFail_WhenPriorityBandIsNotFound", - spec: types.FlowSpecification{Key: types.FlowKey{ID: "flow", Priority: 99}}, - errStr: "failed to get configuration for priority 99", - errIs: contracts.ErrPriorityBandNotFound, - }, - { - name: "ShouldFail_WhenPolicyFactoryFails", - spec: types.FlowSpecification{Key: types.FlowKey{ID: "flow", Priority: 10}}, - setup: func(h *registryTestHarness) { - // After the registry is created with a valid config, we surgically replace the policy factory with one that is - // guaranteed to fail. - h.fr.config.intraFlowDispatchPolicyFactory = failingPolicyFactory - }, - errStr: "failed to instantiate intra-flow policy", - }, - { - name: "ShouldFail_WhenQueueFactoryFails", - spec: types.FlowSpecification{Key: types.FlowKey{ID: "flow", Priority: 10}}, - setup: func(h *registryTestHarness) { - // Similarly, we inject a failing queue factory. - h.fr.config.queueFactory = failingQueueFactory - }, - errStr: "failed to instantiate queue", - }, - } + err := h.fr.WithConnection(key, func(conn contracts.ActiveFlowConnection) error { + t.Fatal("Callback must not be executed when the provided flow key is invalid") + return nil + }) - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - h := newRegistryTestHarness(t, harnessOptions{config: &baseConfig}) - if tc.setup != nil { - tc.setup(h) - } + require.Error(t, err, "WithConnection must return an error for an empty flow ID") + assert.ErrorIs(t, err, contracts.ErrFlowIDEmpty, "The returned error must be of the correct type") + }) - err := h.fr.RegisterOrUpdateFlow(tc.spec) + t.Run("ShouldFail_WhenJITFails", func(t *testing.T) { + t.Parallel() + h := newRegistryTestHarness(t, harnessOptions{}) + h.fr.config.intraFlowDispatchPolicyFactory = func(intra.RegisteredPolicyName) (framework.IntraFlowDispatchPolicy, error) { + return nil, errors.New("injected factory failure") + } + key := types.FlowKey{ID: "test-flow", Priority: highPriority} - require.Error(t, err, "RegisterOrUpdateFlow should have returned an error") - assert.Contains(t, err.Error(), tc.errStr, "Error message should contain expected text") - if tc.errIs != nil { - assert.ErrorIs(t, err, tc.errIs, "Error should match expected wrapped error") - } + err := h.fr.WithConnection(key, func(conn contracts.ActiveFlowConnection) error { + t.Fatal("Callback must not be executed when the flow failes to register JIT") + return nil }) - } -} -func TestFlowRegistry_UpdateShardCount(t *testing.T) { - t.Parallel() + require.Error(t, err, "WithConnection must return an error for a failed flow JIT registration") + assert.ErrorContains(t, err, "injected factory failure", "The returned error must propagate the reason") + }) - // stateForDrainingTest holds the specific item that will be placed on a shard that is about to be Drained. - // We need to pass it from the setup phase to the assertion phase. - type stateForDrainingTest struct { - drainingShardID string - item types.QueueItemAccessor - flowKey types.FlowKey - } + t.Run("Handle_Shards_ShouldReturnAllShardsAndBeACopy", func(t *testing.T) { + t.Parallel() + // Create a registry with a known mixed topology of Active and Draining shards. + h := newRegistryTestHarness(t, harnessOptions{initialShardCount: 3}) + err := h.fr.updateShardCount(2) // This leaves one shard in the Draining state. + require.NoError(t, err, "Test setup: scaling down to create a draining shard should not fail") + require.Len(t, h.fr.allShards, 3, "Test setup: should have 2 active and 1 draining shard") - testCases := []struct { - name string - initialShardCount int - targetShardCount int - setup func(h *registryTestHarness) any // Returns optional state for assertions - expectErrIs error - expectErrStr string - assertions func(t *testing.T, h *registryTestHarness, state any) - }{ - { - name: "ShouldSucceed_OnScaleUp", - initialShardCount: 2, - targetShardCount: 4, - assertions: func(t *testing.T, h *registryTestHarness, _ any) { - assert.Len(t, h.fr.Shards(), 4, "Registry should have 4 shards after scale-up") - assert.Len(t, h.fr.activeShards, 4, "Registry should have 4 active shards") - assert.Empty(t, h.fr.drainingShards, "Registry should have 0 draining shards") - }, - }, - { - name: "ShouldBeNoOp_WhenCountIsUnchanged", - initialShardCount: 2, - targetShardCount: 2, - assertions: func(t *testing.T, h *registryTestHarness, _ any) { - assert.Len(t, h.fr.Shards(), 2, "Shard count should remain unchanged") - }, - }, - { - name: "ShouldSucceed_OnScaleUp_WithExistingFlows", - initialShardCount: 1, - targetShardCount: 2, - setup: func(h *registryTestHarness) any { - key := types.FlowKey{ID: "flow", Priority: 10} - require.NoError(t, h.fr.RegisterOrUpdateFlow(types.FlowSpecification{Key: key}), - "Test setup: failed to register flow") - return key // Pass the key to assertions - }, - assertions: func(t *testing.T, h *registryTestHarness, state any) { - key := state.(types.FlowKey) - assert.Len(t, h.fr.Shards(), 2, "Registry should now have 2 shards") - h.assertFlowExists(key, "Flow must exist on all shards after scaling up") - }, - }, - { - name: "ShouldFail_WhenShardCountIsInvalid_Zero", - initialShardCount: 1, - targetShardCount: 0, - expectErrIs: contracts.ErrInvalidShardCount, - }, - { - name: "ShouldFail_WhenShardCountIsInvalid_Negative", - initialShardCount: 1, - targetShardCount: -1, - expectErrIs: contracts.ErrInvalidShardCount, - }, - { - name: "ShouldFailAndRollback_WhenFlowSyncFailsDuringScaleUp", - initialShardCount: 1, - targetShardCount: 2, - setup: func(h *registryTestHarness) any { - // Create a valid, existing flow in the registry. - key := types.FlowKey{ID: "flow", Priority: 10} - require.NoError(t, h.fr.RegisterOrUpdateFlow(types.FlowSpecification{Key: key}), - "Test setup: failed to register existing flow") - - // Sabotage the system: Inject a policy factory that is guaranteed to fail. - // This will be called when the registry tries to create components for the existing flow on the *new* shard. - failingPolicyFactory := func(_ intra.RegisteredPolicyName) (framework.IntraFlowDispatchPolicy, error) { - return nil, errors.New("injected flow sync failure") - } - h.fr.mu.Lock() - h.fr.config.intraFlowDispatchPolicyFactory = failingPolicyFactory - h.fr.mu.Unlock() + key := types.FlowKey{ID: "test-flow", Priority: highPriority} - return nil - }, - expectErrStr: "injected flow sync failure", - assertions: func(t *testing.T, h *registryTestHarness, _ any) { - // The scale-up must have been aborted, leaving the registry in its original state. - assert.Len(t, h.fr.Shards(), 1, "Shard count should not have changed after a failed scale-up") - }, - }, - { - name: "ShouldGracefullyDrain_OnScaleDown", - initialShardCount: 3, - targetShardCount: 2, - setup: func(h *registryTestHarness) any { - key := types.FlowKey{ID: "test-flow", Priority: 10} - require.NoError(t, h.fr.RegisterOrUpdateFlow(types.FlowSpecification{Key: key}), - "Test setup: failed to register flow") - - allShards := h.fr.Shards() - require.Len(t, allShards, 3, "Test setup: Should start with 3 shards") - drainingShard := allShards[2] // The last shard will be chosen for draining. - - mq, err := drainingShard.ManagedQueue(key) - require.NoError(t, err, "Test setup: failed to get queue on Draining shard") - item := mocks.NewMockQueueItemAccessor(100, "req", key) - require.NoError(t, mq.Add(item), "Test setup: failed to add item") - h.synchronize() - - return &stateForDrainingTest{ - drainingShardID: drainingShard.ID(), - item: item, - flowKey: key, - } - }, - assertions: func(t *testing.T, h *registryTestHarness, state any) { - testState := state.(*stateForDrainingTest) - h.synchronize() // Ensure the Draining state has been processed. - - // Assert the intermediate Draining state is correct. - shards := h.fr.Shards() - require.Len(t, shards, 3, "Registry should still track the Draining shard until it is empty") - var drainingShard contracts.RegistryShard - for _, s := range shards { - if s.ID() == testState.drainingShardID { - drainingShard = s - } - } - require.NotNil(t, drainingShard, "Draining shard should still be present in the list") - assert.False(t, drainingShard.IsActive(), "Target shard should be marked as not active") - - // Assert that the item on the draining shard is still accessible. - mq, err := drainingShard.ManagedQueue(testState.flowKey) - require.NoError(t, err, "Should still be able to get queue from draining shard") - removedItem, err := mq.Remove(testState.item.Handle()) - require.NoError(t, err, "Should be able to remove the item from the draining shard") - assert.Equal(t, testState.item.OriginalRequest().ID(), removedItem.OriginalRequest().ID(), - "Correct item was removed") - - // After the item is removed, the shard becomes empty and should be fully garbage collected. - h.synchronize() // Process the `BecameDrained` event. - assert.Len(t, h.fr.Shards(), 2, - "Drained shard was not garbage collected from the registry after becoming empty") - }, - }, - } + err = h.fr.WithConnection(key, func(conn contracts.ActiveFlowConnection) error { + shards := conn.Shards() - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - h := newRegistryTestHarness(t, harnessOptions{initialShardCount: tc.initialShardCount}) - var state any - if tc.setup != nil { - state = tc.setup(h) - } - - err := h.fr.UpdateShardCount(tc.targetShardCount) + assert.Len(t, shards, 3, "Shards() must return all configured shards, including Draining ones") - if tc.expectErrIs != nil || tc.expectErrStr != "" { - require.Error(t, err, "UpdateShardCount should have returned an error") - if tc.expectErrIs != nil { - assert.ErrorIs(t, err, tc.expectErrIs, "Error should be the expected type") - } - if tc.expectErrStr != "" { - assert.Contains(t, err.Error(), tc.expectErrStr, "Error message should contain expected substring") - } - } else { - require.NoError(t, err, "UpdateShardCount should not have returned an error") - } + // Assert it's a copy by maliciously modifying it. + require.NotEmpty(t, shards, "Test setup assumes shards are present") + shards[0] = nil // Modify the local copy. - if tc.assertions != nil { - tc.assertions(t, h, state) - } + return nil }) - } + require.NoError(t, err) + + // Prove the registry's internal state was not mutated by the modification. + assert.NotNil(t, h.fr.allShards[0], + "Modifying the slice returned by Shards() must not affect the registry's internal state") + }) } -// --- Stats and Observability Tests --- +// --- `FlowRegistryAdmin` API Tests --- -func TestFlowRegistry_Stats_And_ShardStats(t *testing.T) { +func TestFlowRegistry_Stats(t *testing.T) { t.Parallel() + h := newRegistryTestHarness(t, harnessOptions{initialShardCount: 2}) - keyHigh := types.FlowKey{ID: "flow-high", Priority: 10} - keyLow := types.FlowKey{ID: "flow-low", Priority: 20} - require.NoError(t, h.fr.RegisterOrUpdateFlow(types.FlowSpecification{Key: keyHigh})) - require.NoError(t, h.fr.RegisterOrUpdateFlow(types.FlowSpecification{Key: keyLow})) - - // Add items to the queues. Note that these are distributed across 2 shards. - h.setFlowActive(keyHigh, true) // 100 bytes - h.setFlowActive(keyLow, true) // 100 bytes - h.synchronize() - - // Test global stats. - stats := h.fr.Stats() - assert.Equal(t, uint64(2), stats.TotalLen, "Global TotalLen should be correct") - assert.Equal(t, uint64(200), stats.TotalByteSize, "Global TotalByteSize should be correct") - assert.Equal(t, uint64(1), stats.PerPriorityBandStats[10].Len, "Global stats for priority 10 Len is incorrect") - assert.Equal(t, uint64(100), stats.PerPriorityBandStats[10].ByteSize, - "Global stats for priority 10 ByteSize is incorrect") - - // Test shard stats. - shardStats := h.fr.ShardStats() - require.Len(t, shardStats, 2, "Should be stats for 2 shards") + keyHigh := types.FlowKey{ID: "high-pri-flow", Priority: highPriority} + keyLow := types.FlowKey{ID: "low-pri-flow", Priority: lowPriority} + h.openConnectionOnFlow(keyHigh) + h.openConnectionOnFlow(keyLow) + + shards := h.fr.allShards + require.Len(t, shards, 2, "Test setup assumes 2 shards") + mqHigh0, _ := shards[0].ManagedQueue(keyHigh) + mqHigh1, _ := shards[1].ManagedQueue(keyHigh) + mqLow1, _ := shards[1].ManagedQueue(keyLow) + require.NoError(t, mqHigh0.Add(mocks.NewMockQueueItemAccessor(10, "req1", keyHigh)), + "Adding item to queue should not fail") + require.NoError(t, mqHigh1.Add(mocks.NewMockQueueItemAccessor(20, "req2", keyHigh)), + "Adding item to queue should not fail") + require.NoError(t, mqLow1.Add(mocks.NewMockQueueItemAccessor(30, "req3", keyLow)), + "Adding item to queue should not fail") + + // Although the production `Stats()` method provides a 'fuzzy snapshot' under high contention, our test validates it + // in a quiescent state, so these assertions can and must be exact. + globalStats := h.fr.Stats() + assert.Equal(t, uint64(3), globalStats.TotalLen, "Global TotalLen should be the sum of all items") + assert.Equal(t, uint64(60), globalStats.TotalByteSize, "Global TotalByteSize should be the sum of all item sizes") + shardStats := h.fr.ShardStats() + require.Len(t, shardStats, 2, "Should return stats for 2 shards") var totalShardLen, totalShardBytes uint64 for _, ss := range shardStats { totalShardLen += ss.TotalLen totalShardBytes += ss.TotalByteSize } - assert.Equal(t, stats.TotalLen, totalShardLen, "Sum of shard lengths should equal global length") - assert.Equal(t, stats.TotalByteSize, totalShardBytes, "Sum of shard byte sizes should equal global byte size") + assert.Equal(t, globalStats.TotalLen, totalShardLen, "Sum of shard lengths must equal global length") + assert.Equal(t, globalStats.TotalByteSize, totalShardBytes, "Sum of shard byte sizes must equal global byte size") } // --- Garbage Collection Tests --- -func TestFlowRegistry_GarbageCollection_IdleFlows(t *testing.T) { - t.Parallel() - key := types.FlowKey{ID: "gc-flow", Priority: 10} - spec := types.FlowSpecification{Key: key} +func TestFlowRegistry_GarbageCollection(t *testing.T) { + t.Run("ShouldCollectIdleFlow", func(t *testing.T) { + t.Parallel() + h := newRegistryTestHarness(t, harnessOptions{}) + key := types.FlowKey{ID: "idle-flow", Priority: highPriority} + + h.openConnectionOnFlow(key) // Create a flow, which is born Idle. + h.fakeClock.Step(h.config.FlowGCTimeout + time.Second) // Advance the clock just past the GC timeout. + h.fr.executeGCCycle() // Manually and deterministically trigger a GC cycle. + + h.assertFlowDoesNotExist(key, "Idle flow should be collected by the GC") + }) - t.Run("ShouldCollectIdleFlow_AfterGracePeriod", func(t *testing.T) { + t.Run("ShouldNotCollectActiveFlow", func(t *testing.T) { t.Parallel() - h := newRegistryTestHarness(t, harnessOptions{useFakeClock: true}) - require.NoError(t, h.fr.RegisterOrUpdateFlow(spec), "Test setup: failed to register flow") - h.synchronize() - - // A new flow is granted a one-cycle grace period. The Mark phase of the first GC cycle will see - // 1lastActiveGeneration==gcGenerationNewFlow` and "timestamp" it for survival. - h.waitForGCTick() - h.assertFlowExists(key, "A new Idle flow must survive its first GC cycle") - - // In the second cycle, the flow is Idle and was not marked in the preceding cycle, so it is collected. - h.waitForGCTick() - h.assertFlowDoesNotExist(key, "An Idle flow must be collected after its grace period expires") + h := newRegistryTestHarness(t, harnessOptions{}) + key := types.FlowKey{ID: "active-flow", Priority: highPriority} + + var wg sync.WaitGroup + leaseAcquired := make(chan struct{}) + releaseLease := make(chan struct{}) + wg.Add(1) + + go func() { + // This goroutine holds the lease. It will not exit until the main test goroutine calls `wg.Done()`. + defer wg.Done() + err := h.fr.WithConnection(key, func(contracts.ActiveFlowConnection) error { + close(leaseAcquired) // Signal to the main test that the lease is now active. + <-releaseLease // Block here, holding the lease, until signaled. + + return nil + }) + require.NoError(t, err, "WithConnection in the background goroutine should not fail") + }() + t.Cleanup(func() { + close(releaseLease) // Unblock the goroutine. + wg.Wait() // Wait for the goroutine to fully exit. + }) + + <-leaseAcquired // Wait until the goroutine confirms that it has acquired the lease. + h.fakeClock.Step(h.config.FlowGCTimeout * 2) // Advance the clock well past the GC timeout. + h.fr.executeGCCycle() // Manually and deterministically trigger a GC cycle. + + h.assertFlowExists(key, "An active flow must not be garbage collected, even after a forced GC cycle") }) - t.Run("ShouldNotCollectFlow_ThatWasRecentlyActive", func(t *testing.T) { + t.Run("ShouldResetGCTimer_WhenFlowBecomesActive", func(t *testing.T) { t.Parallel() - h := newRegistryTestHarness(t, harnessOptions{useFakeClock: true}) - require.NoError(t, h.fr.RegisterOrUpdateFlow(spec), "Test setup: failed to register flow") - h.synchronize() - - h.setFlowActive(key, true) - h.waitForGCTick() - h.assertFlowExists(key, "An Active flow should not be collected") - h.setFlowActive(key, false) - - // The flow is now Idle, but it was marked as Active in the immediately preceding cycle, so it must survive this - // sweep. - h.waitForGCTick() - h.assertFlowExists(key, "A flow must not be collected in the cycle immediately after it becomes Idle") - - // Having been Idle for a full cycle, it will now be collected. - h.waitForGCTick() - h.assertFlowDoesNotExist(key, "A flow must be collected after being Idle for a full cycle") + h := newRegistryTestHarness(t, harnessOptions{}) + key := types.FlowKey{ID: "reactivated-flow", Priority: highPriority} + h.openConnectionOnFlow(key) // Create an flow with a new idleness timer. + h.fakeClock.Step(h.config.FlowGCTimeout - time.Second) // Advance the clock to just before the GC timeout. + h.openConnectionOnFlow(key) // Open a new connection, resetting its idleness timer. + h.fakeClock.Step(2 * time.Second) // Advance the clock again. + h.fr.executeGCCycle() // Manually and deterministically trigger a GC cycle. + + h.assertFlowExists(key, "Flow should survive GC because its idleness timer was reset") }) - t.Run("ShouldNotCollectFlow_WhenActivityRacesWithGC", func(t *testing.T) { + t.Run("ShouldAbortSweep_WhenFlowBecomesActiveAfterScan", func(t *testing.T) { t.Parallel() - h := newRegistryTestHarness(t, harnessOptions{useFakeClock: true}) - require.NoError(t, h.fr.RegisterOrUpdateFlow(spec), "Test setup: failed to register flow") - h.synchronize() + h := newRegistryTestHarness(t, harnessOptions{}) + key := types.FlowKey{ID: "re-activated-flow", Priority: highPriority} + h.openConnectionOnFlow(key) + + // Make the flow a candidate for GC. + h.fakeClock.Step(h.config.FlowGCTimeout + time.Second) + candidates := []types.FlowKey{key} - // Establish a state where the flow is a candidate for GC on the *next* tick. - // This requires only getting it past its initial grace period. - h.waitForGCTick() - h.assertFlowExists(key, "Test setup failed: flow should exist after its grace period") + // Get the flow's state so we can manipulate its lock. + val, ok := h.fr.flowStates.Load(key) + require.True(t, ok, "Test setup: flow state must exist") + state := val.(*flowState) - // Trigger the race: send a GC tick event but DO NOT wait for it to be processed. - h.fakeClock.Step(h.config.FlowGCTimeout + time.Millisecond) + // Lock the flow's `gcLock` to pause the GC. + state.gcLock.Lock() + defer state.gcLock.Unlock() - // Immediately send a competing activity signal. - shard := h.getFirstActiveShard() - mq, err := shard.ManagedQueue(key) - require.NoError(t, err, "Test setup: failed to get managed queue for race") - require.NoError(t, mq.Add(mocks.NewMockQueueItemAccessor(100, "req", key)), "Test setup: failed to add item") + // Start the sweep in the background; it will block. + sweepDone := make(chan struct{}) + go func() { + defer close(sweepDone) + h.fr.verifyAndSweepFlows(candidates) + }() - // Synchronize. The `onGCTick` runs first, but its "Verify" step will see the live item and abort. - h.synchronize() + // While the sweep is blocked, make the flow Active again. + // This simulates a new connection arriving just in time. + state.leaseCount.Add(1) - // The flow must survive due to the "Trust but Verify" live check. - h.assertFlowExists(key, "Flow should survive GC due to the 'Trust but Verify' live check") + // Unblock the sweep. + state.gcLock.Unlock() + + // Wait for the sweep to complete. + select { + case <-sweepDone: + // Continue to assertion. + case <-time.After(time.Second): + t.Fatal("verifyAndSweepFlows deadlocked or timed out") + } + + h.assertFlowExists(key, "Flow should not be collected because it became active before the sweep") + state.gcLock.Lock() // Re-lock for the deferred unlock. }) - t.Run("ShouldBeBenign_WhenGCingAlreadyDeletedFlow", func(t *testing.T) { + t.Run("ShouldCollectDrainingShard_OnlyWhenEmpty", func(t *testing.T) { t.Parallel() - h := newRegistryTestHarness(t, harnessOptions{useFakeClock: true}) - require.NoError(t, h.fr.RegisterOrUpdateFlow(spec), "Test setup: failed to register flow") - - h.waitForGCTick() - h.waitForGCTick() // The second tick collects the flow. - h.assertFlowDoesNotExist(key, "Test setup failed: flow was not collected") - - // Manually (white-box) call the GC function again for the same key. - // It should return false and not panic. - var collected bool - assert.NotPanics(t, func() { - h.fr.mu.Lock() - collected = h.fr.garbageCollectFlowLocked(key) - h.fr.mu.Unlock() - }, "GCing a deleted flow should not panic") - assert.False(t, collected, "GCing a deleted flow should return false") + h := newRegistryTestHarness(t, harnessOptions{initialShardCount: 2}) + key := types.FlowKey{ID: "flow-on-draining-shard", Priority: highPriority} + h.openConnectionOnFlow(key) + + // Add an item to a queue on the soon-to-be Draining shard to keep it busy. + drainingShard := h.fr.activeShards[1] // The shard that will become draining. + mq, err := drainingShard.ManagedQueue(key) + require.NoError(t, err, "Test setup: getting queue on draining shard failed") + item := mocks.NewMockQueueItemAccessor(100, "req1", key) + require.NoError(t, mq.Add(item), "Adding item to non-active shard should be allowed for in-flight requests") + + // Scale down to mark one shard as Draining. + require.NoError(t, h.fr.updateShardCount(1), "Test setup: scale down should succeed") + require.Len(t, h.fr.drainingShards, 1, "Test setup: one shard should be draining") + + // Trigger a GC cycle while the shard is not empty. + h.fr.sweepDrainingShards() + require.Len(t, h.fr.drainingShards, 1, "Draining shard should not be collected while it is not empty") + + // Empty the shard and trigger GC again. + _, err = mq.Remove(item.Handle()) + require.NoError(t, err, "Test setup: removing item from draining shard failed") + h.fr.sweepDrainingShards() + assert.Empty(t, h.fr.drainingShards, "Draining shard should be collected after it becomes empty") }) - t.Run("ShouldAbortGC_WhenCacheIsActive", func(t *testing.T) { + t.Run("ShouldHandleBenignRace_WhenSweepingAlreadyDeletedFlow", func(t *testing.T) { t.Parallel() - h := newRegistryTestHarness(t, harnessOptions{useFakeClock: true}) - require.NoError(t, h.fr.RegisterOrUpdateFlow(spec), "Test setup: failed to register flow") - h.setFlowActive(key, true) // Make the flow Active. - - // Manually call the GC function. It should check the cache (`isIdle`) first and abort immediately without - // performing the "Verify" step. - var collected bool - assert.NotPanics(t, func() { - h.fr.mu.Lock() - collected = h.fr.garbageCollectFlowLocked(key) - h.fr.mu.Unlock() - }, "GCing an active flow should not panic") - - assert.False(t, collected, "GCing an active flow should return false") - h.assertFlowExists(key, "Flow should still exist after aborted GC attempt") + h := newRegistryTestHarness(t, harnessOptions{}) + key := types.FlowKey{ID: "benign-race-flow", Priority: highPriority} + h.openConnectionOnFlow(key) + + // Get the flow state so we can lock it. + val, ok := h.fr.flowStates.Load(key) + require.True(t, ok, "Test setup: flow state must exist") + state := val.(*flowState) + + // Make the flow a candidate for GC. + h.fakeClock.Step(h.config.FlowGCTimeout + time.Second) + candidates := []types.FlowKey{key} + + // Manually lock the flow's `gcLock`. This simulates the GC being stuck just before its "Verify" phase. + state.gcLock.Lock() + defer state.gcLock.Unlock() + + // In a background goroutine, run the sweep. It will block on the lock. + sweepDone := make(chan struct{}) + go func() { + defer close(sweepDone) + h.fr.verifyAndSweepFlows(candidates) + }() + + // While the sweep is blocked, delete the flow from underneath it. + // This creates the benign race condition. + h.fr.flowStates.Delete(key) + + // Unblock the sweep logic. + state.gcLock.Unlock() // Temporarily unlock to let the sweep proceed. + + // The sweep must complete without panicking. + select { + case <-sweepDone: + // Success! The test completed gracefully. + case <-time.After(time.Second): + t.Fatal("verifyAndSweepFlows deadlocked or timed out") + } + state.gcLock.Lock() // Re-lock for the deferred unlock. }) } -func TestFlowRegistry_GarbageCollection_DrainedShards(t *testing.T) { +// --- Shard Management Tests --- + +func TestFlowRegistry_UpdateShardCount(t *testing.T) { t.Parallel() - h := newRegistryTestHarness(t, harnessOptions{initialShardCount: 2}) - key := types.FlowKey{ID: "test-flow", Priority: 10} - require.NoError(t, h.fr.RegisterOrUpdateFlow(types.FlowSpecification{Key: key}), - "Test setup: failed to register flow") - - allShards := h.fr.Shards() - require.Len(t, allShards, 2, "Test setup: Should start with 2 shards") - drainingShard := allShards[1] - - // Add an item to the shard that will be drained to prevent its immediate GC upon scale-down. - mq, err := drainingShard.ManagedQueue(key) - require.NoError(t, err, "Test setup: failed to get queue on Draining shard") - item := mocks.NewMockQueueItemAccessor(100, "req", key) - require.NoError(t, mq.Add(item), "Test setup: failed to add item") - h.synchronize() - - // Scale down, which marks the shard as Draining. - require.NoError(t, h.fr.UpdateShardCount(1), "Scaling down should succeed") - h.synchronize() - - require.Len(t, h.fr.Shards(), 2, "Registry should still track the Draining shard") - assert.False(t, drainingShard.IsActive(), "Target shard should be marked as Draining") - - // Remove the last item, which triggers the `ShardBecameDrained` signal. - _, err = mq.Remove(item.Handle()) - require.NoError(t, err, "Failed to remove last item from Draining shard") - h.synchronize() // Process the `BecameDrained` signal. - - assert.Len(t, h.fr.Shards(), 1, "Drained shard was not garbage collected from the registry") - - // Crucial memory leak check: ensure the shard's ID was purged from all flow states. - h.fr.mu.Lock() - defer h.fr.mu.Unlock() - flowState, exists := h.fr.flowStates[key] - require.True(t, exists, "Flow state should still exist for the Active flow") - assert.NotContains(t, flowState.emptyOnShards, drainingShard.ID(), - "Drained shard ID must be purged from flow state map to prevent memory leaks") -} + const ( + globalCapacity = 100 + bandCapacity = 50 + ) + testCases := []struct { + name string + initialShardCount int + targetShardCount int + expectedActiveCount int + expectedPartitionedGlobalCapacities map[uint64]int + expectedPartitionedBandCapacities map[uint64]int + expectErrIs error // Optional + }{ + { + name: "NoOp_ScaleToSameCount", + initialShardCount: 2, + targetShardCount: 2, + expectedActiveCount: 2, + expectedPartitionedGlobalCapacities: map[uint64]int{50: 2}, + expectedPartitionedBandCapacities: map[uint64]int{25: 2}, + }, + { + name: "Succeeds_ScaleUp_FromOne", + initialShardCount: 1, + targetShardCount: 4, + expectedActiveCount: 4, + expectedPartitionedGlobalCapacities: map[uint64]int{25: 4}, + expectedPartitionedBandCapacities: map[uint64]int{12: 2, 13: 2}, + }, + { + name: "Succeeds_ScaleUp_FromZero", + initialShardCount: 0, + targetShardCount: 4, + expectedActiveCount: 4, + expectedPartitionedGlobalCapacities: map[uint64]int{25: 4}, + expectedPartitionedBandCapacities: map[uint64]int{12: 2, 13: 2}, + }, + { + name: "Succeeds_ScaleDown_ToOne", + initialShardCount: 3, + targetShardCount: 1, + expectedActiveCount: 1, + expectedPartitionedGlobalCapacities: map[uint64]int{100: 1}, + expectedPartitionedBandCapacities: map[uint64]int{50: 1}, + }, + { + name: "Error_ScaleDown_ToZero", + initialShardCount: 2, + targetShardCount: 0, + expectedActiveCount: 2, + expectErrIs: contracts.ErrInvalidShardCount, + expectedPartitionedGlobalCapacities: map[uint64]int{50: 2}, + expectedPartitionedBandCapacities: map[uint64]int{25: 2}, + }, + { + name: "Error_ScaleDown_ToNegative", + initialShardCount: 1, + targetShardCount: -1, + expectedActiveCount: 1, + expectErrIs: contracts.ErrInvalidShardCount, + expectedPartitionedGlobalCapacities: map[uint64]int{100: 1}, + expectedPartitionedBandCapacities: map[uint64]int{50: 1}, + }, + } -// --- Event Handling and State Machine Edge Cases --- + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + config := Config{ + MaxBytes: globalCapacity, + PriorityBands: []PriorityBandConfig{ + {Priority: highPriority, PriorityName: "A", MaxBytes: bandCapacity}, + }, + InitialShardCount: tc.initialShardCount, + } -func TestFlowRegistry_EventHandling_StaleSignals(t *testing.T) { - t.Parallel() + h := newRegistryTestHarness(t, harnessOptions{config: &config}) + key := types.FlowKey{ID: "flow", Priority: 10} + h.openConnectionOnFlow(key) - t.Run("ShouldIgnoreQueueSignal_ForGarbageCollectedFlow", func(t *testing.T) { - t.Parallel() - h := newRegistryTestHarness(t, harnessOptions{useFakeClock: true}) - key := types.FlowKey{ID: "flow", Priority: 10} - require.NoError(t, h.fr.RegisterOrUpdateFlow(types.FlowSpecification{Key: key}), - "Test setup: failed to register flow") - - h.setFlowActive(key, true) - h.synchronize() - - // Manually (white-box) delete the flow from the registry's state map. - // This simulates a race where a signal is in-flight while the flow is being GC'd. - h.fr.mu.Lock() - delete(h.fr.flowStates, key) - h.fr.mu.Unlock() - - // Now, trigger a `BecameEmpty` signal by making the flow Idle. - // The `onQueueStateChanged` handler should receive this, find no `flowState`, and return gracefully. - assert.NotPanics(t, func() { - h.setFlowActive(key, false) - }, "Registry should not panic when receiving a signal for a deleted flow") - }) + err := h.fr.updateShardCount(tc.targetShardCount) + if tc.expectErrIs != nil { + require.Error(t, err, "UpdateShardCount should have returned an error") + assert.ErrorIs(t, err, tc.expectErrIs, "Error should be the expected type") + } else { + require.NoError(t, err, "UpdateShardCount should not have returned an error") + } + + var finalActiveCount, finalDrainingCount int + globalCapacities := make(map[uint64]int) + bandCapacities := make(map[uint64]int) + err = h.fr.WithConnection(key, func(conn contracts.ActiveFlowConnection) error { + for _, shard := range conn.Shards() { + if shard.IsActive() { + finalActiveCount++ + stats := shard.Stats() + globalCapacities[stats.TotalCapacityBytes]++ + bandCapacities[stats.PerPriorityBandStats[highPriority].CapacityBytes]++ + h.assertFlowExists(key, "Shard %s should contain the existing flow", shard.ID()) + } else { + finalDrainingCount++ + } + } + return nil + }) + require.NoError(t, err, "WithConnection should not fail") + + expectedDrainingCount := 0 + if tc.initialShardCount > tc.expectedActiveCount { + expectedDrainingCount = tc.initialShardCount - tc.expectedActiveCount + } + assert.Equal(t, tc.expectedActiveCount, finalActiveCount, "Final active shard count is incorrect") + assert.Equal(t, expectedDrainingCount, finalDrainingCount, "Final draining shard count in registry is incorrect") + assert.Equal(t, tc.expectedPartitionedGlobalCapacities, globalCapacities, + "Global capacity re-partitioning incorrect") + assert.Equal(t, tc.expectedPartitionedBandCapacities, bandCapacities, "Band capacity re-partitioning incorrect") + }) + } } -// --- Invariant Panic Tests --- +// --- Concurrency Tests --- -func TestFlowRegistry_InvariantPanics(t *testing.T) { +func TestFlowRegistry_Concurrency(t *testing.T) { t.Parallel() - t.Run("ShouldPanic_WhenQueueIsMissingDuringGC", func(t *testing.T) { + t.Run("ConcurrentJITRegistrations_ShouldBeSafe", func(t *testing.T) { t.Parallel() h := newRegistryTestHarness(t, harnessOptions{}) - key := types.FlowKey{ID: "flow", Priority: 10} - require.NoError(t, h.fr.RegisterOrUpdateFlow(types.FlowSpecification{Key: key}), - "Test setup: failed to register flow") - - // Manually corrupt the state by removing the queue from a shard, creating an invariant violation. - // This is white-box testing to validate a critical failure path. - h.fr.mu.Lock() - shard := h.fr.activeShards[0] - shard.mu.Lock() - delete(shard.priorityBands[key.Priority].queues, key.ID) - shard.mu.Unlock() - h.fr.mu.Unlock() - - // Manually prepare the flow for GC by setting its generation. This bypasses the need for the async `Run` loop, - // allowing us to catch the panic in this goroutine. - h.fr.mu.Lock() - h.fr.flowStates[key].lastActiveGeneration = h.fr.gcGeneration - 1 // Make it a candidate - h.fr.mu.Unlock() - - assert.Panics(t, func() { // Value assertion is too brittle as it wraps an error. - h.fr.mu.Lock() - defer h.fr.mu.Unlock() - _ = h.fr.garbageCollectFlowLocked(key) - }, "GC must panic when a queue is missing for a tracked flow state") - }) - - t.Run("ShouldPanic_OnSignalFromUntrackedDrainingShard", func(t *testing.T) { - t.Parallel() - h := newRegistryTestHarness(t, harnessOptions{}) - fakeEvent := &shardStateChangedEvent{ - shardID: "shard-does-not-exist", - signal: shardStateSignalBecameDrained, + key := types.FlowKey{ID: "concurrent-flow", Priority: highPriority} + numGoroutines := 50 + var wg sync.WaitGroup + wg.Add(numGoroutines) + + // Hammer the `WithConnection` method for the same key from many goroutines. + for range numGoroutines { + go func() { + defer wg.Done() + err := h.fr.WithConnection(key, func(contracts.ActiveFlowConnection) error { + // Do a small amount of work inside the connection. + time.Sleep(1 * time.Millisecond) + return nil + }) + require.NoError(t, err, "Concurrent WithConnection calls must not fail") + }() } + wg.Wait() - assert.PanicsWithValue(t, - "invariant violation: shard shard-does-not-exist not found in draining map during GC", - func() { - h.fr.mu.Lock() - defer h.fr.mu.Unlock() - h.fr.onShardStateChanged(fakeEvent) - }, "Panic should occur when a non-existent shard signals it has drained") + // The primary assertion is that this completes without the race detector firing. + // We can also check that the flow state is consistent. + h.assertFlowExists(key, "Flow must exist after concurrent JIT registration") }) - t.Run("ShouldPanic_OnStatsForUnknownPriority", func(t *testing.T) { + t.Run("MixedAdminAndDataPlaneWorkload", func(t *testing.T) { t.Parallel() - h := newRegistryTestHarness(t, harnessOptions{}) - assert.PanicsWithValue(t, - "invariant violation: priority band (999) stats missing during propagation", - func() { - h.fr.propagateStatsDelta(999, 1, 1) // 999 is not a configured priority. - }, - "Panic should occur when stats are propagated for an unknown priority") - }) - - t.Run("ShouldPanic_OnComponentMismatchDuringSync", func(t *testing.T) { - t.Parallel() - h := newRegistryTestHarness(t, harnessOptions{}) - spec := types.FlowSpecification{Key: types.FlowKey{ID: "flow", Priority: 10}} - - assert.PanicsWithValue(t, - fmt.Sprintf("invariant violation: shard/queue/policy count mismatch during commit for flow %s", spec.Key), - func() { - h.fr.mu.Lock() - defer h.fr.mu.Unlock() - // Call with a mismatched number of components (0) vs. shards (1). - h.fr.applyFlowSynchronizationLocked(spec, []flowComponents{}) - }, - "Panic should occur when component count mismatches shard count") - }) + h := newRegistryTestHarness(t, harnessOptions{initialShardCount: 1}) + const ( + numWorkers = 10 + opsPerWorker = 50 + maxShardCount = 4 + ) + + var wg sync.WaitGroup + wg.Add(numWorkers + 1) // +1 for the scaling goroutine + + // Data Plane Workers: Constantly creating new flows. + for i := range numWorkers { + workerID := i + go func() { + defer wg.Done() + for j := range opsPerWorker { + key := types.FlowKey{ID: fmt.Sprintf("flow-%d-%d", workerID, j), Priority: highPriority} + _ = h.fr.WithConnection(key, func(contracts.ActiveFlowConnection) error { return nil }) + } + }() + } - t.Run("ShouldPanic_OnStatsAggregation_WithStateMismatch", func(t *testing.T) { - t.Parallel() - h := newRegistryTestHarness(t, harnessOptions{}) + // Admin Worker: Constantly scaling the number of shards up and down. + go func() { + defer wg.Done() + for i := 1; i < maxShardCount; i++ { + _ = h.fr.updateShardCount(i + 1) + _ = h.fr.updateShardCount(i) + } + }() - // Manually corrupt the state by adding a stats entry for a priority that does not exist in the configuration. - h.fr.mu.Lock() - h.fr.perPriorityBandStats[999] = &bandStats{} - h.fr.mu.Unlock() + wg.Wait() - assert.Panics(t, // Value assertion is too brittle as it wraps an error. - func() { h.fr.Stats() }, - "Stats() must panic when perPriorityBandStats contains a key not in the config") + // The test completing without a race condition is the primary assertion. + // We can also assert a consistent final state. + assert.Len(t, h.fr.activeShards, maxShardCount-1, "Final active shard count should be consistent") + flowCount := 0 + h.fr.flowStates.Range(func(_, _ any) bool { + flowCount++ + return true + }) + assert.Equal(t, numWorkers*opsPerWorker, flowCount, "All concurrently registered flows must be present") }) } diff --git a/pkg/epp/flowcontrol/registry/shard.go b/pkg/epp/flowcontrol/registry/shard.go index fd9846294..36032e42c 100644 --- a/pkg/epp/flowcontrol/registry/shard.go +++ b/pkg/epp/flowcontrol/registry/shard.go @@ -30,74 +30,72 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/util/logging" ) -// shardCallbacks groups the callback functions that a `registryShard` uses to communicate with its parent registry. -type shardCallbacks struct { - propagateStatsDelta propagateStatsDeltaFunc - signalQueueState func(shardID string, key types.FlowKey, signal queueStateSignal) - signalShardState signalShardStateFunc -} - -// priorityBand holds all the `managedQueues` and configuration for a single priority level within a shard. +// priorityBand holds all `managedQueues` and configuration for a single priority level within a shard. type priorityBand struct { - // config holds the partitioned config for this specific band within this shard. - config ShardPriorityBandConfig + // --- Immutable (set at construction) --- + config ShardPriorityBandConfig + interFlowDispatchPolicy framework.InterFlowDispatchPolicy + + // --- State Protected by the parent shard's `mu` --- // queues holds all `managedQueue` instances within this band, keyed by their logical `ID` string. // The priority is implicit from the parent `priorityBand`. queues map[string]*managedQueue - // Band-level statistics. Updated atomically via lock-free propagation. + // --- Concurrent-Safe State (Atomics) --- + + // Band-level statistics, updated via lock-free propagation from child queues. byteSize atomic.Int64 len atomic.Int64 - - // Cached policy instance for this band, created at initialization. - interFlowDispatchPolicy framework.InterFlowDispatchPolicy } // registryShard implements the `contracts.RegistryShard` interface. // // # Role: The Data Plane Slice // -// It represents a single, concurrent-safe slice of the registry's total state. It provides a read-optimized view for a -// `controller.FlowController` worker. +// It represents a single, concurrent-safe slice of the registry's total state, acting as an independent, parallel +// execution unit. It provides a read-optimized view for a `controller.FlowController` worker, partitioning the overall +// system state to enable horizontal scalability. // -// # Concurrency: `RWMutex` and Atomics +// # Concurrency Model: `RWMutex` for Topology, Atomics for Stats // -// The `registryShard` balances read performance with write safety: +// The `registryShard` balances read performance with write safety using a hybrid model: // -// - `sync.RWMutex` (mu): Protects the shard's internal maps (`priorityBands`) during administrative operations. -// -// - Atomics (Stats): Aggregated statistics (`totalByteSize`, `totalLen`) use atomics for lock-free updates during -// delta propagation. -// -// - Atomic Lifecycle (Status): The lifecycle state is managed via an atomic `status` enum. +// - `mu (sync.RWMutex)`: Protects the shard's internal topology (the maps of priority bands and queues) during +// administrative operations like flow registration, garbage collection, and configuration updates. +// Read locks are used on the hot path to look up queues, while write locks are used for infrequent structural +// changes. +// - Atomics: Aggregated statistics (`totalByteSize`, `totalLen`, etc.) and the `isDraining` flag use atomic +// operations, allowing for high-frequency, lock-free updates and reads of the shard's status and load, which is +// critical for the performance of the data path and statistics propagation. type registryShard struct { + // --- Immutable Identity & Dependencies (set at construction) --- id string logger logr.Logger - // config holds the partitioned configuration for this shard, derived from the `FlowRegistry`'s global `Config`. - // It contains only the settings and capacity limits relevant to this specific shard. - config *ShardConfig - - // status tracks the lifecycle state of the shard (Active, Draining, Drained). - // It is stored as an `int32` for atomic operations. - status atomic.Int32 // `componentStatus` + // onStatsDelta is the callback used to propagate statistics changes up to the parent registry. + onStatsDelta propagateStatsDeltaFunc + // orderedPriorityLevels is a cached, sorted list of priority levels. + orderedPriorityLevels []uint - // parentCallbacks provides the communication channels back to the parent registry. - parentCallbacks shardCallbacks + // --- State Protected by `mu` --- - // mu protects the shard's internal maps (`priorityBands`). + // mu protects the shard's internal topology (`priorityBands`) and `config`. + // TODO: This is a priority inversion issue. Administrative operations (e.g., GC) for a low-priority flow block all + // data path operations for high priority flows on this shard. We should replace `s.mu` with granular per-band locks. + // This is safe since the priority band map structure is immutable at initialization. mu sync.RWMutex - - // priorityBands is the primary lookup table for all managed queues on this shard, organized by `priority`. + // config holds the partitioned configuration for this shard, derived from the `FlowRegistry`'s global `Config`. + config *ShardConfig + // priorityBands is the primary lookup table for all managed queues on this shard. priorityBands map[uint]*priorityBand - // orderedPriorityLevels is a cached, sorted list of `priority` levels. - // It is populated at initialization to avoid repeated map key iteration and sorting during the dispatch loop, - // ensuring a deterministic, ordered traversal from highest to lowest priority. - orderedPriorityLevels []uint + // --- Concurrent-Safe State (Atomics) --- + + // isDraining indicates if the shard is gracefully shutting down. + isDraining atomic.Bool - // Shard-level statistics. Updated atomically via lock-free propagation. + // Shard-level statistics, updated via lock-free propagation from child queues. totalByteSize atomic.Int64 totalLen atomic.Int64 } @@ -109,18 +107,17 @@ func newShard( id string, config *ShardConfig, logger logr.Logger, - parentCallbacks shardCallbacks, + onStatsDelta propagateStatsDeltaFunc, interFlowFactory interFlowDispatchPolicyFactory, ) (*registryShard, error) { shardLogger := logger.WithName("registry-shard").WithValues("shardID", id) s := ®istryShard{ - id: id, - logger: shardLogger, - config: config, - parentCallbacks: parentCallbacks, - priorityBands: make(map[uint]*priorityBand, len(config.PriorityBands)), + id: id, + logger: shardLogger, + config: config, + onStatsDelta: onStatsDelta, + priorityBands: make(map[uint]*priorityBand, len(config.PriorityBands)), } - s.status.Store(int32(componentStatusActive)) for _, bandConfig := range config.PriorityBands { interPolicy, err := interFlowFactory(bandConfig.InterFlowDispatchPolicy) @@ -147,16 +144,25 @@ func newShard( func (s *registryShard) ID() string { return s.id } // IsActive returns true if the shard is active and accepting new requests. -// This is used by the `controller.FlowController` to determine if it should use this shard for new enqueue operations. +// This is a lock-free read, making it efficient for the hot path. func (s *registryShard) IsActive() bool { - return componentStatus(s.status.Load()) == componentStatusActive + return !s.isDraining.Load() } // ManagedQueue retrieves a specific `contracts.ManagedQueue` instance from this shard. func (s *registryShard) ManagedQueue(key types.FlowKey) (contracts.ManagedQueue, error) { s.mu.RLock() defer s.mu.RUnlock() - return s.managedQueueLocked(key) + + band, ok := s.priorityBands[key.Priority] + if !ok { + return nil, fmt.Errorf("failed to get managed queue for flow %q: %w", key, contracts.ErrPriorityBandNotFound) + } + mq, ok := band.queues[key.ID] + if !ok { + return nil, fmt.Errorf("failed to get managed queue for flow %q: %w", key, contracts.ErrFlowInstanceNotFound) + } + return mq, nil } // IntraFlowDispatchPolicy retrieves a flow's configured `framework.IntraFlowDispatchPolicy`. @@ -172,13 +178,14 @@ func (s *registryShard) IntraFlowDispatchPolicy(key types.FlowKey) (framework.In if !ok { return nil, fmt.Errorf("failed to get intra-flow policy for flow %q: %w", key, contracts.ErrFlowInstanceNotFound) } - // The policy is stored on the `managedQueue`. + // The policy is stored on the `managedQueue` and is immutable after creation. return mq.dispatchPolicy, nil } // InterFlowDispatchPolicy retrieves a priority band's configured `framework.InterFlowDispatchPolicy`. // This read is lock-free as the policy instance is immutable after the shard is initialized. func (s *registryShard) InterFlowDispatchPolicy(priority uint) (framework.InterFlowDispatchPolicy, error) { + // This read is safe because the `priorityBands` map structure is immutable after initialization. band, ok := s.priorityBands[priority] if !ok { return nil, fmt.Errorf("failed to get inter-flow policy for priority %d: %w", @@ -201,7 +208,7 @@ func (s *registryShard) PriorityBandAccessor(priority uint) (framework.PriorityB } // AllOrderedPriorityLevels returns a cached, sorted slice of all configured priority levels for this shard. -// The slice is sorted from highest to lowest priority (ascending numerical order). +// This is a lock-free read. func (s *registryShard) AllOrderedPriorityLevels() []uint { return s.orderedPriorityLevels } @@ -241,8 +248,7 @@ func (s *registryShard) Stats() contracts.ShardStats { // --- Internal Administrative/Lifecycle Methods (called by `FlowRegistry`) --- // synchronizeFlow is the internal administrative method for creating a flow instance on this shard. -// Since a flow instance (identified by its immutable `FlowKey`) cannot be updated yet, this function is a simple -// "create if not exists" operation. It is idempotent. +// It is an idempotent "create if not exists" operation. func (s *registryShard) synchronizeFlow( spec types.FlowSpecification, policy framework.IntraFlowDispatchPolicy, @@ -267,20 +273,22 @@ func (s *registryShard) synchronizeFlow( s.logger.V(logging.TRACE).Info("Creating new queue for flow instance.", "flowKey", key, "queueType", q.Name()) - callbacks := managedQueueCallbacks{ - propagateStatsDelta: s.propagateStatsDelta, - signalQueueState: func(key types.FlowKey, signal queueStateSignal) { - s.parentCallbacks.signalQueueState(s.id, key, signal) - }, + // Create a closure that captures the shard's `isDraining` atomic field. + // This provides the queue with a way to check the shard's status without creating a tight coupling or circular + // dependency. + isDrainingFunc := func() bool { + return s.isDraining.Load() } - mq := newManagedQueue(q, policy, spec.Key, s.logger, callbacks) + + mq := newManagedQueue(q, policy, spec.Key, s.logger, s.propagateStatsDelta, isDrainingFunc) band.queues[key.ID] = mq } -// garbageCollectLocked removes a queue instance from the shard. -// This must be called under the shard's write lock. -func (s *registryShard) garbageCollectLocked(key types.FlowKey) { - s.logger.Info("Garbage collecting queue instance.", "flowKey", key, "flowID", key.ID, "priority", key.Priority) +// deleteFlowLocked removes a queue instance from the shard. +func (s *registryShard) deleteFlow(key types.FlowKey) { + s.mu.Lock() + defer s.mu.Unlock() + s.logger.Info("Deleting queue instance.", "flowKey", key, "flowID", key.ID, "priority", key.Priority) if band, ok := s.priorityBands[key.Priority]; ok { delete(band.queues, key.ID) } @@ -288,37 +296,8 @@ func (s *registryShard) garbageCollectLocked(key types.FlowKey) { // markAsDraining transitions the shard to a Draining state. This method is lock-free. func (s *registryShard) markAsDraining() { - // Attempt to transition from Active to Draining atomically. - if s.status.CompareAndSwap(int32(componentStatusActive), int32(componentStatusDraining)) { - s.logger.V(logging.DEBUG).Info("Shard status changed", - "from", componentStatusActive, - "to", componentStatusDraining, - ) - } - - // Check if the shard is *already* empty when marked as draining. If so, immediately attempt the transition to - // Drained to ensure timely GC. This handles the race where the shard becomes empty just before or during being - // marked Draining. - if s.totalLen.Load() == 0 { - // Attempt to transition from Draining to Drained atomically. - if s.status.CompareAndSwap(int32(componentStatusDraining), int32(componentStatusDrained)) { - s.parentCallbacks.signalShardState(s.id, shardStateSignalBecameDrained) - } - } -} - -// managedQueueLocked retrieves a specific `contracts.ManagedQueue` instance from this shard. -// This must be called under the shard's read lock. -func (s *registryShard) managedQueueLocked(key types.FlowKey) (*managedQueue, error) { - band, ok := s.priorityBands[key.Priority] - if !ok { - return nil, fmt.Errorf("failed to get managed queue for flow %q: %w", key, contracts.ErrPriorityBandNotFound) - } - mq, ok := band.queues[key.ID] - if !ok { - return nil, fmt.Errorf("failed to get managed queue for flow %q: %w", key, contracts.ErrFlowInstanceNotFound) - } - return mq, nil + s.isDraining.Store(true) + s.logger.V(logging.DEBUG).Info("Shard marked as Draining") } // updateConfig atomically replaces the shard's configuration. This is used during scaling events to re-partition @@ -333,7 +312,6 @@ func (s *registryShard) updateConfig(newConfig *ShardConfig) { newBandConfig, err := newConfig.getBandConfig(priority) if err != nil { // An invariant was violated: a priority exists in the shard but not in the new config. - // This should be impossible if the registry's logic is correct. panic(fmt.Errorf("invariant violation: priority band (%d) missing in new shard configuration during update: %w", priority, err)) } @@ -342,14 +320,13 @@ func (s *registryShard) updateConfig(newConfig *ShardConfig) { s.logger.Info("Shard configuration updated") } -// --- Internal Callback Methods --- +// --- Internal Callback --- // propagateStatsDelta is the single point of entry for all statistics changes within the shard. -// It updates the relevant band's stats, the shard's total stats, and handles the shard's lifecycle signaling before -// propagating the delta to the parent registry. -// It uses atomic operations to maintain high performance under concurrent updates from multiple shards. -// As a result, its counters are eventually consistent and may be transiently inaccurate during high-contention races. +// It atomically updates the relevant band's stats, the shard's total stats, and propagates the delta to the parent +// registry. func (s *registryShard) propagateStatsDelta(priority uint, lenDelta, byteSizeDelta int64) { + // This read is safe because the `priorityBands` map structure is immutable after initialization. band, ok := s.priorityBands[priority] if !ok { // This should be impossible if the `managedQueue` calling this is correctly registered. @@ -359,24 +336,11 @@ func (s *registryShard) propagateStatsDelta(priority uint, lenDelta, byteSizeDel band.len.Add(lenDelta) band.byteSize.Add(byteSizeDelta) - newTotalLen := s.totalLen.Add(lenDelta) + s.totalLen.Add(lenDelta) s.totalByteSize.Add(byteSizeDelta) - // Following the strict bottom-up signaling pattern, we evaluate and signal our own state change *before* propagating - // the statistics to the parent registry. - s.evaluateDrainingState(newTotalLen) - s.parentCallbacks.propagateStatsDelta(priority, lenDelta, byteSizeDelta) -} - -// evaluateDrainingState checks if the shard has transitioned to the Drained state and signals the parent. -func (s *registryShard) evaluateDrainingState(currentLen int64) { - if currentLen == 0 { - // Attempt transition from Draining to Drained atomically. - // This acts as the exactly-once latch. If it succeeds, this goroutine is solely responsible for signaling. - if s.status.CompareAndSwap(int32(componentStatusDraining), int32(componentStatusDrained)) { - s.parentCallbacks.signalShardState(s.id, shardStateSignalBecameDrained) - } - } + // Propagate the delta up to the parent registry. This propagation is lock-free and eventually consistent. + s.onStatsDelta(priority, lenDelta, byteSizeDelta) } // --- `priorityBandAccessor` --- diff --git a/pkg/epp/flowcontrol/registry/shard_test.go b/pkg/epp/flowcontrol/registry/shard_test.go index aa1f06d9f..214497e41 100644 --- a/pkg/epp/flowcontrol/registry/shard_test.go +++ b/pkg/epp/flowcontrol/registry/shard_test.go @@ -17,6 +17,8 @@ limitations under the License. package registry import ( + "errors" + "fmt" "sync" "testing" @@ -26,7 +28,6 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/contracts" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework" - frameworkmocks "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/mocks" inter "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/interflow/dispatch" intra "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/policies/intraflow/dispatch" "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/framework/plugins/queue" @@ -34,330 +35,362 @@ import ( "sigs.k8s.io/gateway-api-inference-extension/pkg/epp/flowcontrol/types/mocks" ) +const ( + // highPriority is the priority level for the "High" priority band in the test harness config. + highPriority uint = 10 + // lowPriority is the priority level for the "Low" priority band in the test harness config. + lowPriority uint = 20 + // nonExistentPriority is a priority that is known not to exist in the test harness config. + nonExistentPriority uint = 99 +) + // --- Test Harness and Mocks --- -// shardTestHarness holds the components needed for a `registryShard` test. +// shardTestHarness holds all components for a `registryShard` test. type shardTestHarness struct { - t *testing.T - globalConfig *Config - shard *registryShard - shardSignaler *mockShardSignalRecorder - statsPropagator *mockStatsPropagator - callbacks shardCallbacks // Callbacks to be passed to `newShard` + t *testing.T + shard *registryShard + statsPropagator *mockStatsPropagator + highPriorityKey1 types.FlowKey + highPriorityKey2 types.FlowKey + lowPriorityKey types.FlowKey } -// newShardTestHarness creates a new test harness for testing the `registryShard`. -// It correctly simulates the parent registry's behavior by creating a global `Config`, partitioning it, and then -// creating a shard from the partitioned `ShardConfig`. +// newShardTestHarness initializes a `shardTestHarness` with a default configuration. func newShardTestHarness(t *testing.T) *shardTestHarness { t.Helper() - globalConfig, err := NewConfig(Config{ PriorityBands: []PriorityBandConfig{ - {Priority: 10, PriorityName: "High"}, - {Priority: 20, PriorityName: "Low"}, + {Priority: highPriority, PriorityName: "High"}, + {Priority: lowPriority, PriorityName: "Low"}, }, }) require.NoError(t, err, "Test setup: validating and defaulting config should not fail") - shardSignaler := &mockShardSignalRecorder{} statsPropagator := &mockStatsPropagator{} - callbacks := shardCallbacks{ - propagateStatsDelta: statsPropagator.propagate, - // For most tests, we don't care about queue signals, so this is a no-op. - signalQueueState: func(string, types.FlowKey, queueStateSignal) {}, - signalShardState: shardSignaler.signal, - } - - // Partition the global config to create a shard-specific config. shardConfig := globalConfig.partition(0, 1) - - shard, err := newShard("test-shard-1", shardConfig, logr.Discard(), callbacks, inter.NewPolicyFromName) - require.NoError(t, err, "Test setup: newShard should not return an error") - - return &shardTestHarness{ - t: t, - globalConfig: globalConfig, - shard: shard, - shardSignaler: shardSignaler, - statsPropagator: statsPropagator, - callbacks: callbacks, + shard, err := newShard( + "test-shard-1", + shardConfig, logr.Discard(), + statsPropagator.propagate, + inter.NewPolicyFromName, + ) + require.NoError(t, err, "Test setup: newShard should not return an error with valid configuration") + + h := &shardTestHarness{ + t: t, + shard: shard, + statsPropagator: statsPropagator, + highPriorityKey1: types.FlowKey{ID: "hp-flow-1", Priority: highPriority}, + highPriorityKey2: types.FlowKey{ID: "hp-flow-2", Priority: highPriority}, + lowPriorityKey: types.FlowKey{ID: "lp-flow-1", Priority: lowPriority}, } + // Automatically sync some default flows for convenience. + h.synchronizeFlow(h.highPriorityKey1) + h.synchronizeFlow(h.highPriorityKey2) + h.synchronizeFlow(h.lowPriorityKey) + return h } -// synchronizeFlowWithMocks is a test helper that simulates the parent registry's logic for calling `synchronizeFlow` on -// the shard, but with mock plugins to ensure test isolation. -func (h *shardTestHarness) synchronizeFlowWithMocks(key types.FlowKey, q framework.SafeQueue) { +// synchronizeFlow simulates the registry synchronizing a flow with a real queue. +func (h *shardTestHarness) synchronizeFlow(key types.FlowKey) { h.t.Helper() spec := types.FlowSpecification{Key: key} - mockPolicy := &frameworkmocks.MockIntraFlowDispatchPolicy{} - h.shard.synchronizeFlow(spec, mockPolicy, q) -} - -// synchronizeFlowWithRealQueue is a test helper that uses a real queue implementation. -// This is essential for integration and concurrency tests where the interaction between the shard and a real queue is -// being validated. -func (h *shardTestHarness) synchronizeFlowWithRealQueue(key types.FlowKey) { - h.t.Helper() - spec := types.FlowSpecification{Key: key} - - // This logic correctly mimics the parent registry's instantiation process. policy, err := intra.NewPolicyFromName(defaultIntraFlowDispatchPolicy) - require.NoError(h.t, err, "Test setup: failed to create real intra-flow policy") + require.NoError(h.t, err, "Helper synchronizeFlow: failed to create real intra-flow policy for synchronization") q, err := queue.NewQueueFromName(defaultQueue, policy.Comparator()) - require.NoError(h.t, err, "Test setup: failed to create real queue") + require.NoError(h.t, err, "Helper synchronizeFlow: failed to create real queue for synchronization") h.shard.synchronizeFlow(spec, policy, q) } -// addItem adds an item to a specific flow on the shard, failing the test on any error. +// addItem adds an item to a specific flow's queue on the shard. func (h *shardTestHarness) addItem(key types.FlowKey, size uint64) types.QueueItemAccessor { h.t.Helper() mq, err := h.shard.ManagedQueue(key) - require.NoError(h.t, err, "Helper addItem: failed to get queue for flow %v", key) + require.NoError(h.t, err, "Helper addItem: failed to get queue for flow %s; ensure flow is synchronized", key) item := mocks.NewMockQueueItemAccessor(size, "req", key) - require.NoError(h.t, mq.Add(item), "Helper addItem: failed to add item to queue") + require.NoError(h.t, mq.Add(item), "Helper addItem: failed to add item to queue for flow %s", key) return item } -// removeItem removes an item from a specific flow, failing the test on any error. +// removeItem removes an item from a specific flow's queue. func (h *shardTestHarness) removeItem(key types.FlowKey, item types.QueueItemAccessor) { h.t.Helper() mq, err := h.shard.ManagedQueue(key) - require.NoError(h.t, err, "Helper removeItem: failed to get queue for flow %v", key) + require.NoError(h.t, err, "Helper removeItem: failed to get queue for flow %s; ensure flow is synchronized", key) _, err = mq.Remove(item.Handle()) - require.NoError(h.t, err, "Helper removeItem: failed to remove item from queue") -} - -// mockShardSignalRecorder is a thread-safe helper for recording shard state signals. -type mockShardSignalRecorder struct { - mu sync.Mutex - signals []shardStateSignal -} - -func (r *mockShardSignalRecorder) signal(_ string, signal shardStateSignal) { - r.mu.Lock() - defer r.mu.Unlock() - r.signals = append(r.signals, signal) -} - -func (r *mockShardSignalRecorder) getSignals() []shardStateSignal { - r.mu.Lock() - defer r.mu.Unlock() - signalsCopy := make([]shardStateSignal, len(r.signals)) - copy(signalsCopy, r.signals) - return signalsCopy + require.NoError(h.t, err, "Helper removeItem: failed to remove item from queue for flow %s", key) } // --- Basic Tests --- func TestShard_New(t *testing.T) { t.Parallel() - h := newShardTestHarness(t) - assert.Equal(t, "test-shard-1", h.shard.ID(), "ID should be set correctly") - assert.True(t, h.shard.IsActive(), "A new shard should be active by default") - assert.Equal(t, []uint{10, 20}, h.shard.AllOrderedPriorityLevels(), "Priority levels should be correctly ordered") - - bandHigh, ok := h.shard.priorityBands[10] - require.True(t, ok, "High priority band should exist") - assert.Equal(t, "High", bandHigh.config.PriorityName, "High priority band should have correct name") - assert.NotNil(t, bandHigh.interFlowDispatchPolicy, "Inter-flow policy should be instantiated") - assert.Equal(t, string(defaultInterFlowDispatchPolicy), bandHigh.interFlowDispatchPolicy.Name(), - "Correct default inter-flow policy should be used") -} + t.Run("ShouldInitializeCorrectly_WithDefaultConfig", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) -func TestShard_New_ErrorPaths(t *testing.T) { - t.Parallel() + assert.Equal(t, "test-shard-1", h.shard.ID(), "Shard ID must match the value provided during construction") + assert.True(t, h.shard.IsActive(), "A newly created shard must be initialized in the Active state") + assert.Equal(t, []uint{highPriority, lowPriority}, h.shard.AllOrderedPriorityLevels(), + "Shard must report configured priority levels sorted numerically (highest priority first)") + + bandHigh, ok := h.shard.priorityBands[highPriority] + require.True(t, ok, "Priority band %d (High) must be initialized", highPriority) + assert.Equal(t, "High", bandHigh.config.PriorityName, "Priority band name must match the configuration") + require.NotNil(t, bandHigh.interFlowDispatchPolicy, "Inter-flow policy must be instantiated during construction") + assert.Equal(t, string(defaultInterFlowDispatchPolicy), bandHigh.interFlowDispatchPolicy.Name(), + "The default inter-flow policy implementation must be used when not overridden") + }) - t.Run("ShouldFail_WhenInterFlowPolicyIsMissing", func(t *testing.T) { + t.Run("ShouldFail_WhenInterFlowPolicyFactoryFails", func(t *testing.T) { t.Parallel() - shardConfig := &ShardConfig{ - PriorityBands: []ShardPriorityBandConfig{{Priority: 10, PriorityName: "High"}}, + shardConfig, _ := NewConfig(Config{PriorityBands: []PriorityBandConfig{ + {Priority: highPriority, PriorityName: "High"}, + }}) + failingFactory := func(inter.RegisteredPolicyName) (framework.InterFlowDispatchPolicy, error) { + return nil, errors.New("policy not found") } - _, err := newShard("test-shard-1", shardConfig, logr.Discard(), shardCallbacks{}, inter.NewPolicyFromName) - assert.Error(t, err, "newShard should fail when missing a configured inter-flow policy") + _, err := newShard("test-shard-1", shardConfig.partition(0, 1), logr.Discard(), nil, failingFactory) + require.Error(t, err, "newShard must fail if the inter-flow policy cannot be instantiated during initialization") }) } func TestShard_Stats(t *testing.T) { t.Parallel() h := newShardTestHarness(t) - flowKeyHigh := types.FlowKey{ID: "flow1", Priority: 10} - h.synchronizeFlowWithMocks(flowKeyHigh, &frameworkmocks.MockSafeQueue{}) - h.addItem(flowKeyHigh, 100) - h.addItem(flowKeyHigh, 50) + h.addItem(h.highPriorityKey1, 100) + h.addItem(h.highPriorityKey1, 50) stats := h.shard.Stats() - assert.Equal(t, uint64(2), stats.TotalLen, "Total length should be 2") - assert.Equal(t, uint64(150), stats.TotalByteSize, "Total byte size should be 150") + assert.Equal(t, uint64(2), stats.TotalLen, "Total shard length must aggregate counts from all bands") + assert.Equal(t, uint64(150), stats.TotalByteSize, "Total shard byte size must aggregate sizes from all bands") - bandHighStats := stats.PerPriorityBandStats[10] - assert.Equal(t, uint64(2), bandHighStats.Len, "High priority band length should be 2") - assert.Equal(t, uint64(150), bandHighStats.ByteSize, "High priority band byte size should be 150") - - bandLowStats := stats.PerPriorityBandStats[20] - assert.Zero(t, bandLowStats.Len, "Low priority band length should be 0") + bandHighStats, ok := stats.PerPriorityBandStats[highPriority] + require.True(t, ok, "Stats snapshot must include entries for all configured priority bands (e.g., %d)", highPriority) + assert.Equal(t, uint64(2), bandHighStats.Len, "Priority band length must reflect the items queued at that level") + assert.Equal(t, uint64(150), bandHighStats.ByteSize, + "Priority band byte size must reflect the items queued at that level") } -func TestShard_Accessors_SuccessPaths(t *testing.T) { +func TestShard_Accessors(t *testing.T) { t.Parallel() - h := newShardTestHarness(t) - flowKey := types.FlowKey{ID: "test-flow", Priority: 10} - h.synchronizeFlowWithRealQueue(flowKey) - t.Run("ManagedQueue_ShouldReturnCorrectInstance", func(t *testing.T) { + t.Run("SuccessPaths", func(t *testing.T) { t.Parallel() - mq, err := h.shard.ManagedQueue(flowKey) + h := newShardTestHarness(t) - require.NoError(t, err, "ManagedQueue should not error for an existing flow") - require.NotNil(t, mq, "Returned ManagedQueue should not be nil") - accessor := mq.FlowQueueAccessor() - assert.Equal(t, flowKey, accessor.FlowKey(), "Returned queue should have the correct flow key") - }) + t.Run("ManagedQueue", func(t *testing.T) { + t.Parallel() + mq, err := h.shard.ManagedQueue(h.highPriorityKey1) + require.NoError(t, err, "ManagedQueue accessor must succeed for a synchronized flow") + require.NotNil(t, mq, "Returned ManagedQueue must not be nil") + assert.Equal(t, h.highPriorityKey1, mq.FlowQueueAccessor().FlowKey(), + "The returned queue instance must correspond to the requested FlowKey") + }) - t.Run("IntraFlowDispatchPolicy_ShouldReturnCorrectInstance", func(t *testing.T) { - t.Parallel() - policy, err := h.shard.IntraFlowDispatchPolicy(flowKey) + t.Run("IntraFlowDispatchPolicy", func(t *testing.T) { + t.Parallel() + policy, err := h.shard.IntraFlowDispatchPolicy(h.highPriorityKey1) + require.NoError(t, err, "IntraFlowDispatchPolicy accessor must succeed for a synchronized flow") + require.NotNil(t, policy, "Returned policy must not be nil (guaranteed by contract)") + assert.Equal(t, string(defaultIntraFlowDispatchPolicy), policy.Name(), + "Must return the default intra-flow policy implementation") + }) - require.NoError(t, err, "IntraFlowDispatchPolicy should not error for an existing flow") - require.NotNil(t, policy, "Returned policy should not be nil") - assert.Equal(t, string(defaultIntraFlowDispatchPolicy), policy.Name(), - "Should return the correct default policy for the band") + t.Run("InterFlowDispatchPolicy", func(t *testing.T) { + t.Parallel() + policy, err := h.shard.InterFlowDispatchPolicy(highPriority) + require.NoError(t, err, "InterFlowDispatchPolicy accessor must succeed for a configured priority band") + require.NotNil(t, policy, "Returned policy must not be nil (guaranteed by contract)") + assert.Equal(t, string(defaultInterFlowDispatchPolicy), policy.Name(), + "Must return the default inter-flow policy implementation") + }) }) - t.Run("InterFlowDispatchPolicy_ShouldReturnCorrectInstance", func(t *testing.T) { + t.Run("ErrorPaths", func(t *testing.T) { t.Parallel() - policy, err := h.shard.InterFlowDispatchPolicy(uint(10)) - - require.NoError(t, err, "InterFlowDispatchPolicy should not error for an existing priority") - require.NotNil(t, policy, "Returned policy should not be nil") - assert.Equal(t, string(defaultInterFlowDispatchPolicy), policy.Name(), - "Should return the correct default policy for the band") - }) -} - -func TestShard_Accessors_ErrorPaths(t *testing.T) { - t.Parallel() - h := newShardTestHarness(t) - flowKey := types.FlowKey{ID: "flow-a", Priority: 10} - h.synchronizeFlowWithMocks(flowKey, &frameworkmocks.MockSafeQueue{}) - - testCases := []struct { - name string - action func() error - expectErr error - }{ - { - name: "ManagedQueue_PriorityNotFound", - action: func() error { _, err := h.shard.ManagedQueue(types.FlowKey{ID: "flow-a", Priority: 99}); return err }, - expectErr: contracts.ErrPriorityBandNotFound, - }, - { - name: "ManagedQueue_FlowNotFound", - action: func() error { _, err := h.shard.ManagedQueue(types.FlowKey{ID: "missing", Priority: 10}); return err }, - expectErr: contracts.ErrFlowInstanceNotFound, - }, - { - name: "InterFlowDispatchPolicy_PriorityNotFound", - action: func() error { _, err := h.shard.InterFlowDispatchPolicy(99); return err }, - expectErr: contracts.ErrPriorityBandNotFound, - }, - { - name: "IntraFlowDispatchPolicy_PriorityNotFound", - action: func() error { - _, err := h.shard.IntraFlowDispatchPolicy(types.FlowKey{ID: "flow-a", Priority: 99}) - return err + testCases := []struct { + name string + action func(s *registryShard) error + expectErr error + }{ + { + name: "ManagedQueue_PriorityNotFound", + action: func(s *registryShard) error { + _, err := s.ManagedQueue(types.FlowKey{Priority: nonExistentPriority}) + return err + }, + expectErr: contracts.ErrPriorityBandNotFound, }, - expectErr: contracts.ErrPriorityBandNotFound, - }, - { - name: "IntraFlowDispatchPolicy_FlowNotFound", - action: func() error { - _, err := h.shard.IntraFlowDispatchPolicy(types.FlowKey{ID: "missing", Priority: 10}) - return err + { + name: "ManagedQueue_FlowNotFound", + action: func(s *registryShard) error { + _, err := s.ManagedQueue(types.FlowKey{ID: "missing", Priority: highPriority}) + return err + }, + expectErr: contracts.ErrFlowInstanceNotFound, }, - expectErr: contracts.ErrFlowInstanceNotFound, - }, - } - - for _, tc := range testCases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - err := tc.action() - require.Error(t, err, "Action should have returned an error") - assert.ErrorIs(t, err, tc.expectErr, "Error should wrap the expected sentinel error") - }) - } + { + name: "InterFlowDispatchPolicy_PriorityNotFound", + action: func(s *registryShard) error { + _, err := s.InterFlowDispatchPolicy(nonExistentPriority) + return err + }, + expectErr: contracts.ErrPriorityBandNotFound, + }, + { + name: "IntraFlowDispatchPolicy_PriorityNotFound", + action: func(s *registryShard) error { + _, err := s.IntraFlowDispatchPolicy(types.FlowKey{Priority: nonExistentPriority}) + return err + }, + expectErr: contracts.ErrPriorityBandNotFound, + }, + { + name: "IntraFlowDispatchPolicy_FlowNotFound", + action: func(s *registryShard) error { + _, err := s.IntraFlowDispatchPolicy(types.FlowKey{ID: "missing", Priority: highPriority}) + return err + }, + expectErr: contracts.ErrFlowInstanceNotFound, + }, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) + err := tc.action(h.shard) + require.Error(t, err, "The accessor method must return an error for this scenario") + assert.ErrorIs(t, err, tc.expectErr, + "The error must wrap the specific sentinel error defined in the contracts package") + }) + } + }) } func TestShard_PriorityBandAccessor(t *testing.T) { t.Parallel() - // The shard will have two flows in the "High" priority band and one in the "Low" band. - h := newShardTestHarness(t) - highPriorityKey1 := types.FlowKey{ID: "flow1", Priority: 10} - highPriorityKey2 := types.FlowKey{ID: "flow2", Priority: 10} - lowPriorityKey := types.FlowKey{ID: "flow3", Priority: 20} - - h.synchronizeFlowWithMocks(highPriorityKey1, &frameworkmocks.MockSafeQueue{}) - h.synchronizeFlowWithMocks(highPriorityKey2, &frameworkmocks.MockSafeQueue{}) - h.synchronizeFlowWithMocks(lowPriorityKey, &frameworkmocks.MockSafeQueue{}) - t.Run("ShouldFail_WhenPriorityDoesNotExist", func(t *testing.T) { t.Parallel() - _, err := h.shard.PriorityBandAccessor(99) - require.Error(t, err, "PriorityBandAccessor should fail for a non-existent priority") - assert.ErrorIs(t, err, contracts.ErrPriorityBandNotFound, "Error should be ErrPriorityBandNotFound") + h := newShardTestHarness(t) + _, err := h.shard.PriorityBandAccessor(nonExistentPriority) + assert.ErrorIs(t, err, contracts.ErrPriorityBandNotFound, + "Requesting an accessor for an unconfigured priority must fail with ErrPriorityBandNotFound") }) t.Run("ShouldSucceed_WhenPriorityExists", func(t *testing.T) { t.Parallel() - accessor, err := h.shard.PriorityBandAccessor(10) - require.NoError(t, err, "PriorityBandAccessor should not fail for an existing priority") - require.NotNil(t, accessor, "Accessor should not be nil") + h := newShardTestHarness(t) + accessor, err := h.shard.PriorityBandAccessor(h.highPriorityKey1.Priority) + require.NoError(t, err, "Requesting an accessor for a configured priority must succeed") + require.NotNil(t, accessor, "The returned accessor instance must not be nil") t.Run("Properties_ShouldReturnCorrectValues", func(t *testing.T) { t.Parallel() - assert.Equal(t, uint(10), accessor.Priority(), "Accessor should have the correct priority") - assert.Equal(t, "High", accessor.PriorityName(), "Accessor should have the correct priority name") + assert.Equal(t, h.highPriorityKey1.Priority, accessor.Priority(), + "Accessor Priority() must match the configured numerical priority") + assert.Equal(t, "High", accessor.PriorityName(), "Accessor PriorityName() must match the configured name") }) t.Run("FlowKeys_ShouldReturnAllKeysInBand", func(t *testing.T) { t.Parallel() keys := accessor.FlowKeys() - expectedKeys := []types.FlowKey{highPriorityKey1, highPriorityKey2} + expectedKeys := []types.FlowKey{h.highPriorityKey1, h.highPriorityKey2} assert.ElementsMatch(t, expectedKeys, keys, - "Accessor should return all flow keys for the priority band") + "FlowKeys() must return a complete snapshot of all flows registered in this band") }) t.Run("Queue_ShouldReturnCorrectAccessor", func(t *testing.T) { t.Parallel() - q := accessor.Queue("flow1") - require.NotNil(t, q, "Accessor should return a non-nil accessor for an existing flow") - assert.Equal(t, highPriorityKey1, q.FlowKey(), "Queue accessor should have the correct flow key") - assert.Nil(t, accessor.Queue("non-existent"), "Accessor should return nil for a non-existent flow") + q := accessor.Queue(h.highPriorityKey1.ID) + require.NotNil(t, q, "Queue() must return a non-nil accessor for a registered flow ID") + assert.Equal(t, h.highPriorityKey1, q.FlowKey(), "The returned queue accessor must have the correct FlowKey") + assert.Nil(t, accessor.Queue("non-existent"), "Queue() must return nil if the flow ID is not found in this band") }) - t.Run("IterateQueues_ShouldVisitAllQueuesInBand", func(t *testing.T) { + t.Run("IterateQueues", func(t *testing.T) { t.Parallel() - var iteratedKeys []types.FlowKey - accessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool { - iteratedKeys = append(iteratedKeys, queue.FlowKey()) - return true + + t.Run("ShouldVisitAllQueuesInBand", func(t *testing.T) { + t.Parallel() + var iteratedKeys []types.FlowKey + accessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool { + iteratedKeys = append(iteratedKeys, queue.FlowKey()) + return true + }) + expectedKeys := []types.FlowKey{h.highPriorityKey1, h.highPriorityKey2} + assert.ElementsMatch(t, expectedKeys, iteratedKeys, + "IterateQueues must visit every registered flow in the band exactly once") + }) + + t.Run("ShouldExitEarly_WhenCallbackReturnsFalse", func(t *testing.T) { + t.Parallel() + var iterationCount int + accessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool { + iterationCount++ + return false + }) + assert.Equal(t, 1, iterationCount, "IterateQueues must terminate immediately when the callback returns false") + }) + + t.Run("ShouldBeSafe_DuringConcurrentMapModification", func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) // Isolated harness to avoid corrupting the state for other parallel tests + accessor, err := h.shard.PriorityBandAccessor(highPriority) + require.NoError(t, err) + + var wg sync.WaitGroup + wg.Add(2) + + // Goroutine A: The Iterator (constantly reading) + go func() { + defer wg.Done() + for i := 0; i < 100; i++ { + accessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool { + // Accessing data should not panic or race. + _ = queue.FlowKey() + return true + }) + } + }() + + // Goroutine B: The Modifier (constantly writing) + go func() { + defer wg.Done() + for i := 0; i < 100; i++ { + key := types.FlowKey{ID: fmt.Sprintf("new-flow-%d", i), Priority: highPriority} + h.synchronizeFlow(key) + h.shard.deleteFlow(key) + } + }() + + // The primary assertion is that this test completes without the race detector firing, which proves the + // `RLock/WLock` separation is correct. + wg.Wait() }) - expectedKeys := []types.FlowKey{highPriorityKey1, highPriorityKey2} - assert.ElementsMatch(t, expectedKeys, iteratedKeys, "IterateQueues should visit all flows in the band") }) - t.Run("IterateQueues_ShouldExitEarly_WhenCallbackReturnsFalse", func(t *testing.T) { + t.Run("OnEmptyBand", func(t *testing.T) { t.Parallel() - var iterationCount int + h := newShardTestHarness(t) + h.shard.deleteFlow(h.lowPriorityKey) + accessor, err := h.shard.PriorityBandAccessor(lowPriority) + require.NoError(t, err, "Setup: getting an accessor for an empty band must succeed") + + keys := accessor.FlowKeys() + assert.NotNil(t, keys, "FlowKeys() on an empty band must return a non-nil slice") + assert.Empty(t, keys, "FlowKeys() on an empty band must return an empty slice") + + var callbackExecuted bool accessor.IterateQueues(func(queue framework.FlowQueueAccessor) bool { - iterationCount++ - return false // Signal to stop iteration immediately. + callbackExecuted = true + return true }) - assert.Equal(t, 1, iterationCount, "IterateQueues should exit after the first item") + assert.False(t, callbackExecuted, "IterateQueues must not execute the callback for an empty band") }) }) } @@ -367,148 +400,138 @@ func TestShard_PriorityBandAccessor(t *testing.T) { func TestShard_SynchronizeFlow(t *testing.T) { t.Parallel() h := newShardTestHarness(t) - flowKey := types.FlowKey{ID: "flow1", Priority: 10} + flowKey := types.FlowKey{ID: "flow1", Priority: highPriority} - // Synchronization should create the queue. - h.synchronizeFlowWithMocks(flowKey, &frameworkmocks.MockSafeQueue{}) + h.synchronizeFlow(flowKey) mq1, err := h.shard.ManagedQueue(flowKey) - require.NoError(t, err, "Queue should exist after first synchronization") - assert.Contains(t, h.shard.priorityBands[10].queues, flowKey.ID, - "Queue should be in the correct priority band map") + require.NoError(t, err, "Flow instance should be accessible after synchronization") - // Synchronization should be idempotent. - h.synchronizeFlowWithMocks(flowKey, &frameworkmocks.MockSafeQueue{}) + h.synchronizeFlow(flowKey) mq2, err := h.shard.ManagedQueue(flowKey) - require.NoError(t, err, "Queue should still exist after second synchronization") - assert.Same(t, mq1, mq2, "Queue instance should not have been replaced on second sync") + require.NoError(t, err, "Flow instance should remain accessible after idempotent re-synchronization") + assert.Same(t, mq1, mq2, "Idempotent synchronization must not replace the existing queue instance") } -func TestShard_GarbageCollect(t *testing.T) { +func TestShard_DeleteFlow(t *testing.T) { t.Parallel() h := newShardTestHarness(t) - flowKey := types.FlowKey{ID: "flow1", Priority: 10} - h.synchronizeFlowWithMocks(flowKey, &frameworkmocks.MockSafeQueue{}) - require.Contains(t, h.shard.priorityBands[10].queues, flowKey.ID, "Test setup: queue must exist before GC") + _, err := h.shard.ManagedQueue(h.highPriorityKey1) + require.NoError(t, err, "Test setup: flow instance must exist before deletion") - h.shard.garbageCollectLocked(flowKey) - assert.NotContains(t, h.shard.priorityBands[10].queues, flowKey.ID, - "Queue should have been removed from the priority band") -} + h.shard.deleteFlow(h.highPriorityKey1) -// --- Draining Lifecycle and Concurrency Tests --- + _, err = h.shard.ManagedQueue(h.highPriorityKey1) + require.Error(t, err, "Flow instance should not be accessible after deletion") + assert.ErrorIs(t, err, contracts.ErrFlowInstanceNotFound, + "Accessing a deleted flow must return ErrFlowInstanceNotFound") +} -func TestShard_Lifecycle_Draining(t *testing.T) { +func TestShard_MarkAsDraining(t *testing.T) { t.Parallel() + h := newShardTestHarness(t) + assert.True(t, h.shard.IsActive(), "Shard should be active initially") - t.Run("ShouldTransitionToDraining_WhenMarkedWhileNonEmpty", func(t *testing.T) { - t.Parallel() - h := newShardTestHarness(t) - flowKey := types.FlowKey{ID: "flow1", Priority: 10} - h.synchronizeFlowWithMocks(flowKey, &frameworkmocks.MockSafeQueue{}) - h.addItem(flowKey, 100) - - h.shard.markAsDraining() - assert.False(t, h.shard.IsActive(), "Shard should no longer be active") - assert.Equal(t, componentStatusDraining, componentStatus(h.shard.status.Load()), "Shard status should be Draining") - }) - - t.Run("ShouldTransitionToDrainedAndSignal_WhenMarkedWhileEmpty", func(t *testing.T) { - t.Parallel() - h := newShardTestHarness(t) - h.shard.markAsDraining() - assert.Equal(t, componentStatusDrained, componentStatus(h.shard.status.Load()), "Shard status should be Drained") - assert.Equal(t, []shardStateSignal{shardStateSignalBecameDrained}, h.shardSignaler.getSignals(), - "Should have sent BecameDrained signal") - }) + h.shard.markAsDraining() + assert.False(t, h.shard.IsActive(), "Shard must report IsActive as false after being marked for draining") - t.Run("ShouldTransitionToDrainedAndSignal_WhenLastItemIsRemoved", func(t *testing.T) { - t.Parallel() - h := newShardTestHarness(t) - flowKey := types.FlowKey{ID: "flow1", Priority: 10} - h.synchronizeFlowWithRealQueue(flowKey) - item := h.addItem(flowKey, 100) - - h.shard.markAsDraining() - require.Equal(t, componentStatusDraining, componentStatus(h.shard.status.Load()), - "Shard should be Draining while it contains items") - - h.removeItem(flowKey, item) - assert.Equal(t, componentStatusDrained, componentStatus(h.shard.status.Load()), - "Shard should become Drained after last item is removed") - assert.Len(t, h.shardSignaler.getSignals(), 1, "A signal should have been sent") - }) + h.shard.markAsDraining() + assert.False(t, h.shard.IsActive(), "Marking as draining should be idempotent") } -// TestShard_Concurrency_DrainingRace targets the race between `markAsDraining()` and the shard becoming empty via -// concurrent item removals. It proves the atomic CAS correctly arbitrates the race and signals exactly once. -func TestShard_Concurrency_DrainingRace(t *testing.T) { +// --- Concurrency Test --- + +// TestShard_Concurrency_MixedWorkload is a general stability test that simulates a realistic workload by having +// concurrent readers (e.g., dispatchers) and writers operating on the same shard. +// It provides high confidence that the fine-grained locking strategy is free of deadlocks and data races under +// sustained, mixed contention. +func TestShard_Concurrency_MixedWorkload(t *testing.T) { t.Parallel() + const ( + numReaders = 5 + numWriters = 2 + opsPerWriter = 100 + ) + h := newShardTestHarness(t) - flowKey := types.FlowKey{ID: "flow1", Priority: 10} - h.synchronizeFlowWithRealQueue(flowKey) - item := h.addItem(flowKey, 1) - - var wg sync.WaitGroup - wg.Add(2) - - // Goroutine 1: Attempts to mark the shard as draining. - go func() { - defer wg.Done() - h.shard.markAsDraining() - }() - - // Goroutine 2: Concurrently removes the single item. - go func() { - defer wg.Done() - h.removeItem(flowKey, item) - }() - - wg.Wait() - - // Verification: No matter which operation "won" the race, the final state must be Drained, and the signal must have - // been sent exactly once. - assert.Equal(t, componentStatusDrained, componentStatus(h.shard.status.Load()), "Final state must be Drained") - assert.Len(t, h.shardSignaler.getSignals(), 1, "BecameDrained signal must be sent exactly once") -} + stopCh := make(chan struct{}) + var readersWg, writersWg sync.WaitGroup + + readersWg.Add(numReaders) + for range numReaders { + go func() { + defer readersWg.Done() + for { + select { + case <-stopCh: + return + default: + for _, priority := range h.shard.AllOrderedPriorityLevels() { + accessor, err := h.shard.PriorityBandAccessor(priority) + if err == nil { + accessor.IterateQueues(func(q framework.FlowQueueAccessor) bool { return true }) + } + } + } + } + }() + } -// --- Invariant Panic Tests --- + writersWg.Add(numWriters) + for range numWriters { + go func() { + defer writersWg.Done() + for j := range opsPerWriter { + // Alternate writing to different flows and priorities to increase contention. + if j%2 == 0 { + item := h.addItem(h.highPriorityKey1, 10) + h.removeItem(h.highPriorityKey1, item) + } else { + item := h.addItem(h.lowPriorityKey, 5) + h.removeItem(h.lowPriorityKey, item) + } + } + }() + } -func TestShard_PanicOnCorruption(t *testing.T) { - t.Parallel() + // Wait for all writers to complete first. + writersWg.Wait() - t.Run("ShouldPanic_WhenStatsPropagatedForUnknownPriority", func(t *testing.T) { - t.Parallel() - h := newShardTestHarness(t) - invalidPriority := uint(99) - assert.Panics(t, func() { - // This simulates a corrupted callback from a rogue `managedQueue`. - h.shard.propagateStatsDelta(invalidPriority, 1, 1) - }, "propagateStatsDelta must panic for an unknown priority") - }) + // Now stop the readers and wait for them to exit. + close(stopCh) + readersWg.Wait() - t.Run("ShouldPanic_WhenUpdatingConfigWithMissingPriority", func(t *testing.T) { - t.Parallel() - h := newShardTestHarness(t) + // The primary assertion is that this test completes without the race detector firing; however, we can make some final + // assertions on state consistency. + finalStats := h.shard.Stats() + assert.Zero(t, finalStats.TotalLen, "After all paired add/remove operations, the total length should be zero") + assert.Zero(t, finalStats.TotalByteSize, "After all paired add/remove operations, the total byte size should be zero") +} - // Create a new config that is missing one of the shard's existing bands (priority 20). - newGlobalConfig, _ := NewConfig(Config{ - PriorityBands: []PriorityBandConfig{{Priority: 10, PriorityName: "High-Updated"}}, - }) - newShardConfig := newGlobalConfig.partition(0, 1) +// --- Invariant Tests --- - assert.Panics(t, func() { - h.shard.updateConfig(newShardConfig) - }, "updateConfig must panic when an existing priority is missing from the new config") - }) +func TestShard_InvariantPanics(t *testing.T) { + t.Parallel() - t.Run("ShouldPanic_WhenSynchronizingFlowWithMissingPriority", func(t *testing.T) { - t.Parallel() - h := newShardTestHarness(t) - assert.Panics(t, func() { - h.shard.synchronizeFlow( - types.FlowSpecification{Key: types.FlowKey{ID: "flow", Priority: 99}}, - &frameworkmocks.MockIntraFlowDispatchPolicy{}, - &frameworkmocks.MockSafeQueue{}) - }, "synchronizeFlow must panic for an unknown priority") - }) + testCases := []struct { + name string + action func(h *shardTestHarness) + }{ + { + name: "OnStatsPropagatedForUnknownPriority", + action: func(h *shardTestHarness) { h.shard.propagateStatsDelta(nonExistentPriority, 1, 1) }, + }, + { + name: "OnSynchronizingFlowWithMissingPriority", + action: func(h *shardTestHarness) { h.synchronizeFlow(types.FlowKey{ID: "flow", Priority: nonExistentPriority}) }, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + h := newShardTestHarness(t) + assert.Panics(t, func() { tc.action(h) }, + "The action must trigger a panic when a critical system invariant is violated") + }) + } }